Wishart分布の可視化
ベイジアンモデリングで多次元正規分布の共分散行列の事前分布を設定することがある。
その際、よく取り上げられるのがWishart分布である。Wishart分布は多次元正規分布の精度行列(分散共分散行列の逆行列)に対する共役事前分布になることが知られている。分散共分散行列の事前分布としては、InverseWishart分布を用いれば良い。
WIshart分布の確率密度関数は下記の通り。パラメータは自由度であり、である必要がある。 また、パラメータは正定値対称行列であり、scipyの実装ではscale matrixと呼ばれている。
Wishart分布の期待値は下記の通り。
(Inverse)Wishart分布は行列を生成する確率分布なので、直感的にイメージしづらい。 そこで、Wishart分布から得られたサンプルを精度行列とする2次元正規分布について、95%信頼区間を図示して可視化を試みる。
50個の精度行列をサンプリングして、それぞれについて95%信頼区間を赤線で示している。 また、青線は精度行列の期待値を用いた場合の95%信頼区間である。
まずはを単位行列として、自由度を変化させた場合。
次に、自由度を固定してを定数倍して行った場合。
を非対角行列とすることで、相関がある共分散行列を得られやすくすることもできる。
ただし、PyMC3やStanでは共分散行列の事前分布としては、LKJ分布の利用が勧められている。
PyMC3のissueにも挙げられている通り、
- 正定値対称行列の制約があるため、全ての成分を少しずつ動かしたサンプルが有効となる確率はほぼ0となる。最初に制約のない空間に変換した上でサンプリングし、制約のある空間に戻した方が好ましい。
- Wishart分布はかなり裾が重い分布であり、サンプリング効率が悪くなってしまう。LKJ分布はWishart分布ほど裾が重くない。
- LKJ分布の方がWishart分布よりも直感的に解釈しやすい。
とのことらしい。 LKJ分布についてはまた記事を書く。
プロット出力に用いたPythonコードは下記の通り。
from matplotlib.patches import Ellipse from matplotlib import pyplot as plt import seaborn as sns import numpy as np from scipy import stats # draw 95% confidence interval of multivariate normal distribution def draw_ellipse(cov, ax, **kwargs): var, U = np.linalg.eig(cov) if U[1, 0]: angle = 180. / np.pi * np.arctan(U[0, 0]/ U[1, 0]) else: angle = 0 e = Ellipse(np.zeros(2), 2 * np.sqrt(5.991 * var[0]), 2 * np.sqrt(5.991 * var[1]), angle=angle, facecolor='none', **kwargs) ax.add_artist(e) ax.set_xlim(-8, 8) ax.set_ylim(-8, 8) # visualize samples from Wishart distribution def visualize_Wishart(nu, sigma, n_sample, ax): delta_samples = stats.wishart(df=nu, scale=sigma).rvs(n_sample) for delta in delta_samples: cov = np.linalg.inv(delta) draw_ellipse(cov, ax, edgecolor='r', linewidth=0.5, alpha=0.4) # 期待値 draw_ellipse(np.linalg.inv(nu * sigma), ax, edgecolor='b', linewidth=2.0, linestyle='--', alpha=0.7) ax.set_title(f'df={nu}, scale={sigma.round(2)}') ## 様々なパラメータで可視化 # nuを変化させる nu_list = [2, 3, 4] sigma = np.eye(2) fig, axes = plt.subplots(1, 3, figsize=(12, 4)) np.random.seed(2) for i, nu in enumerate(nu_list): visualize_Wishart(nu, sigma, 50, axes[i]) plt.tight_layout() plt.savefig('Wishart_sample_nu.png') # sigmaを変化させる scale_list = [2, 3, 4] sigma = np.eye(2) nu = 3 fig, axes = plt.subplots(1, 3, figsize=(12, 4)) np.random.seed(2) for i, scale in enumerate(scale_list): visualize_Wishart(nu, scale * sigma, 50, axes[i]) plt.tight_layout() plt.savefig('Wishart_sample_sigma.png') # 相関係数0.8 nu_list = [2, 3, 4] sigma = np.linalg.inv(np.array([[1.0, 0.8], [0.8, 1.0]])) fig, axes = plt.subplots(1, 3, figsize=(12, 4)) np.random.seed(2) for i, nu in enumerate(nu_list): visualize_Wishart(nu, sigma, 50, axes[i]) plt.tight_layout() plt.savefig('Wishart_sample_nu_corr0.8.png') # 相関係数-0.4 scale_list = [2, 3, 4] nu = 3 sigma = np.linalg.inv(np.array([[1.0, -0.4], [-0.4, 1.0]])) fig, axes = plt.subplots(1, 3, figsize=(12, 4)) np.random.seed(2) for i, scale in enumerate(scale_list): visualize_Wishart(nu, scale * sigma, 50, axes[i]) plt.tight_layout() plt.savefig('Wishart_sample_sigma_corr-0.4.png')