Basic usage of sparsesurv

[20]:
import pandas as pd
from sparsesurv.utils import transform_survival
from sklearn.decomposition import PCA
from sparsesurv._base import KDSurv
from sparsesurv.cv import KDPHElasticNetCV, KDEHMultiTaskLassoCV, KDAFTElasticNetCV
from sparsesurv.utils import transform_survival
from sklearn.pipeline import make_pipeline
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sparsesurv.aft import AFT
from sparsesurv.eh import EH
from sklearn.preprocessing import StandardScaler

df = pd.read_csv("https://zenodo.org/records/10027434/files/OV_data_preprocessed.csv?download=1")
X = df.iloc[:, 3:].to_numpy()
y = transform_survival(time=df.OS_days.values, event=df.OS.values)

X_train = X[:200]
X_test = X[200:]
y_train = y[:200]
y_test = y[200:]
[10]:
pipe = KDSurv(
            teacher=make_pipeline(
                StandardScaler(),
                PCA(n_components=16),
                CoxPHSurvivalAnalysis(ties="efron"),
            ),
            student=make_pipeline(
                StandardScaler(),
                KDPHElasticNetCV(
                    tie_correction="efron",
                    l1_ratio=0.9,
                    eps=0.01,
                    n_alphas=100,
                    cv=5,
                    stratify_cv=True,
                    seed=None,
                    shuffle_cv=False,
                    cv_score_method="linear_predictor",
                    n_jobs=1,
                    alpha_type="min",
                ),
            ),
        )

Above, we have set up an example object of sparsesurv. We will now go through some of the parameters of the relevant classes for fitting a model with sparsesurv.

KDSurv

KDSurv is essentially “just” a wrapper classes that holds the teacher and student and wraps their training and prediction into a convenient sklearn API. The teacher and student parameters naturally both correspond to teacher and student. We note that both teacher and student must be sklearn API compatible in order to work with the KDSurv class.

Student classes

While there is not much more to say about the teacher, the student requires additional consideration, in particular since they are fully implemented in sparsesurv and cannot (easily) be replaced by external models.

There are three student types in sparsesurv, each corresponding to one of the three model types implemented:

1. sparsesurv.cv.KDPHElasticNetCV

2. sparsesurv.cv.KDAFTElasticNetCV

3. sparsesurv.cv.KDEHMultiTaskLassoCV

We note that since the Extended Hazards (EH) model has two linear predictors, it is implemented via a multi task lasso instead of a Lasso or Elastic Net.

Please refer to our full API for further details on these parameters. Below, we show how each of these models may be fit.

[23]:
pipe_cox_efron = KDSurv(
            teacher=make_pipeline(
                StandardScaler(),
                PCA(n_components=16),
                CoxPHSurvivalAnalysis(ties="efron"),
            ),
            student=make_pipeline(
                StandardScaler(),
                KDPHElasticNetCV(
                    tie_correction="efron",
                    l1_ratio=0.9,
                    eps=0.01,
                    n_alphas=100,
                    cv=5,
                    stratify_cv=True,
                    seed=None,
                    shuffle_cv=False,
                    cv_score_method="linear_predictor",
                    n_jobs=1,
                    alpha_type="min",
                ),
            ),
        )

pipe_cox_breslow = KDSurv(
            teacher=make_pipeline(
                StandardScaler(),
                PCA(n_components=16),
                CoxPHSurvivalAnalysis(ties="breslow"),
            ),
            student=make_pipeline(
                StandardScaler(),
                KDPHElasticNetCV(
                    tie_correction="breslow",
                    l1_ratio=0.9,
                    eps=0.01,
                    n_alphas=100,
                    cv=5,
                    stratify_cv=True,
                    seed=None,
                    shuffle_cv=False,
                    cv_score_method="linear_predictor",
                    n_jobs=1,
                    alpha_type="min",
                ),
            ),
        )

pipe_cox_aft = KDSurv(
            teacher=make_pipeline(
                StandardScaler(),
                PCA(n_components=16),
                AFT()
            ),
            student=make_pipeline(
                StandardScaler(),
                KDAFTElasticNetCV(
                    l1_ratio=0.9,
                    eps=0.01,
                    n_alphas=100,
                    cv=5,
                    stratify_cv=True,
                    seed=None,
                    shuffle_cv=False,
                    cv_score_method="linear_predictor",
                    n_jobs=1,
                    alpha_type="min",
                ),
            ),
        )

pipe_cox_eh = KDSurv(
            teacher=make_pipeline(
                StandardScaler(),
                PCA(n_components=16),
                EH()
            ),
            student=make_pipeline(
                StandardScaler(),
                KDEHMultiTaskLassoCV(
                    eps=0.01,
                    n_alphas=100,
                    cv=5,
                    stratify_cv=True,
                    seed=None,
                    shuffle_cv=False,
                    cv_score_method="linear_predictor",
                    n_jobs=1,
                    alpha_type="min",
                ),
            ),
        )
[24]:
pipe_cox_efron.fit(X_train, y_train)
pipe_cox_breslow.fit(X_train, y_train)
pipe_cox_aft.fit(X_train, y_train)
pipe_cox_eh.fit(X_train, y_train)
[26]:
import numpy as np
np.sum(pipe_cox_efron.student[1].coef_ != 0.0)
[26]:
176
[27]:
np.sum(pipe_cox_breslow.student[1].coef_ != 0.0)
[27]:
173
[28]:
np.sum(pipe_cox_aft.student[1].coef_ != 0.0)
[28]:
185
[31]:
int(np.sum(pipe_cox_eh.student[1].coef_ != 0.0) / 2)
[31]:
319
[35]:
from sksurv.metrics import concordance_index_censored
concordance_index_censored(y_test["event"], y_test["time"], pipe_cox_efron.predict(X_test))[0]
[35]:
0.5215973920130399
[36]:
concordance_index_censored(y_test["event"], y_test["time"], pipe_cox_breslow.predict(X_test))[0]
[36]:
0.530562347188264
[37]:
concordance_index_censored(y_test["event"], y_test["time"], pipe_cox_aft.predict(X_test))[0]
[37]:
0.5872045639771801

We can see that both sparsity and discriminative performance of the distilled models can depend significantly on the teacher model (class).