gaiaskyの技術メモ

ベイズ、機械学習などのデータサイエンスがテーマ。言語はpythonで。

ベイズ生存時間分析(Weibull分布)

データ分析では正規分布を仮定することが多いが、生存時間分析・信頼性工学では、ワイブル分布を仮定することが多い。これはワイブル分布が、形状パラメータ・尺度パラメータによって、所謂バスタブカーブの3要素(初期故障、偶発故障、摩耗)を表現可能であるからと思う。一方、データには打ち切りが存在することが一般的で、そのために、ノンパラメトリックモデルのカプランマイヤー法で推定することが多いが、ベイズ推定を用いるとパラメトリックモデルで打ち切りデータもうまく扱えるようだ。

Bayesian Inference With Stan ~062~ | |

上記の記事はstanで実装されているため、今回はこの例題をpymcで実装してみることにした。

# データセットの読み込み。
df = load_rossi()

rossiデータの概要。

  • 再逮捕までの時間(週単位)
  • 右側打ち切り(52週)
  • 共変量あり(fin=1, 経済的支援あり)
df.head()
week arrest fin age race wexp mar paro prio
0 20 1 0 27 1 0 0 1 3
1 17 1 0 18 1 0 0 1 8
2 25 1 0 19 0 1 0 1 13
3 52 0 1 23 1 1 1 1 1
4 52 0 0 19 0 1 0 1 3
df_obs = df[df.arrest==1].week.values # 打ち切りなしのデータ
df_fin_obs = df[df.arrest==1].fin.values
df_cens = df[df.arrest==0].week.values # 右側打ち切りのデータ
df_fin_cens = df[df.arrest==0].fin.values

ワイブル分布の関数定義

# pdf
def weibProbDist(x, a, b):
    return (a / b) * (x / b) ** (a - 1) * np.exp(-(x / b) ** a)
# cdf
def weibCumDist(x, a, b):
    return 1 - np.exp(-(x / b) ** a)

ワイブル分布への当てはめ

pymc3の公式HPのExampleを参考にした。

Reparameterizing the Weibull Accelerated Failure Time Model — PyMC3 3.5 documentation

打ち切りなしデータと、打ち切りありデータのそれぞれを分けてサンプリングする。

def weibull_lccdf(x, shape, scale):
    ''' Log complementary cdf of Weibull distribution. '''
    return -(x / scale)**shape
with pm.Model() as model_1:
    # 形状パラメータ
    alpha_sd = 10.0
    alpha_raw = pm.Normal('a0', mu=0, sd=0.1)
    shape = pm.Deterministic('shape', tt.exp(alpha_sd * alpha_raw))
    # 尺度パラメータ
    mu = pm.Normal('mu', mu=0, sd=100)
    scale = pm.Deterministic('scale', tt.exp(mu / shape))

    y_obs = pm.Weibull('y_obs', alpha=shape, beta=scale, observed=df_obs)
    y_cens = pm.Potential('y_cens', weibull_lccdf(df_cens, shape, scale))
with model_1:
    trace_1 = pm.sample(5000, tune=2000, 
                        nuts_kwargs={'target_accept': 0.95}, 
                        init='adapt_diag')
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu, a0]
Sampling 2 chains: 100%|██████████| 14000/14000 [00:50<00:00, 277.78draws/s]
The number of effective samples is smaller than 25% for some parameters.
pm.traceplot(trace_1)

f:id:gaiasky:20181111232851p:plain

pm.summary(trace_1).round(2)
mean sd mc_error hpd_2.5 hpd_97.5 n_eff Rhat
a0 0.03 0.01 0.00 0.01 0.05 1680.58 1.0
mu 6.56 0.49 0.01 5.60 7.50 1671.80 1.0
shape 1.36 0.12 0.00 1.12 1.60 1671.94 1.0
scale 126.44 14.47 0.26 101.34 156.59 2817.21 1.0

推定したパラメータによる生存時間関数をグラフ表示する。

qlist = [2.5, 50, 97.5]
shapes = pm.quantiles(trace_1.get_values('shape'), qlist=qlist)
scales = pm.quantiles(trace_1.get_values('scale'), qlist=qlist)
# 生存時間関数表示関数
def plot_survival_function(ax, x, shape, scale, color='blue', qlist=[2.5,50,97.5]):
    low=qlist[0]
    mid=qlist[1]
    hig=qlist[2]
    ax.plot(x, 1-weibCumDist(x, shape[mid], scale[mid]), color=color)
    ax.fill_between(x=x, 
                 y1=1-weibCumDist(x, shape[low], scale[low]),
                 y2=1-weibCumDist(x, shape[hig], scale[hig]),
                 alpha=0.2, color=color)
    ax.set_xlabel('週')
    ax.set_ylabel('生存率')
    return ax
