본문 바로가기

통계

2d symmetric KL-divergence

Implementation

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def normal2d(mu:np.array, sigma):
    def normal2d_(x, y):
        x = np.array([x,y])
        x = x.reshape((2,1))
        I = np.eye(2)
        V = sigma**2*I
        V_inv = np.linalg.inv(V)
        mul = np.linalg.det(2*np.pi*V)**(-0.5)
        px = -0.5*(x-mu).T@V_inv@(x-mu)
        result = mul*np.exp(px[0])
        return result[0]
    return normal2d_

def calculate_symmetric_KL(x, y, q, p):
    left = 0
    right = 0
    for i in x:
        for j in y:
            left += q(i,j)*np.log(q(i,j)/p(i,j))
            right += p(i,j)*np.log(p(i,j)/q(i,j))
    return 0.5*(left+right)
    
q = normal2d(np.array([[1.],[2]]), 1)
p = normal2d(np.array([[3.],[0]]), 1.7)

x = np.linspace(-3, 7, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
Z_q = np.array([[q(xi, yi) for xi in x] for yi in y])  # Evaluate q for each point
Z_p = np.array([[p(xi, yi) for xi in x] for yi in y])  # Evaluate p for each point

print('Symmetric KL-divergence value: ', calculate_symmetric_KL(x,y,q,p))

# Plot
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.contourf(X, Y, Z_q, cmap='viridis')
plt.colorbar(label='Function value')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Distribution q')

plt.subplot(1, 2, 2)
plt.contourf(X, Y, Z_p, cmap='viridis')
plt.colorbar(label='Function value')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Distribution p')

plt.tight_layout()
plt.show()


fig = plt.figure(figsize=(12, 6))

# Plot for distribution q
ax1 = fig.add_subplot(121, projection='3d')
ax1.plot_surface(X, Y, Z_q, cmap='viridis')
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Function value')
ax1.set_title('Distribution q')

# Plot for distribution p
ax2 = fig.add_subplot(122, projection='3d')
ax2.plot_surface(X, Y, Z_p, cmap='viridis')
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Function value')
ax2.set_title('Distribution p')

plt.tight_layout()
plt.show()

Results

 

Symmetric KL-divergence value:  313.8467848853865

 

 

'통계' 카테고리의 다른 글

수열 수렴 판정법  (0) 2024.06.26
LASSO, Ridge regression  (0) 2024.06.26
Rightarrow vs mapsto  (1) 2024.01.10
Metric Space  (0) 2024.01.10
[정보이론] Entropy  (0) 2023.07.11