アヒル本8.3章非線形モデルの階層モデルをpymcでやってみる。
StanとRでベイズ統計モデリングの8.3「非線形モデルの階層モデル」をpymcで実装してみる、というお話。 この本は実践的な例題が多く、非常に参考になる。ただし、タイトルの通りだが、RとStanの実装。
- この節の例題は、経過時間と薬の血中濃度を非線形の関係でモデル化している。
- 各患者のモデルパラメータを階層化してモデル化しており、パラメータの対数値が正規分布に従う、と仮定されている。
- pymcでモデルを実装する際、時系列やパラメータの対数変換はどうするのか?という素朴な疑問からpymcで実装してみた。
実装してみた結果は以下。
- 患者と時間という2軸あるので、少し分かりにくい印象。(Stanの方がコード化しやすい気がする)
- NUTSによるMCMCはエラーで実行できなかった。代わりにADVIで推定した。得られた推定結果は書籍と大体同じであった。
df = pd.read_csv('data/data-conc-2.txt')
df.head()
PersonID | Time1 | Time2 | Time4 | Time8 | Time12 | Time24 | |
---|---|---|---|---|---|---|---|
0 | 1 | 2.4 | 5.0 | 7.5 | 11.9 | 12.5 | 12.7 |
1 | 2 | 1.4 | 3.9 | 4.4 | 7.7 | 6.4 | 8.3 |
2 | 3 | 5.2 | 9.4 | 19.4 | 20.2 | 22.7 | 24.9 |
3 | 4 | 6.7 | 12.6 | 19.1 | 23.4 | 25.8 | 26.1 |
4 | 5 | 0.3 | 4.7 | 7.0 | 10.2 | 12.9 | 14.8 |
N = len(df.PersonID.unique()) Times = np.array([1,2,4,8,12,24]) Times = np.array([Times]).T # 列ベクトルに変換する。
with pm.Model() as model: # 階層化部分 a0 = pm.Normal('a0', mu=0, sd=10**2) s_a = pm.HalfCauchy('s_a', beta=10**2) b0 = pm.Normal('b0', mu=0, sd=10**2) s_b = pm.HalfCauchy('s_b', beta=10**2) # 各患者のモデルパラメータ log_a = pm.Normal('log_a', mu=a0, sd=s_a, shape=N) a = pm.Deterministic('a', pm.math.exp(log_a)) log_b = pm.Normal('log_b', mu=b0, sd=s_b, shape=N) b = pm.Deterministic('b', pm.math.exp(log_b)) # 誤差 eps = pm.Normal('eps', mu=0, sd=10**2) # 各時刻における各患者の平均(matrixになるので、少しややこしい) mu = pm.Deterministic('mu', a*(1-pm.math.exp(-b*Times))) y = pm.Normal('y', mu=mu, sd=eps, observed=df.iloc[:,1:].values.T)
パラメータの推定。MCMC(NUTS)はエラーで実行できなかった。
with model: inference = pm.ADVI() approx = pm.fit(100000, method=inference)
Average Loss = 265.72: 100%|██████████| 100000/100000 [02:42<00:00, 614.61it/s]
Finished [100%]: Average Loss = 265.69
trace = approx.sample(draws=3000)
推定結果。書籍の結果と大体同じ値が得られている。
pm.summary(trace, varnames=['a0','b0','s_a','s_b','eps'])
mean | sd | mc_error | hpd_2.5 | hpd_97.5 | |
---|---|---|---|---|---|
a0 | 2.859904 | 0.110053 | 0.001996 | 2.636606 | 3.068902 |
b0 | -1.165525 | 0.089034 | 0.001613 | -1.344461 | -0.993055 |
s_a | 0.420661 | 0.083319 | 0.001390 | 0.275964 | 0.590157 |
s_b | 0.344515 | 0.071129 | 0.001492 | 0.210082 | 0.483927 |
eps | 1.808791 | 0.139476 | 0.002373 | 1.540795 | 2.083439 |
グラフ化して確認。 各患者の平均値の推定値(事後分布の中央値)と、95%予測区間をグラフ化する。
times = Times.T[0] # 事後分布の中央値を使う。 mu = pm.quantiles(trace)['mu'][50] eps = pm.quantiles(trace)['eps'][50] _, axes = plt.subplots(4, 4) for (row, col), ax in np.ndenumerate(axes): id = row * 4+ col # 観測値 ax.scatter(times, df.loc[id,'Time1':].values, color='k') # 平均の推定値(事後分布の中央値) ax.plot(times, mu[:,id], color='k') # 平均の推定値(事後分布の95%最高密度区間) y1,y2 = map(list, zip(*pm.hpd(trace, alpha=0.05)['mu'][:,id])) #ax.fill_between(times, y1=y1, y2=y2, alpha=0.2, color='b') # 平均値の推定値(事後分布の中央値)を使った場合の95%予測区間 y1,y2 = map(list, zip(*norm.ppf(q=[0.025, 0.975], loc=np.array([mu[:,id]]).T, scale=eps))) ax.fill_between(times, y1=y1, y2=y2, alpha=0.2, color='r') # グラフの調整 if row < 3: plt.setp(ax.get_xticklabels(), visible=False) else: plt.setp(ax, xlabel='Time (hour)') if col > 0: plt.setp(ax.get_yticklabels(), visible=False) else: plt.setp(ax, ylabel='Y') plt.setp(ax, title='id:%d'%(id+1), xticks=times, xlim=(0, 24), yticks=np.arange(0, 40, 10), ylim=(-3, 37)) plt.tight_layout()
グラフも大体同じ結果。
以上