标准化生存分析
参见《因果推断》一书第17.5节(“参数化的g公式”)。
在参数化标准化中,也称为“参数化g公式”,时间步k处的生存率是对协变量X水平和处理分配a条件下的条件生存率的加权平均,权重为每个分层中个体的比例。换句话说,类似于使用简单结果模型(S-Learner)的标准处理,我们拟合了一个包含基线协变量的风险模型。然后使用这个风险模型来计算生存曲线。
from causallib.survival.standardized_survival import StandardizedSurvival
standardized_survival = StandardizedSurvival(survival_model=LogisticRegression(max_iter=4000))
standardized_survival.fit(X, a, t, y)
population_averaged_survival_curves = standardized_survival.estimate_population_outcome(X, a, t)
plot_survival_curves(
population_averaged_survival_curves,
labels=["non-quitters", "quitters"],
title="Standardized survival of smoke quitters vs. non-quitters in a 10 years observation period",
)
或者,我们也可以使用 lifelines
包中的 RegressionFitter
类,例如 Cox 比例风险拟合器:
# Use lifelines Cox Proportional Hazards Fitter as a survival model for standardization
standardized_survival_cox = StandardizedSurvival(survival_model=lifelines.CoxPHFitter())
standardized_survival_cox.fit(X, a, t, y)
population_averaged_survival_curves = standardized_survival_cox.estimate_population_outcome(X, a, t)
plot_survival_curves(
population_averaged_survival_curves,
labels=["non-quitters", "quitters"],
title="Standardized survival of smoke quitters vs. non-quitters in a 10 years observation period (Cox PH)",
)
由于在标准化中我们对时间和协变量条件下的点风险进行建模,因此有一个良好指定的模型非常重要。使用过于简单的线性模型可能会导致“僵硬”的、过于简化的生存曲线。这里我们通过一个自定义的 scikit-learn
变换器添加额外的时间特征,以获得更平滑的曲线。可以将其与标准化生存分析部分的第一幅图(单元格13)进行比较。
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
class TimeTransform(BaseEstimator, TransformerMixin):
"""
Simple transformer for adding time points transformations
"""
def __init__(self, time_col_name):
super().__init__()
self.time_col_name = time_col_name
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
X_ = X.copy()
X_[self.time_col_name + "^2"] = X_[self.time_col_name] ** 2
X_[self.time_col_name + "^3"] = X_[self.time_col_name] ** 3
X_[self.time_col_name + "_sqrt"] = np.sqrt(X_[self.time_col_name])
return X_
time_transform_pipeline = Pipeline(
[("transform", TimeTransform(time_col_name=t.name)), ("LR", LogisticRegression(max_iter=2000))]
)
standardized_survival = StandardizedSurvival(survival_model=time_transform_pipeline)
standardized_survival.fit(X, a, t, y)
population_averaged_survival_curves = standardized_survival.estimate_population_outcome(X, a, t)
plot_survival_curves(
population_averaged_survival_curves,
labels=["non-quitters", "quitters"],
title="Standardized survival of smoke quitters vs. non-quitters in a 10 years observation period",
)