gaiaskyの技術メモ

ベイズ、機械学習などのデータサイエンスがテーマ。言語はpythonで。

空間統計モデル(CAR model)をpymc3で。(緑本11章)

CAR model(条件つき自己回帰モデル

いわゆる久保先生の緑本(データ解析のための統計モデリング入門)11章で紹介されている空間構造を考慮したモデル。

ベイズ統計モデルに興味を持ったきっかけの本。正規分布を仮定する線形モデル(LM)から始まり、正規分布に限定されない一般化線形モデル(GLM)、個体差を考慮する一般化線形混合モデル(GLMM)と発展し、複雑なモデルのパラメータを解析的に求めることの難しさから、ベイズMCMCによる推定への話がつながっており非常に分かりやすかった。

書籍では、モデルはBUGSで記載されており、Webを検索すればstanコードもたくさん出てくる。一方、pythonコードは少なく、特に11章の空間統計モデル(CAR model)の実装は見当たらなかったため、pymc3版を書くことにした。

BUGSではCARモデルの関数が用意されているようだが、pymc3ではそういった関数はない模様。一方で、pymc3の公式HPにはexampleがあったため、これを参考にした。

pymc3によるCARモデル

Intrinsic CARモデルは下式で表される。 wは重みで、場所iにおける場所jから受ける影響の大きさを表す。 例えば、場所iとjが隣接している場合、w_ij=1、隣接していない場合、w_ij=0などとする。 この場合、平均や分散の分母Σw_ijは隣接する場所の数となる。

 \displaystyle
\begin{align*}
y_i &\sim Poisson(\exp( \beta + r_i)) \\
r_i | r_j, j \neq i &\sim Normal( \frac{\sum_j w_{ij} r_j}{\sum_j w_{ij}}, \frac{s^2}{\sum_j w_{ij}})
\end{align*}

実装方法としては、公式HPに、CARモデルの例題が紹介されている。①theano.scanを使う方法、②matrix trickを使う方法などが紹介されているが、①は実行速度が遅いと書かれているため、②を使うことにした。②の方法では、CAR2クラスという独自の分布モデルを作成する(ちなみに、CARクラスも例には載っており、これは①に対応したクラス)。パラメータのaが隣接情報(何番目の場所が隣接するか)の行列で、wが重み情報の行列になる。

PyMC3 Modeling tips and heuristic — PyMC3 3.5 documentation

class CAR2(distribution.Continuous):
    """
    Conditional Autoregressive (CAR) distribution

    Parameters
    ----------
    a : adjacency matrix
    w : weight matrix
    tau : precision at each location
    """

    def __init__(self, w, a, tau, *args, **kwargs):
        super(CAR2, self).__init__(*args, **kwargs)
        self.a = a = tt.as_tensor_variable(a)
        self.w = w = tt.as_tensor_variable(w)
        self.tau = tau*tt.sum(w, axis=1)
        self.mode = 0.

    def logp(self, x):
        tau = self.tau
        w = self.w
        a = self.a

        mu_w = tt.sum(x*a, axis=1)/tt.sum(w, axis=1)
        return tt.sum(continuous.Normal.dist(mu=mu_w, tau=tau).logp(x))

まず、データを用意する。

Y = np.array([  0.,   3.,   2.,   5.,   6.,  16.,   8.,  14.,  11.,  10.,  17.,
                19.,  14.,  19.,  19.,  18.,  15.,  13.,  13.,   9.,  11.,  15.,
                18.,  12.,  11.,  17.,  14.,  16.,  15.,   9.,   6.,  15.,  10.,
                11.,  14.,   7.,  14.,  14.,  13.,  17.,   8.,   7.,  10.,   4.,
                 5.,   5.,   7.,   4.,   3.,   1.], np.int64)

次に、隣接情報と重み情報を作成する。これは、自身の両隣で、重みはどちらも同じ大きさ。

# 隣接情報
adj = np.array(
       [[1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7], [6, 8], [7, 9],
       [8, 10], [9, 11], [10, 12], [11, 13], [12, 14], [13, 15], [14, 16],
       [15, 17], [16, 18], [17, 19], [18, 20], [19, 21], [20, 22],
       [21, 23], [22, 24], [23, 25], [24, 26], [25, 27], [26, 28],
       [27, 29], [28, 30], [29, 31], [30, 32], [31, 33], [32, 34],
       [33, 35], [34, 36], [35, 37], [36, 38], [37, 39], [38, 40],
       [39, 41], [40, 42], [41, 43], [42, 44], [43, 45], [44, 46],
       [45, 47], [46, 48], [47, 49], [49]], dtype=object)

# 重み
weights = np.array(
       [[1.0], [1, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],
       [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],
       [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],
       [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],
       [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],
       [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],
       [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],
       [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],
       [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],
       [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0]], dtype=object)

これらを行列化する。

N = len(weights)
wmat2 = np.zeros((N,N))
amat2 = np.zeros((N,N))
for i, a in enumerate(adj):
    amat2[i,a] = 1
    wmat2[i,a] = weights[i]
wmat2
array([[0., 1., 0., ..., 0., 0., 0.],
       [1., 0., 1., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 1., 0., 1.],
       [0., 0., 0., ..., 0., 0., 1.]])
amat2
array([[0., 1., 0., ..., 0., 0., 0.],
       [1., 0., 1., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 1., 0., 1.],
       [0., 0., 0., ..., 0., 0., 1.]])

モデルの作成。

with pm.Model() as model:
    # hyper
    s = pm.Uniform('s', lower=0, upper=100)
    # prior
    beta = pm.Normal('beta', mu=0, sd=10)
    tau = 1/(s*s)
    r =CAR2('r', w=wmat2, a=amat2, tau=tau, shape=N)
    #
    y = pm.Poisson('y', mu=np.exp(beta + r[np.arange(50)]), observed=Y)

モデルのグラフィカル表示。

pm.model_to_graphviz(model)

f:id:gaiasky:20180815110140p:plain

MCMCの実行。NUTSでサンプリングすると、収束に時間が掛かったため、ADVI(自動微分変分法)を使ってみた。

%%time
with model:
    inference = pm.ADVI()
    approx = pm.fit(n=100000, method=inference,random_seed=123, start=pm.find_MAP(method='Powell'))
    trace = approx.sample(draws=5000)
logp = -29.053, ||grad|| = 50.912: : 5001it [00:05, 926.82it/s]                          

Average Loss = 143.95: 100%|██████████| 100000/100000 [02:09<00:00, 771.11it/s]
Finished [100%]: Average Loss = 143.95


CPU times: user 2min 12s, sys: 20.1 s, total: 2min 32s
Wall time: 2min 29s

結果の出力。

pm.traceplot(trace)

f:id:gaiasky:20180815203334p:plain

データと推定結果をグラフ表示する。青丸がデータ。黒線が推定結果(中央値)、グレー帯が95%HPD区間

b = np.median(trace.get_values('beta'))
r = np.median(trace.get_values('r'),axis=0)
xx = np.arange(N)
yy = np.exp(b+r)
plt.plot(xx,yy, 'k-')
hpd = pm.hpd(trace.get_values('r'), alpha=0.05)
plt.fill_between(xx, np.exp(b+hpd.T[0]), np.exp(b+hpd.T[1]), color='k', alpha=0.2)

plt.plot(range(len(Y)),Y, 'p')
plt.xlabel('position $j$')
plt.ylabel('number of individuals $y_j$')
plt.ylim(-2.5,30)

f:id:gaiasky:20180815203351p:plain

推定したパラメータの詳細は下表。 betaもsも書籍とは少し異なった推定値が得られているが、 上図のグラフを見ると、そこそこ良い推定結果が得られていると思う。

pm.summary(trace)
mean sd mc_error hpd_2.5 hpd_97.5
beta 1.158679 0.047799 0.000675 1.068305 1.254242
r__0 -0.888382 0.301022 0.004542 -1.492071 -0.318896
r__1 -0.653750 0.205648 0.003100 -1.063285 -0.269927
r__2 -0.218527 0.211978 0.003377 -0.651866 0.181730
r__3 0.315092 0.205860 0.003063 -0.103261 0.703068
r__4 0.832849 0.195722 0.002816 0.461069 1.226940
r__5 1.281606 0.176526 0.002431 0.932831 1.627272
r__6 1.260697 0.181086 0.002758 0.902225 1.618805
r__7 1.321297 0.183583 0.002787 0.948448 1.661514
r__8 1.264942 0.180075 0.002479 0.901088 1.611691
r__9 1.315550 0.178816 0.002678 0.956416 1.655632
r__10 1.570310 0.170607 0.002199 1.230681 1.900745
r__11 1.695244 0.165743 0.002195 1.351274 2.006621
r__12 1.612901 0.170895 0.002235 1.284063 1.965944
r__13 1.727920 0.164937 0.002407 1.416415 2.059092
r__14 1.752215 0.168899 0.002160 1.412582 2.071013
r__15 1.705700 0.166617 0.002456 1.397315 2.053247
r__16 1.547344 0.172276 0.002583 1.210396 1.883085
r__17 1.415188 0.177532 0.002456 1.081155 1.770285
r__18 1.296239 0.186362 0.002506 0.938413 1.665728
r__19 1.165170 0.184071 0.002463 0.817927 1.541284
r__20 1.274738 0.179306 0.002607 0.923394 1.621070
r__21 1.502048 0.172138 0.002584 1.168851 1.827205
r__22 1.587426 0.170370 0.002340 1.247124 1.908877
r__23 1.415244 0.171736 0.002149 1.096096 1.761579
r__24 1.373549 0.178739 0.002443 1.022405 1.719808
r__25 1.541359 0.171126 0.002365 1.187936 1.854710
r__26 1.557773 0.173536 0.002299 1.220501 1.894539
r__27 1.557409 0.176749 0.002737 1.210693 1.904542
r__28 1.417466 0.176512 0.002367 1.068018 1.756552
r__29 1.136758 0.184823 0.002607 0.789616 1.508527
r__30 1.020712 0.185264 0.002552 0.671178 1.397461
r__31 1.254099 0.176998 0.002652 0.921095 1.610488
r__32 1.247103 0.179981 0.002625 0.918825 1.624593
r__33 1.256230 0.179677 0.002617 0.903487 1.614004
r__34 1.286939 0.178641 0.002481 0.924141 1.616439
r__35 1.160147 0.188822 0.002600 0.771362 1.514623
r__36 1.340738 0.171982 0.002378 0.998691 1.665439
r__37 1.451340 0.173466 0.002303 1.102961 1.783128
r__38 1.484233 0.174277 0.002155 1.135390 1.825727
r__39 1.461851 0.171625 0.002388 1.121778 1.793317
r__40 1.129038 0.183425 0.002540 0.745838 1.467143
r__41 0.915550 0.190102 0.002935 0.541487 1.276502
r__42 0.825596 0.197854 0.002697 0.431916 1.193682
r__43 0.568216 0.201373 0.002649 0.165221 0.949261
r__44 0.480365 0.195278 0.002611 0.092885 0.859808
r__45 0.523366 0.192335 0.002415 0.158896 0.904041
r__46 0.526652 0.199353 0.002596 0.126299 0.898693
r__47 0.250715 0.199579 0.002248 -0.161358 0.618004
r__48 -0.212703 0.229811 0.003280 -0.637011 0.258963
r__49 -0.801395 0.463723 0.007071 -1.727559 0.096066
s 0.347891 0.040211 0.000537 0.269606 0.424428

以上