gaiaskyの技術メモ

ベイズ、機械学習などのデータサイエンスがテーマ。言語はpythonで。

アヒル本11章ゼロ過剰ポアソン分布をpymcでやってみる。

名前のとおり、ゼロが多いポアソン分布ということで、ベルヌーイ分布とポアソン分布を組み合わせたZIP分布を使った例題。

stanではなく、pymcで実装する場合、以下が注意点として挙げられる。

  • 切片のデータを追加しておく。
  • pymcのZIP分布におけるポアソン分布のパラメータλは、リンク関数(exp)を掛けて値を入れる。 (stanの実装では、poisson(exp(x))と等価な関数でZIP関数が作られているため、リンク関数を掛けていない。 これに気づかず、中々うまく推定できなかった。)

データを読み込み、可視化する。ゼロが多いカウントデータであることが分かる。

df = pd.read_csv('./data/data-ZIP.txt')

df.Y.plot.hist()

f:id:gaiasky:20180911215640p:plain

データの整形。切片も加える。

df.Age = df.Age/10
df['Intercept'] = 1
Y = df.Y.values
X = df[['Intercept', 'Sex', 'Sake', 'Age']].values
n_shape = X.shape[1]

モデルの実装。

with pm.Model() as model:
    b1 = pm.Normal('b1', mu=0, sd=5, shape=n_shape)
    b2 = pm.Normal('b2', mu=0, sd=5, shape=n_shape)
    
    q_x = pm.math.dot(X,b1)
    q = pm.invlogit(q_x)
    lam = pm.math.exp(pm.math.dot(X,b2)) # リンク関数(exp)をかけておく
    
    y_pred = pm.ZeroInflatedPoisson('y_pred', q, lam, observed=Y)
with model:
    trace = pm.sample(2000, start=pm.find_MAP(), step=pm.NUTS())
logp = -416.06, ||grad|| = 0.11749: 100%|██████████| 65/65 [00:00<00:00, 754.49it/s]  
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [b2, b1]
Sampling 2 chains: 100%|██████████| 5000/5000 [00:27<00:00, 181.93draws/s]
pm.summary(trace)
mean sd mc_error hpd_2.5 hpd_97.5 n_eff Rhat
b1__0 0.914821 0.674350 0.012534 -0.332461 2.305058 2608.604989 0.999940
b1__1 1.579079 0.415617 0.006595 0.779755 2.401231 4781.578735 0.999810
b1__2 3.258817 0.788188 0.014349 1.781250 4.799758 2837.311993 1.000395
b1__3 -0.356855 0.179263 0.003336 -0.716908 -0.018116 2592.259435 1.000172
b2__0 1.450101 0.134314 0.002736 1.175606 1.699740 2785.911268 1.000095
b2__1 -0.745450 0.079422 0.001370 -0.893933 -0.588722 3471.970062 0.999781
b2__2 -0.161931 0.073769 0.001336 -0.306756 -0.023078 3011.972673 0.999774
b2__3 0.198165 0.033652 0.000672 0.131085 0.262932 2823.546198 1.000473
pm.traceplot(trace)

f:id:gaiasky:20180911220116p:plain

推測したパラメータでデータを生成し、分布を確認。

pred = pm.sample_ppc(trace, samples=100, model=model)
100%|██████████| 100/100 [00:00<00:00, 682.51it/s]
plt.hist(pred['y_pred'].reshape(-1,))

f:id:gaiasky:20180911220213p:plain

実際のデータに近い分布が得られた。

以上