空間統計モデル(CAR model)をpymc3で。(緑本11章)
CAR model(条件つき自己回帰モデル)
いわゆる久保先生の緑本(データ解析のための統計モデリング入門)11章で紹介されている空間構造を考慮したモデル。
ベイズ統計モデルに興味を持ったきっかけの本。正規分布を仮定する線形モデル(LM)から始まり、正規分布に限定されない一般化線形モデル(GLM)、個体差を考慮する一般化線形混合モデル(GLMM)と発展し、複雑なモデルのパラメータを解析的に求めることの難しさから、ベイズ+MCMCによる推定への話がつながっており非常に分かりやすかった。
書籍では、モデルはBUGSで記載されており、Webを検索すればstanコードもたくさん出てくる。一方、pythonコードは少なく、特に11章の空間統計モデル(CAR model)の実装は見当たらなかったため、pymc3版を書くことにした。
BUGSではCARモデルの関数が用意されているようだが、pymc3ではそういった関数はない模様。一方で、pymc3の公式HPにはexampleがあったため、これを参考にした。
pymc3によるCARモデル
Intrinsic CARモデルは下式で表される。 wは重みで、場所iにおける場所jから受ける影響の大きさを表す。 例えば、場所iとjが隣接している場合、w_ij=1、隣接していない場合、w_ij=0などとする。 この場合、平均や分散の分母Σw_ijは隣接する場所の数となる。
実装方法としては、公式HPに、CARモデルの例題が紹介されている。①theano.scanを使う方法、②matrix trickを使う方法などが紹介されているが、①は実行速度が遅いと書かれているため、②を使うことにした。②の方法では、CAR2クラスという独自の分布モデルを作成する(ちなみに、CARクラスも例には載っており、これは①に対応したクラス)。パラメータのaが隣接情報(何番目の場所が隣接するか)の行列で、wが重み情報の行列になる。
PyMC3 Modeling tips and heuristic — PyMC3 3.5 documentation
class CAR2(distribution.Continuous): """ Conditional Autoregressive (CAR) distribution Parameters ---------- a : adjacency matrix w : weight matrix tau : precision at each location """ def __init__(self, w, a, tau, *args, **kwargs): super(CAR2, self).__init__(*args, **kwargs) self.a = a = tt.as_tensor_variable(a) self.w = w = tt.as_tensor_variable(w) self.tau = tau*tt.sum(w, axis=1) self.mode = 0. def logp(self, x): tau = self.tau w = self.w a = self.a mu_w = tt.sum(x*a, axis=1)/tt.sum(w, axis=1) return tt.sum(continuous.Normal.dist(mu=mu_w, tau=tau).logp(x))
まず、データを用意する。
Y = np.array([ 0., 3., 2., 5., 6., 16., 8., 14., 11., 10., 17., 19., 14., 19., 19., 18., 15., 13., 13., 9., 11., 15., 18., 12., 11., 17., 14., 16., 15., 9., 6., 15., 10., 11., 14., 7., 14., 14., 13., 17., 8., 7., 10., 4., 5., 5., 7., 4., 3., 1.], np.int64)
次に、隣接情報と重み情報を作成する。これは、自身の両隣で、重みはどちらも同じ大きさ。
# 隣接情報 adj = np.array( [[1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7], [6, 8], [7, 9], [8, 10], [9, 11], [10, 12], [11, 13], [12, 14], [13, 15], [14, 16], [15, 17], [16, 18], [17, 19], [18, 20], [19, 21], [20, 22], [21, 23], [22, 24], [23, 25], [24, 26], [25, 27], [26, 28], [27, 29], [28, 30], [29, 31], [30, 32], [31, 33], [32, 34], [33, 35], [34, 36], [35, 37], [36, 38], [37, 39], [38, 40], [39, 41], [40, 42], [41, 43], [42, 44], [43, 45], [44, 46], [45, 47], [46, 48], [47, 49], [49]], dtype=object) # 重み weights = np.array( [[1.0], [1, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0]], dtype=object)
これらを行列化する。
N = len(weights) wmat2 = np.zeros((N,N)) amat2 = np.zeros((N,N)) for i, a in enumerate(adj): amat2[i,a] = 1 wmat2[i,a] = weights[i]
wmat2
array([[0., 1., 0., ..., 0., 0., 0.],
[1., 0., 1., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 1., 0., 1.],
[0., 0., 0., ..., 0., 0., 1.]])
amat2
array([[0., 1., 0., ..., 0., 0., 0.],
[1., 0., 1., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 1., 0., 1.],
[0., 0., 0., ..., 0., 0., 1.]])
モデルの作成。
with pm.Model() as model: # hyper s = pm.Uniform('s', lower=0, upper=100) # prior beta = pm.Normal('beta', mu=0, sd=10) tau = 1/(s*s) r =CAR2('r', w=wmat2, a=amat2, tau=tau, shape=N) # y = pm.Poisson('y', mu=np.exp(beta + r[np.arange(50)]), observed=Y)
モデルのグラフィカル表示。
pm.model_to_graphviz(model)
MCMCの実行。NUTSでサンプリングすると、収束に時間が掛かったため、ADVI(自動微分変分法)を使ってみた。
%%time with model: inference = pm.ADVI() approx = pm.fit(n=100000, method=inference,random_seed=123, start=pm.find_MAP(method='Powell')) trace = approx.sample(draws=5000)
logp = -29.053, ||grad|| = 50.912: : 5001it [00:05, 926.82it/s]
Average Loss = 143.95: 100%|██████████| 100000/100000 [02:09<00:00, 771.11it/s]
Finished [100%]: Average Loss = 143.95
CPU times: user 2min 12s, sys: 20.1 s, total: 2min 32s
Wall time: 2min 29s
結果の出力。
pm.traceplot(trace)
データと推定結果をグラフ表示する。青丸がデータ。黒線が推定結果(中央値)、グレー帯が95%HPD区間。
b = np.median(trace.get_values('beta')) r = np.median(trace.get_values('r'),axis=0) xx = np.arange(N) yy = np.exp(b+r) plt.plot(xx,yy, 'k-') hpd = pm.hpd(trace.get_values('r'), alpha=0.05) plt.fill_between(xx, np.exp(b+hpd.T[0]), np.exp(b+hpd.T[1]), color='k', alpha=0.2) plt.plot(range(len(Y)),Y, 'p') plt.xlabel('position $j$') plt.ylabel('number of individuals $y_j$') plt.ylim(-2.5,30)
推定したパラメータの詳細は下表。 betaもsも書籍とは少し異なった推定値が得られているが、 上図のグラフを見ると、そこそこ良い推定結果が得られていると思う。
pm.summary(trace)
mean | sd | mc_error | hpd_2.5 | hpd_97.5 | |
---|---|---|---|---|---|
beta | 1.158679 | 0.047799 | 0.000675 | 1.068305 | 1.254242 |
r__0 | -0.888382 | 0.301022 | 0.004542 | -1.492071 | -0.318896 |
r__1 | -0.653750 | 0.205648 | 0.003100 | -1.063285 | -0.269927 |
r__2 | -0.218527 | 0.211978 | 0.003377 | -0.651866 | 0.181730 |
r__3 | 0.315092 | 0.205860 | 0.003063 | -0.103261 | 0.703068 |
r__4 | 0.832849 | 0.195722 | 0.002816 | 0.461069 | 1.226940 |
r__5 | 1.281606 | 0.176526 | 0.002431 | 0.932831 | 1.627272 |
r__6 | 1.260697 | 0.181086 | 0.002758 | 0.902225 | 1.618805 |
r__7 | 1.321297 | 0.183583 | 0.002787 | 0.948448 | 1.661514 |
r__8 | 1.264942 | 0.180075 | 0.002479 | 0.901088 | 1.611691 |
r__9 | 1.315550 | 0.178816 | 0.002678 | 0.956416 | 1.655632 |
r__10 | 1.570310 | 0.170607 | 0.002199 | 1.230681 | 1.900745 |
r__11 | 1.695244 | 0.165743 | 0.002195 | 1.351274 | 2.006621 |
r__12 | 1.612901 | 0.170895 | 0.002235 | 1.284063 | 1.965944 |
r__13 | 1.727920 | 0.164937 | 0.002407 | 1.416415 | 2.059092 |
r__14 | 1.752215 | 0.168899 | 0.002160 | 1.412582 | 2.071013 |
r__15 | 1.705700 | 0.166617 | 0.002456 | 1.397315 | 2.053247 |
r__16 | 1.547344 | 0.172276 | 0.002583 | 1.210396 | 1.883085 |
r__17 | 1.415188 | 0.177532 | 0.002456 | 1.081155 | 1.770285 |
r__18 | 1.296239 | 0.186362 | 0.002506 | 0.938413 | 1.665728 |
r__19 | 1.165170 | 0.184071 | 0.002463 | 0.817927 | 1.541284 |
r__20 | 1.274738 | 0.179306 | 0.002607 | 0.923394 | 1.621070 |
r__21 | 1.502048 | 0.172138 | 0.002584 | 1.168851 | 1.827205 |
r__22 | 1.587426 | 0.170370 | 0.002340 | 1.247124 | 1.908877 |
r__23 | 1.415244 | 0.171736 | 0.002149 | 1.096096 | 1.761579 |
r__24 | 1.373549 | 0.178739 | 0.002443 | 1.022405 | 1.719808 |
r__25 | 1.541359 | 0.171126 | 0.002365 | 1.187936 | 1.854710 |
r__26 | 1.557773 | 0.173536 | 0.002299 | 1.220501 | 1.894539 |
r__27 | 1.557409 | 0.176749 | 0.002737 | 1.210693 | 1.904542 |
r__28 | 1.417466 | 0.176512 | 0.002367 | 1.068018 | 1.756552 |
r__29 | 1.136758 | 0.184823 | 0.002607 | 0.789616 | 1.508527 |
r__30 | 1.020712 | 0.185264 | 0.002552 | 0.671178 | 1.397461 |
r__31 | 1.254099 | 0.176998 | 0.002652 | 0.921095 | 1.610488 |
r__32 | 1.247103 | 0.179981 | 0.002625 | 0.918825 | 1.624593 |
r__33 | 1.256230 | 0.179677 | 0.002617 | 0.903487 | 1.614004 |
r__34 | 1.286939 | 0.178641 | 0.002481 | 0.924141 | 1.616439 |
r__35 | 1.160147 | 0.188822 | 0.002600 | 0.771362 | 1.514623 |
r__36 | 1.340738 | 0.171982 | 0.002378 | 0.998691 | 1.665439 |
r__37 | 1.451340 | 0.173466 | 0.002303 | 1.102961 | 1.783128 |
r__38 | 1.484233 | 0.174277 | 0.002155 | 1.135390 | 1.825727 |
r__39 | 1.461851 | 0.171625 | 0.002388 | 1.121778 | 1.793317 |
r__40 | 1.129038 | 0.183425 | 0.002540 | 0.745838 | 1.467143 |
r__41 | 0.915550 | 0.190102 | 0.002935 | 0.541487 | 1.276502 |
r__42 | 0.825596 | 0.197854 | 0.002697 | 0.431916 | 1.193682 |
r__43 | 0.568216 | 0.201373 | 0.002649 | 0.165221 | 0.949261 |
r__44 | 0.480365 | 0.195278 | 0.002611 | 0.092885 | 0.859808 |
r__45 | 0.523366 | 0.192335 | 0.002415 | 0.158896 | 0.904041 |
r__46 | 0.526652 | 0.199353 | 0.002596 | 0.126299 | 0.898693 |
r__47 | 0.250715 | 0.199579 | 0.002248 | -0.161358 | 0.618004 |
r__48 | -0.212703 | 0.229811 | 0.003280 | -0.637011 | 0.258963 |
r__49 | -0.801395 | 0.463723 | 0.007071 | -1.727559 | 0.096066 |
s | 0.347891 | 0.040211 | 0.000537 | 0.269606 | 0.424428 |
以上