x = np.linspace(0,52,100)
fig = plt.figure(figsize=(9,4))
ax = fig.add_subplot(111)
ax = plot_survival_function(ax, x, shapes, scales, qlist=qlist)

f:id:gaiasky:20181111233138p:plain

カプランマイヤー法

  • ノンパラメトリックモデルであるカプランマイヤー法による生存率推定結果と比較する。
from lifelines import KaplanMeierFitter
kmf = KaplanMeierFitter()
kmf.fit(df.week.values, df.arrest.astype(np.bool), alpha=0.975)
x = np.linspace(0,52,100)
fig = plt.figure(figsize=(9,4))
ax = fig.add_subplot(111)
ax = plot_survival_function(ax, x, shapes, scales, qlist=qlist)
kmf.plot(ax=ax)

f:id:gaiasky:20181111233210p:plain

ベイズ推定はカプランマイヤー法よりも広めに推定されている。 元の記事の推定結果よりも広いため、うまく推定できていないのかも知れない。

比例ハザードモデル

  • 形状パラメータではなく、尺度パラメータに共変量の効果を入れる模様。
with pm.Model() as model_2:
    # 形状パラメータ
    alpha_sd = 10.0
    alpha_raw = pm.Normal('a0', mu=0, sd=0.1)
    shape = pm.Deterministic('shape', tt.exp(alpha_sd * alpha_raw))
    # 尺度パラメータ
    beta = pm.Normal('beta', mu=0, sd=10**2, shape=2)
    # 経済的支援ありの場合の尺度パラメータ
    scale1 = pm.Deterministic('scale1', tt.exp(-(beta[0]+beta[1])/shape))
    # 経済的支援なしの場合の尺度パラメータ
    scale2 = pm.Deterministic('scale2', tt.exp(-(beta[0])/shape))
                              
    y_obs = pm.Weibull('y_obs', alpha=shape, beta=tt.exp(-(beta[0]+df_fin_obs*beta[1])/shape), observed=df_obs)
    y_cens = pm.Potential('y_cens', weibull_lccdf(df_cens, shape, tt.exp(-(beta[0]+df_fin_cens*beta[1])/shape)))
with model_2:
    trace_2 = pm.sample(5000, tune=2000, 
                        nuts_kwargs={'target_accept': 0.9}, 
                        init='adapt_diag')
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [beta, a0]
Sampling 2 chains: 100%|██████████| 14000/14000 [01:30<00:00, 154.14draws/s]
pm.summary(trace_2).round(2)
mean sd mc_error hpd_2.5 hpd_97.5 n_eff Rhat
a0 0.03 0.01 0.00 0.01 0.05 2917.32 1.0
beta__0 -6.42 0.49 0.01 -7.38 -5.49 2859.43 1.0
beta__1 -0.37 0.19 0.00 -0.73 0.01 4608.48 1.0
shape 1.37 0.12 0.00 1.14 1.61 2877.35 1.0
scale1 146.66 21.97 0.33 108.67 190.92 4099.34 1.0
scale2 111.09 13.30 0.18 87.85 138.03 5590.24 1.0
pm.traceplot(trace_2)

f:id:gaiasky:20181111233422p:plain

shapes = pm.quantiles(trace_2.get_values('shape'), qlist=qlist)
scales1 = pm.quantiles(trace_2.get_values('scale1'), qlist=qlist)
scales2 = pm.quantiles(trace_2.get_values('scale2'), qlist=qlist)
fig = plt.figure(figsize=(9,4))
ax = fig.add_subplot(111)
plot_survival_function(ax, x, shapes, scales1)
plot_survival_function(ax, x, shapes, scales2, color='red')

f:id:gaiasky:20181111233458p:plain

Cox比例ハザードモデルとの比較

cph = lifelines.CoxPHFitter()
cph.fit(df[['week','arrest','fin']], duration_col='week', event_col='arrest', show_progress=False)
cph.print_summary()  # access the results using cph.summary
n=432, number of events=114

       coef  exp(coef)  se(coef)       z      p  lower 0.95  upper 0.95   
fin -0.3691     0.6914    0.1897 -1.9453 0.0517     -0.7409      0.0028  .
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 

Concordance = 0.546
Likelihood ratio test = 3.837 on 1 df, p=0.05013
fig = plt.figure(figsize=(9,4))
ax = fig.add_subplot(111)
plot_survival_function(ax, x, shapes, scales1)
plot_survival_function(ax, x, shapes, scales2, color='red')
cph.plot_covariate_groups('fin',[0,1], ax=ax)

f:id:gaiasky:20181111233536p:plain

こちらもやはり、ベイズ信頼区間の範囲が元の記事よりも広い。 元の記事(stan)では事前分布が明示されていないため、そこが違いを生んでいるのかも知れない。 とはいえ、pymcでの打ち切りデータを含む場合のweibull分布への当てはめ方法が分かった。

以上