gaiaskyの技術メモ

ベイズ、機械学習などのデータサイエンスがテーマ。言語はpythonで。

アヒル本8.3章非線形モデルの階層モデルをpymcでやってみる。

StanとRでベイズ統計モデリングの8.3「非線形モデルの階層モデル」をpymcで実装してみる、というお話。 この本は実践的な例題が多く、非常に参考になる。ただし、タイトルの通りだが、RとStanの実装。

  • この節の例題は、経過時間と薬の血中濃度非線形の関係でモデル化している。
  • 各患者のモデルパラメータを階層化してモデル化しており、パラメータの対数値が正規分布に従う、と仮定されている。
  • pymcでモデルを実装する際、時系列やパラメータの対数変換はどうするのか?という素朴な疑問からpymcで実装してみた。

実装してみた結果は以下。

  • 患者と時間という2軸あるので、少し分かりにくい印象。(Stanの方がコード化しやすい気がする)
  • NUTSによるMCMCはエラーで実行できなかった。代わりにADVIで推定した。得られた推定結果は書籍と大体同じであった。
df = pd.read_csv('data/data-conc-2.txt')
df.head()
PersonID Time1 Time2 Time4 Time8 Time12 Time24
0 1 2.4 5.0 7.5 11.9 12.5 12.7
1 2 1.4 3.9 4.4 7.7 6.4 8.3
2 3 5.2 9.4 19.4 20.2 22.7 24.9
3 4 6.7 12.6 19.1 23.4 25.8 26.1
4 5 0.3 4.7 7.0 10.2 12.9 14.8
N = len(df.PersonID.unique())

Times = np.array([1,2,4,8,12,24])
Times = np.array([Times]).T # 列ベクトルに変換する。
with pm.Model() as model:
    # 階層化部分
    a0 = pm.Normal('a0', mu=0, sd=10**2)
    s_a = pm.HalfCauchy('s_a', beta=10**2)
    b0 = pm.Normal('b0', mu=0, sd=10**2)
    s_b = pm.HalfCauchy('s_b', beta=10**2) 
    # 各患者のモデルパラメータ
    log_a = pm.Normal('log_a', mu=a0, sd=s_a, shape=N)
    a = pm.Deterministic('a', pm.math.exp(log_a))
    log_b = pm.Normal('log_b', mu=b0, sd=s_b, shape=N)
    b = pm.Deterministic('b', pm.math.exp(log_b))
    # 誤差
    eps = pm.Normal('eps', mu=0, sd=10**2)
    
    # 各時刻における各患者の平均(matrixになるので、少しややこしい)
    mu = pm.Deterministic('mu', a*(1-pm.math.exp(-b*Times)))
    y = pm.Normal('y', mu=mu, sd=eps, observed=df.iloc[:,1:].values.T)

パラメータの推定。MCMC(NUTS)はエラーで実行できなかった。

with model:
    inference = pm.ADVI()
    approx = pm.fit(100000, method=inference)
Average Loss = 265.72: 100%|██████████| 100000/100000 [02:42<00:00, 614.61it/s]
Finished [100%]: Average Loss = 265.69
trace = approx.sample(draws=3000)

推定結果。書籍の結果と大体同じ値が得られている。

pm.summary(trace, varnames=['a0','b0','s_a','s_b','eps'])
mean sd mc_error hpd_2.5 hpd_97.5
a0 2.859904 0.110053 0.001996 2.636606 3.068902
b0 -1.165525 0.089034 0.001613 -1.344461 -0.993055
s_a 0.420661 0.083319 0.001390 0.275964 0.590157
s_b 0.344515 0.071129 0.001492 0.210082 0.483927
eps 1.808791 0.139476 0.002373 1.540795 2.083439

グラフ化して確認。 各患者の平均値の推定値(事後分布の中央値)と、95%予測区間をグラフ化する。

times = Times.T[0]
# 事後分布の中央値を使う。
mu = pm.quantiles(trace)['mu'][50]
eps = pm.quantiles(trace)['eps'][50]

_, axes = plt.subplots(4, 4)
for (row, col), ax in np.ndenumerate(axes):
    id = row * 4+ col
    # 観測値
    ax.scatter(times, df.loc[id,'Time1':].values, color='k')
    # 平均の推定値(事後分布の中央値)
    ax.plot(times, mu[:,id], color='k')
    # 平均の推定値(事後分布の95%最高密度区間)
    y1,y2 = map(list, zip(*pm.hpd(trace, alpha=0.05)['mu'][:,id]))
    #ax.fill_between(times, y1=y1, y2=y2, alpha=0.2, color='b')
    # 平均値の推定値(事後分布の中央値)を使った場合の95%予測区間
    y1,y2 = map(list, zip(*norm.ppf(q=[0.025, 0.975], loc=np.array([mu[:,id]]).T, scale=eps)))
    ax.fill_between(times, y1=y1, y2=y2, alpha=0.2, color='r')    
    # グラフの調整
    if row < 3:
        plt.setp(ax.get_xticklabels(), visible=False)
    else:
        plt.setp(ax, xlabel='Time (hour)')
    if col > 0:
        plt.setp(ax.get_yticklabels(), visible=False)
    else:
        plt.setp(ax, ylabel='Y')    
    plt.setp(ax, title='id:%d'%(id+1), xticks=times, xlim=(0, 24), yticks=np.arange(0, 40, 10), ylim=(-3, 37))
plt.tight_layout()

f:id:gaiasky:20180905223022p:plain

グラフも大体同じ結果。

以上