アヒル本11章ゼロ過剰ポアソン分布をpymcでやってみる。
名前のとおり、ゼロが多いポアソン分布ということで、ベルヌーイ分布とポアソン分布を組み合わせたZIP分布を使った例題。
stanではなく、pymcで実装する場合、以下が注意点として挙げられる。
- 切片のデータを追加しておく。
- pymcのZIP分布におけるポアソン分布のパラメータλは、リンク関数(exp)を掛けて値を入れる。 (stanの実装では、poisson(exp(x))と等価な関数でZIP関数が作られているため、リンク関数を掛けていない。 これに気づかず、中々うまく推定できなかった。)
データを読み込み、可視化する。ゼロが多いカウントデータであることが分かる。
df = pd.read_csv('./data/data-ZIP.txt')
df.Y.plot.hist()
データの整形。切片も加える。
df.Age = df.Age/10 df['Intercept'] = 1 Y = df.Y.values X = df[['Intercept', 'Sex', 'Sake', 'Age']].values n_shape = X.shape[1]
モデルの実装。
with pm.Model() as model: b1 = pm.Normal('b1', mu=0, sd=5, shape=n_shape) b2 = pm.Normal('b2', mu=0, sd=5, shape=n_shape) q_x = pm.math.dot(X,b1) q = pm.invlogit(q_x) lam = pm.math.exp(pm.math.dot(X,b2)) # リンク関数(exp)をかけておく y_pred = pm.ZeroInflatedPoisson('y_pred', q, lam, observed=Y)
with model: trace = pm.sample(2000, start=pm.find_MAP(), step=pm.NUTS())
logp = -416.06, ||grad|| = 0.11749: 100%|██████████| 65/65 [00:00<00:00, 754.49it/s]
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [b2, b1]
Sampling 2 chains: 100%|██████████| 5000/5000 [00:27<00:00, 181.93draws/s]
pm.summary(trace)
mean | sd | mc_error | hpd_2.5 | hpd_97.5 | n_eff | Rhat | |
---|---|---|---|---|---|---|---|
b1__0 | 0.914821 | 0.674350 | 0.012534 | -0.332461 | 2.305058 | 2608.604989 | 0.999940 |
b1__1 | 1.579079 | 0.415617 | 0.006595 | 0.779755 | 2.401231 | 4781.578735 | 0.999810 |
b1__2 | 3.258817 | 0.788188 | 0.014349 | 1.781250 | 4.799758 | 2837.311993 | 1.000395 |
b1__3 | -0.356855 | 0.179263 | 0.003336 | -0.716908 | -0.018116 | 2592.259435 | 1.000172 |
b2__0 | 1.450101 | 0.134314 | 0.002736 | 1.175606 | 1.699740 | 2785.911268 | 1.000095 |
b2__1 | -0.745450 | 0.079422 | 0.001370 | -0.893933 | -0.588722 | 3471.970062 | 0.999781 |
b2__2 | -0.161931 | 0.073769 | 0.001336 | -0.306756 | -0.023078 | 3011.972673 | 0.999774 |
b2__3 | 0.198165 | 0.033652 | 0.000672 | 0.131085 | 0.262932 | 2823.546198 | 1.000473 |
pm.traceplot(trace)
推測したパラメータでデータを生成し、分布を確認。
pred = pm.sample_ppc(trace, samples=100, model=model)
100%|██████████| 100/100 [00:00<00:00, 682.51it/s]
plt.hist(pred['y_pred'].reshape(-1,))
実際のデータに近い分布が得られた。
以上