ベイズ生存時間分析(Weibull分布)
データ分析では正規分布を仮定することが多いが、生存時間分析・信頼性工学では、ワイブル分布を仮定することが多い。これはワイブル分布が、形状パラメータ・尺度パラメータによって、所謂バスタブカーブの3要素(初期故障、偶発故障、摩耗)を表現可能であるからと思う。一方、データには打ち切りが存在することが一般的で、そのために、ノンパラメトリックモデルのカプランマイヤー法で推定することが多いが、ベイズ推定を用いるとパラメトリックモデルで打ち切りデータもうまく扱えるようだ。
Bayesian Inference With Stan ~062~ | |
上記の記事はstanで実装されているため、今回はこの例題をpymcで実装してみることにした。
# データセットの読み込み。
df = load_rossi()
rossiデータの概要。
- 再逮捕までの時間(週単位)
- 右側打ち切り(52週)
- 共変量あり(fin=1, 経済的支援あり)
df.head()
week | arrest | fin | age | race | wexp | mar | paro | prio | |
---|---|---|---|---|---|---|---|---|---|
0 | 20 | 1 | 0 | 27 | 1 | 0 | 0 | 1 | 3 |
1 | 17 | 1 | 0 | 18 | 1 | 0 | 0 | 1 | 8 |
2 | 25 | 1 | 0 | 19 | 0 | 1 | 0 | 1 | 13 |
3 | 52 | 0 | 1 | 23 | 1 | 1 | 1 | 1 | 1 |
4 | 52 | 0 | 0 | 19 | 0 | 1 | 0 | 1 | 3 |
df_obs = df[df.arrest==1].week.values # 打ち切りなしのデータ df_fin_obs = df[df.arrest==1].fin.values df_cens = df[df.arrest==0].week.values # 右側打ち切りのデータ df_fin_cens = df[df.arrest==0].fin.values
ワイブル分布の関数定義
# pdf def weibProbDist(x, a, b): return (a / b) * (x / b) ** (a - 1) * np.exp(-(x / b) ** a) # cdf def weibCumDist(x, a, b): return 1 - np.exp(-(x / b) ** a)
ワイブル分布への当てはめ
pymc3の公式HPのExampleを参考にした。
Reparameterizing the Weibull Accelerated Failure Time Model — PyMC3 3.5 documentation
打ち切りなしデータと、打ち切りありデータのそれぞれを分けてサンプリングする。
def weibull_lccdf(x, shape, scale): ''' Log complementary cdf of Weibull distribution. ''' return -(x / scale)**shape
with pm.Model() as model_1: # 形状パラメータ alpha_sd = 10.0 alpha_raw = pm.Normal('a0', mu=0, sd=0.1) shape = pm.Deterministic('shape', tt.exp(alpha_sd * alpha_raw)) # 尺度パラメータ mu = pm.Normal('mu', mu=0, sd=100) scale = pm.Deterministic('scale', tt.exp(mu / shape)) y_obs = pm.Weibull('y_obs', alpha=shape, beta=scale, observed=df_obs) y_cens = pm.Potential('y_cens', weibull_lccdf(df_cens, shape, scale))
with model_1: trace_1 = pm.sample(5000, tune=2000, nuts_kwargs={'target_accept': 0.95}, init='adapt_diag')
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu, a0]
Sampling 2 chains: 100%|██████████| 14000/14000 [00:50<00:00, 277.78draws/s]
The number of effective samples is smaller than 25% for some parameters.
pm.traceplot(trace_1)
pm.summary(trace_1).round(2)
mean | sd | mc_error | hpd_2.5 | hpd_97.5 | n_eff | Rhat | |
---|---|---|---|---|---|---|---|
a0 | 0.03 | 0.01 | 0.00 | 0.01 | 0.05 | 1680.58 | 1.0 |
mu | 6.56 | 0.49 | 0.01 | 5.60 | 7.50 | 1671.80 | 1.0 |
shape | 1.36 | 0.12 | 0.00 | 1.12 | 1.60 | 1671.94 | 1.0 |
scale | 126.44 | 14.47 | 0.26 | 101.34 | 156.59 | 2817.21 | 1.0 |
推定したパラメータによる生存時間関数をグラフ表示する。
qlist = [2.5, 50, 97.5] shapes = pm.quantiles(trace_1.get_values('shape'), qlist=qlist) scales = pm.quantiles(trace_1.get_values('scale'), qlist=qlist)
# 生存時間関数表示関数 def plot_survival_function(ax, x, shape, scale, color='blue', qlist=[2.5,50,97.5]): low=qlist[0] mid=qlist[1] hig=qlist[2] ax.plot(x, 1-weibCumDist(x, shape[mid], scale[mid]), color=color) ax.fill_between(x=x, y1=1-weibCumDist(x, shape[low], scale[low]), y2=1-weibCumDist(x, shape[hig], scale[hig]), alpha=0.2, color=color) ax.set_xlabel('週') ax.set_ylabel('生存率') return ax
x = np.linspace(0,52,100) fig = plt.figure(figsize=(9,4)) ax = fig.add_subplot(111) ax = plot_survival_function(ax, x, shapes, scales, qlist=qlist)
カプランマイヤー法
- ノンパラメトリックモデルであるカプランマイヤー法による生存率推定結果と比較する。
from lifelines import KaplanMeierFitter kmf = KaplanMeierFitter() kmf.fit(df.week.values, df.arrest.astype(np.bool), alpha=0.975)
x = np.linspace(0,52,100) fig = plt.figure(figsize=(9,4)) ax = fig.add_subplot(111) ax = plot_survival_function(ax, x, shapes, scales, qlist=qlist) kmf.plot(ax=ax)
ベイズ推定はカプランマイヤー法よりも広めに推定されている。 元の記事の推定結果よりも広いため、うまく推定できていないのかも知れない。
比例ハザードモデル
- 形状パラメータではなく、尺度パラメータに共変量の効果を入れる模様。
with pm.Model() as model_2: # 形状パラメータ alpha_sd = 10.0 alpha_raw = pm.Normal('a0', mu=0, sd=0.1) shape = pm.Deterministic('shape', tt.exp(alpha_sd * alpha_raw)) # 尺度パラメータ beta = pm.Normal('beta', mu=0, sd=10**2, shape=2) # 経済的支援ありの場合の尺度パラメータ scale1 = pm.Deterministic('scale1', tt.exp(-(beta[0]+beta[1])/shape)) # 経済的支援なしの場合の尺度パラメータ scale2 = pm.Deterministic('scale2', tt.exp(-(beta[0])/shape)) y_obs = pm.Weibull('y_obs', alpha=shape, beta=tt.exp(-(beta[0]+df_fin_obs*beta[1])/shape), observed=df_obs) y_cens = pm.Potential('y_cens', weibull_lccdf(df_cens, shape, tt.exp(-(beta[0]+df_fin_cens*beta[1])/shape)))
with model_2: trace_2 = pm.sample(5000, tune=2000, nuts_kwargs={'target_accept': 0.9}, init='adapt_diag')
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [beta, a0]
Sampling 2 chains: 100%|██████████| 14000/14000 [01:30<00:00, 154.14draws/s]
pm.summary(trace_2).round(2)
mean | sd | mc_error | hpd_2.5 | hpd_97.5 | n_eff | Rhat | |
---|---|---|---|---|---|---|---|
a0 | 0.03 | 0.01 | 0.00 | 0.01 | 0.05 | 2917.32 | 1.0 |
beta__0 | -6.42 | 0.49 | 0.01 | -7.38 | -5.49 | 2859.43 | 1.0 |
beta__1 | -0.37 | 0.19 | 0.00 | -0.73 | 0.01 | 4608.48 | 1.0 |
shape | 1.37 | 0.12 | 0.00 | 1.14 | 1.61 | 2877.35 | 1.0 |
scale1 | 146.66 | 21.97 | 0.33 | 108.67 | 190.92 | 4099.34 | 1.0 |
scale2 | 111.09 | 13.30 | 0.18 | 87.85 | 138.03 | 5590.24 | 1.0 |
pm.traceplot(trace_2)
shapes = pm.quantiles(trace_2.get_values('shape'), qlist=qlist) scales1 = pm.quantiles(trace_2.get_values('scale1'), qlist=qlist) scales2 = pm.quantiles(trace_2.get_values('scale2'), qlist=qlist)
fig = plt.figure(figsize=(9,4)) ax = fig.add_subplot(111) plot_survival_function(ax, x, shapes, scales1) plot_survival_function(ax, x, shapes, scales2, color='red')
Cox比例ハザードモデルとの比較
cph = lifelines.CoxPHFitter() cph.fit(df[['week','arrest','fin']], duration_col='week', event_col='arrest', show_progress=False) cph.print_summary() # access the results using cph.summary
n=432, number of events=114
coef exp(coef) se(coef) z p lower 0.95 upper 0.95
fin -0.3691 0.6914 0.1897 -1.9453 0.0517 -0.7409 0.0028 .
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Concordance = 0.546
Likelihood ratio test = 3.837 on 1 df, p=0.05013
fig = plt.figure(figsize=(9,4)) ax = fig.add_subplot(111) plot_survival_function(ax, x, shapes, scales1) plot_survival_function(ax, x, shapes, scales2, color='red') cph.plot_covariate_groups('fin',[0,1], ax=ax)
こちらもやはり、ベイズ信頼区間の範囲が元の記事よりも広い。 元の記事(stan)では事前分布が明示されていないため、そこが違いを生んでいるのかも知れない。 とはいえ、pymcでの打ち切りデータを含む場合のweibull分布への当てはめ方法が分かった。
以上