Abnormally Distributed

統計解析担当のライフサイエンス研究者 -> データサイエンティスト@コンサル

Wishart分布の可視化

ベイジアンモデリングで多次元正規分布の共分散行列の事前分布を設定することがある。

その際、よく取り上げられるのがWishart分布である。Wishart分布は多次元正規分布の精度行列(分散共分散行列の逆行列)に対する共役事前分布になることが知られている。分散共分散行列の事前分布としては、InverseWishart分布を用いれば良い。

WIshart分布の確率密度関数は下記の通り。パラメータ\nuは自由度であり、\nu > D-1である必要がある。 また、パラメータ\Sigmaは正定値対称行列であり、scipyの実装ではscale matrixと呼ばれている。

 { \displaystyle \text{Wishart}(S|\nu,\Sigma) =C_w(\nu, \Sigma) \left|
S \right|^{(\nu - K - 1)/2} \ \exp \left(- \frac{1}{2} \
\text{tr}\left( \Sigma^{-1} S \right) \right)}

Wishart分布の期待値は下記の通り。

E(S)=\nu\Sigma

(Inverse)Wishart分布は行列を生成する確率分布なので、直感的にイメージしづらい。 そこで、Wishart分布から得られたサンプルを精度行列とする2次元正規分布について、95%信頼区間を図示して可視化を試みる。

50個の精度行列をサンプリングして、それぞれについて95%信頼区間を赤線で示している。 また、青線は精度行列の期待値を用いた場合の95%信頼区間である。  

まずは\Sigma単位行列として、自由度を変化させた場合。

f:id:kibya:20190816160824p:plain
自由度nuを変化させた場合のWishart分布
 

次に、自由度を固定して\Sigmaを定数倍して行った場合。

 

f:id:kibya:20190816161042p:plain
scale matrix (sigma)を変化させた際のWishart分布

\Sigmaを非対角行列とすることで、相関がある共分散行列を得られやすくすることもできる。

f:id:kibya:20190816161436p:plain
期待値が相関係数が0.8の共分散行列となるWishart分布

f:id:kibya:20190816161221p:plain
期待値が相関係数が-0.4の共分散行列となるWishart分布

ただし、PyMC3やStanでは共分散行列の事前分布としては、LKJ分布の利用が勧められている。

PyMC3のissueにも挙げられている通り、

  1. 正定値対称行列の制約があるため、全ての成分を少しずつ動かしたサンプルが有効となる確率はほぼ0となる。最初に制約のない空間に変換した上でサンプリングし、制約のある空間に戻した方が好ましい。
  2. Wishart分布はかなり裾が重い分布であり、サンプリング効率が悪くなってしまう。LKJ分布はWishart分布ほど裾が重くない。
  3. LKJ分布の方がWishart分布よりも直感的に解釈しやすい。

とのことらしい。 LKJ分布についてはまた記事を書く。

プロット出力に用いたPythonコードは下記の通り。

from matplotlib.patches import Ellipse
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
from scipy import stats


# draw 95% confidence interval of multivariate normal distribution
def draw_ellipse(cov, ax, **kwargs):
    var, U = np.linalg.eig(cov)
    if U[1, 0]:
        angle = 180. / np.pi * np.arctan(U[0, 0]/ U[1, 0])
    else:
        angle = 0

    e = Ellipse(np.zeros(2), 2 * np.sqrt(5.991 * var[0]), 2 * np.sqrt(5.991 * var[1]), 
                angle=angle, facecolor='none', **kwargs)

    ax.add_artist(e)
    ax.set_xlim(-8, 8)
    ax.set_ylim(-8, 8)

        
# visualize samples from Wishart distribution
def visualize_Wishart(nu, sigma, n_sample, ax):
    delta_samples = stats.wishart(df=nu, scale=sigma).rvs(n_sample)
    
    for delta in delta_samples:
        cov = np.linalg.inv(delta)
        draw_ellipse(cov, ax, edgecolor='r', linewidth=0.5, alpha=0.4)
    # 期待値
    draw_ellipse(np.linalg.inv(nu * sigma), ax, edgecolor='b', linewidth=2.0, linestyle='--', alpha=0.7)
    ax.set_title(f'df={nu}, scale={sigma.round(2)}')


## 様々なパラメータで可視化
# nuを変化させる
nu_list = [2, 3, 4]
sigma = np.eye(2)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

np.random.seed(2)
for i, nu in enumerate(nu_list):
    visualize_Wishart(nu, sigma, 50, axes[i])
plt.tight_layout()
plt.savefig('Wishart_sample_nu.png')


# sigmaを変化させる
scale_list = [2, 3, 4]
sigma = np.eye(2)
nu = 3

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

np.random.seed(2)
for i, scale in enumerate(scale_list):
    visualize_Wishart(nu, scale * sigma, 50, axes[i])
plt.tight_layout()
plt.savefig('Wishart_sample_sigma.png')


# 相関係数0.8
nu_list = [2, 3, 4]
sigma = np.linalg.inv(np.array([[1.0, 0.8], [0.8, 1.0]]))

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

np.random.seed(2)
for i, nu in enumerate(nu_list):
    visualize_Wishart(nu, sigma, 50, axes[i])
plt.tight_layout()
plt.savefig('Wishart_sample_nu_corr0.8.png')


# 相関係数-0.4
scale_list = [2, 3, 4]
nu = 3
sigma = np.linalg.inv(np.array([[1.0, -0.4], [-0.4, 1.0]]))

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

np.random.seed(2)
for i, scale in enumerate(scale_list):
    visualize_Wishart(nu, scale * sigma, 50, axes[i])
plt.tight_layout()
plt.savefig('Wishart_sample_sigma_corr-0.4.png')