Survival analysis
sparsesurv [1] operates on survival analysis data. Below, we quote the notation from the supplementary section of our manuscript to ensure we are on the same page in terms of notation and language.
In particular survival concerns the analysis and modeling of a non-negative random variable \(T > 0\), that is used to model the time until an event of interest occurs. In observational survival datasets, we let \(T_i\) and \(C_i\) denote the event and right-censoring times of patient \(i\). In right-censored survival analysis, we observe triplets \((x_i, \delta_i, O_i)\), where \(O_i = \text{min}(T_i, C_i)\) and \(\delta_i = {1}(T_i \leq C_i)\). Throughout we assume conditionally independent censoring and non-informative censoring. That is, \(T \perp\!\!\!\!\perp C \mid X\) and \(C\) may not be a function of any of the parameters of \(T\) \citep{kalbfleisch2011statistical}. Further, let \(\lambda\) denote the hazard function, \(\Lambda\) be the cumulative hazard function, and \(S(t) = 1 - F(t)\) be the survival function, where \(F(t)\) denotes the cumulative distribution function. We let \(\tilde T\) be the set of unique, ascending-ordered death times. \(R_i\) is the risk set at time \(i\), that is, \(R(i) = \{j: O_j \geq O_i\}\). \(D_i\) denotes the death set at time \(i\), \(D(i) = \{j: O_j = i \land \delta_i = 1\}\).
For now, sparsesurv operats solely on right censored data, although we may consider an extension to other censoring and truncation schemes, if there is interest. We now briefly show an example right-censored survival dataset available in scikit-survival [4], another Python package for survival analysis.
[1]:
from sksurv.datasets import load_flchain
X, y = load_flchain()
[2]:
X
[2]:
age | chapter | creatinine | flc.grp | kappa | lambda | mgus | sample.yr | sex | |
---|---|---|---|---|---|---|---|---|---|
0 | 97.0 | Circulatory | 1.7 | 10 | 5.700 | 4.860 | no | 1997 | F |
1 | 92.0 | Neoplasms | 0.9 | 1 | 0.870 | 0.683 | no | 2000 | F |
2 | 94.0 | Circulatory | 1.4 | 10 | 4.360 | 3.850 | no | 1997 | F |
3 | 92.0 | Circulatory | 1.0 | 9 | 2.420 | 2.220 | no | 1996 | F |
4 | 93.0 | Circulatory | 1.1 | 6 | 1.320 | 1.690 | no | 1996 | F |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
7869 | 52.0 | NaN | 1.0 | 6 | 1.210 | 1.610 | no | 1995 | F |
7870 | 52.0 | NaN | 0.8 | 1 | 0.858 | 0.581 | no | 1999 | F |
7871 | 54.0 | NaN | NaN | 8 | 1.700 | 1.720 | no | 2002 | F |
7872 | 53.0 | NaN | NaN | 9 | 1.710 | 2.690 | no | 1995 | F |
7873 | 50.0 | NaN | 0.7 | 4 | 1.190 | 1.250 | no | 1998 | F |
7874 rows × 9 columns
The design matrix \(X\) looks the same as it would in other modeling settings, such as regression or classification, and thus does not require any special treatment from a modeling point of view.
[3]:
y
[3]:
array([( True, 85.), ( True, 1281.), ( True, 69.), ...,
(False, 2507.), (False, 4982.), (False, 3995.)],
dtype=[('death', '?'), ('futime', '<f8')])
The target \(y\) looks weird upon first glance however, as it contains two elements for each sample. These correspond exactly to \(O_i\) and \(\delta_i\) in our notation section above and respectively represent the censoring indicator and the observed time. Right-censored survival data is generally represented in structured array as is shown here, having one element for the censoring indicator and the observed time:
[4]:
y["death"]
[4]:
array([ True, True, True, ..., False, False, False])
[5]:
(y["death"]).astype(int)
[5]:
array([1, 1, 1, ..., 0, 0, 0])
While the number of covariates relative to the number of is quite good here (only 9 variables for 7,874 samples), a very common setting in (cancer) survival anaylsis is one where the number of available covariates is much larger than the number of available samples (i.e., \(p >> n\)). This is exactly the setting that sparsesurv is designed for. sparsesurv is based on knowledge distillation [5], which is also referred to as preconditioning [2] or reference models in statistics [3]. Thus, we briefly introduce the idea of knowledge distillation next.
[6]:
X.shape
[6]:
(7874, 9)
Knowledge distillation
The original idea of knowledge distillation was not directly related interpretabiltiy or feature selection. We note however, that the idea of using something akin to knowledge distillation was used and proposed in statistics before knowledge distillation itself, under the name of preconditioning. We will continue to use the name knowledge distillation since it may be more familiar to readers in the machine learning community.
The actual process of knowledge distillation proceeds in two steps:
1. Fit a teacher model that can approximate the target of interest (very) well
2. Fit a student model on the predictions of the teacher model, hoping that the teacher model acts as a kind of noise filter
We illustrate how this process can be adapted to survival analysis. Please note that this running example will continue to use scikit-survival
, before we move on to high-dimensional data.
[7]:
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sksurv.metrics import concordance_index_censored
X_train = X.iloc[:5000, [0, 3, 4]]
X_test = X.iloc[5000:, [0, 3, 4]]
y_train = y[:5000]
y_test = y[5000:]
teacher_pipe = make_pipeline(StandardScaler(), RandomSurvivalForest())
student_pipe = make_pipeline(StandardScaler(), LinearRegression())
baseline_pipe = make_pipeline(StandardScaler(), CoxPHSurvivalAnalysis())
teacher_pipe.fit(X_train, y_train)
student_pipe.fit(X_train, teacher_pipe.predict(X_train))
baseline_pipe.fit(X_train, y_train)
[7]:
Pipeline(steps=[('standardscaler', StandardScaler()), ('coxphsurvivalanalysis', CoxPHSurvivalAnalysis())])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('standardscaler', StandardScaler()), ('coxphsurvivalanalysis', CoxPHSurvivalAnalysis())])
StandardScaler()
CoxPHSurvivalAnalysis()
[8]:
concordance_index_censored(y_test["death"], y_test["futime"], teacher_pipe.predict(X_test))[0]
[8]:
0.5677982023489915
[9]:
concordance_index_censored(y_test["death"], y_test["futime"], student_pipe.predict(X_test))[0]
[9]:
0.6479734724443875
[10]:
concordance_index_censored(y_test["death"], y_test["futime"], baseline_pipe.predict(X_test))[0]
[10]:
0.6444986225266496
Interestingly, in this example, the student performance (as measured by Harrell’s concordance) was slightly higher than the baseline, despite the teacher performing quite bad. There is ongoing research in the ML community along these lines. For us, however, all that matters is that knowledge distillation works for survival analysis.
Minimal example of sparsesurv
Lastly, we give a brief example of usage of sparsesurv. If you are interested in using sparsesurv on your own data, please consult the documentation or the more specific user guides linked above.
[11]:
import pandas as pd
from sparsesurv.utils import transform_survival
df = pd.read_csv("https://zenodo.org/records/10027434/files/OV_data_preprocessed.csv?download=1")
[14]:
X = df.iloc[:, 3:].to_numpy()
[15]:
y = transform_survival(time=df.OS_days.values, event=df.OS.values)
[16]:
X.shape
[16]:
(302, 19076)
[17]:
y[: 5]
[17]:
array([( True, 304.), ( True, 24.), (False, 576.), (False, 1207.),
( True, 676.)], dtype=[('event', '?'), ('time', '<f8')])
[18]:
from sklearn.decomposition import PCA
from sparsesurv._base import KDSurv
from sparsesurv.cv import KDPHElasticNetCV
from sparsesurv.utils import transform_survival
pipe = KDSurv(
teacher=make_pipeline(
StandardScaler(),
PCA(n_components=16),
CoxPHSurvivalAnalysis(ties="efron"),
),
student=make_pipeline(
StandardScaler(),
KDPHElasticNetCV(
tie_correction="efron",
l1_ratio=0.9,
eps=0.01,
n_alphas=100,
cv=5,
stratify_cv=True,
seed=None,
shuffle_cv=False,
cv_score_method="linear_predictor",
n_jobs=1,
alpha_type="min",
),
),
)
[19]:
pipe.fit(X, y)
Now, we can easily check how many non-zero coefficients the fitted model has, or get predictions on the training set.
[25]:
import numpy as np
np.sum(pipe.student[1].coef_ != 0.0)
[25]:
260
[26]:
pipe.predict(X)
[26]:
array([ 2.82819899e-01, 4.66079463e-01, 1.15606183e-01, 5.02952325e-01,
4.17393777e-02, -4.11573164e-01, 7.60555538e-01, 4.34577104e-01,
1.14479065e-01, 3.75259090e-02, -1.34671115e-02, 1.61175017e-01,
1.00036148e-01, 2.40373951e-01, -3.25340021e-01, -8.67875954e-01,
7.96211926e-01, -1.83116380e-01, 1.07281533e-02, -4.77575349e-02,
-1.43752928e-01, 4.76747822e-01, -1.28454611e-02, -2.77780752e-01,
1.18303704e-01, 7.13485684e-01, -9.06133817e-02, -2.28805873e-01,
-1.85240257e-01, -1.88895164e-01, 1.79877864e-02, -4.40709840e-01,
-3.33191668e-02, -3.47169497e-01, -5.41839603e-02, -7.00115696e-01,
-4.83739189e-01, 1.49624486e-01, 2.72062364e-01, 7.45285449e-01,
-7.67869829e-02, -1.82008113e-01, -6.69662687e-02, 8.93221182e-02,
2.10881649e-01, -5.40481656e-01, -6.76554743e-02, 4.59392323e-02,
-1.14925641e-01, 1.08297960e-01, -1.48445361e-01, 3.92129238e-01,
3.46609042e-04, 1.68584073e-01, 1.64706486e-01, 2.00876547e-01,
-4.59086265e-01, -8.81797220e-02, -3.05503196e-01, -1.06486223e+00,
-8.44201513e-01, 3.24547683e-01, -1.86346579e-01, -1.51631340e-01,
1.79803120e-01, 1.06390383e-01, -6.11439908e-01, 6.71797157e-02,
-6.46251752e-01, 3.26988764e-01, 1.07586534e-01, 3.68125671e-02,
-3.73012562e-01, 5.57874420e-02, 3.79215839e-01, -2.26237176e-01,
-1.29248351e-02, 1.22046107e-01, 3.74760349e-01, 5.23166206e-01,
3.03795116e-01, -7.36812630e-01, 3.94627049e-01, 4.25870853e-02,
2.98564610e-02, 2.95674650e-01, -2.19203239e-01, -5.11737562e-01,
-5.24482896e-02, -4.64111839e-03, -3.21799909e-01, -7.14669688e-01,
5.84417412e-01, -2.90476660e-01, 1.78423635e-01, -2.39300019e-01,
1.80786722e-01, -3.34470179e-02, -2.56003735e-01, -5.37086393e-02,
-4.43376994e-01, 3.19931551e-01, 1.97498674e-01, -2.61924192e-02,
1.33033208e-01, -6.95678071e-02, -7.24221455e-02, 5.40122505e-01,
2.65075056e-01, -9.89361408e-01, 6.19007678e-02, 2.07408942e-01,
6.82609874e-02, -6.37693151e-01, -3.20701807e-01, -6.57601954e-01,
1.24781707e-01, 1.89385766e-01, 2.85797247e-01, -5.92343166e-02,
3.80218635e-01, 5.44106400e-01, 6.82875840e-01, -3.90579509e-01,
-8.42389835e-02, 7.58434789e-01, -8.61379695e-02, 7.01730243e-01,
-1.48692814e-01, 7.07359015e-02, -6.76859547e-02, -4.46045375e-02,
4.81015185e-01, -1.98355534e-01, -1.59107820e-01, -3.26734493e-01,
4.56446108e-01, -4.22370601e-01, 8.02473240e-01, 1.50128880e-01,
6.83395951e-01, 2.06511496e-01, 2.67441747e-01, 8.38858830e-02,
3.20384092e-01, 6.08116886e-01, 3.70467301e-01, -1.47024656e-02,
2.73821126e-01, -2.22213948e-01, 3.45943407e-01, 2.92928823e-01,
-5.43679120e-01, 1.20502523e-01, 5.61094405e-01, -9.07648816e-02,
-9.08701304e-02, 5.10690412e-01, -1.53761912e-01, 4.23909767e-01,
-4.37683251e-01, 3.16901267e-01, 3.95289983e-01, -2.98683737e-01,
2.21080367e-01, 8.72769946e-02, 1.29061883e-01, 1.20706128e-01,
-1.15802828e-01, -7.05581525e-02, -3.29695109e-01, 2.26276519e-01,
8.84738248e-01, 1.01021425e-01, -1.40474023e-01, -2.09845104e-01,
-7.54611778e-02, -5.61162007e-02, -1.97142810e-01, 7.53374053e-02,
5.01884547e-01, -3.95690048e-01, -1.22748153e-01, 4.00627789e-01,
-2.17081618e-01, -2.00329194e-01, -1.83593487e-01, -1.00495352e-01,
-1.27141121e-01, 4.24449987e-02, -9.90337386e-03, 7.94286450e-02,
-5.22270616e-01, 1.33174475e-01, 9.51870865e-02, 5.27839777e-01,
5.82771390e-01, 3.33622667e-01, -5.25593641e-01, -4.71537211e-01,
-4.27592018e-01, 1.73267690e-01, 3.51133339e-01, -1.63715494e-01,
2.90653134e-01, -2.67936877e-01, 1.74063236e-01, 1.31693212e-01,
2.69151375e-01, -1.73910919e-01, 3.56312451e-01, -2.95944572e-01,
-6.76934153e-01, 1.15260613e-01, 7.46456877e-01, -3.96934560e-01,
6.12222822e-01, -2.83372799e-02, -3.51497255e-01, 6.33327038e-01,
2.71835304e-01, 2.94155228e-01, 5.16169771e-02, 1.08403732e-02,
6.67114061e-01, -1.82332792e-01, -1.21011324e-01, -4.54893674e-01,
-7.96243813e-01, 3.81974552e-01, -1.07648224e-01, -1.71062783e-02,
2.23721369e-01, 1.36297632e-01, 3.11243091e-01, -7.24278109e-03,
6.13225615e-02, -3.55728676e-01, 1.71853952e-01, 8.08492454e-01,
1.08213626e-01, 5.32913239e-02, -9.11289836e-02, -7.65856673e-02,
1.58455383e-01, -5.23370219e-01, -3.94924071e-01, -1.09364829e-01,
-3.78914862e-01, -1.89939895e-01, -1.36739936e-01, -2.63573754e-01,
1.01121780e-01, 7.21476287e-02, -9.03630706e-02, 3.04653321e-01,
1.10769228e-01, -6.11767145e-01, 2.23789775e-01, 1.68579657e-01,
5.42411777e-01, 4.75852309e-01, 5.58213127e-01, -4.18131792e-02,
-4.55969471e-01, -5.50211628e-01, -6.02525807e-01, -1.94368141e-01,
-3.66442961e-01, -1.26307742e-01, 1.02601869e-01, -5.58109635e-01,
-2.63724315e-01, -4.75675963e-01, 2.74982443e-01, -4.08772709e-01,
-7.00331777e-02, -4.36649556e-01, -6.36961440e-02, -3.30293833e-01,
-2.75192094e-01, -3.93795722e-01, -3.22032629e-01, 4.70539660e-01,
2.46172417e-01, -1.01038986e-01, -3.22538044e-01, -4.07975457e-01,
7.72229576e-02, -2.91750927e-01, -2.78907089e-01, -7.91286409e-02,
-1.04626150e-01, 3.41475254e-01, -1.85888436e-01, -2.36925667e-01,
-1.33571157e-01, -1.26144809e-01, -4.56658067e-01, 4.44035026e-01,
1.60457468e-02, -2.75158258e-01])
References
[1] David Wissel, Nikita Janakarajan, Daniel Rowson, Julius Schulte, Xintian Yuan, Valentina Boeva. “sparsesurv: Sparse survival models via knowledge distillation.” (2023, under review).
[2] Paul, Debashis, et al. ““Preconditioning” for feature selection and regression in high-dimensional problems.” (2008): 1595-1618.
[3] Pavone, Federico, et al. “Using reference models in variable selection.” Computational Statistics 38.1 (2023): 349-371.
[4] Pölsterl, Sebastian. “scikit-survival: A Library for Time-to-Event Analysis Built on Top of scikit-learn.” The Journal of Machine Learning Research 21.1 (2020): 8747-8752.
[5] Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. “Distilling the knowledge in a neural network.” arXiv preprint arXiv:1503.02531 (2015).