{ "cells": [ { "cell_type": "markdown", "id": "1af32a7e", "metadata": {}, "source": [ "# Basic usage of `sparsesurv`" ] }, { "cell_type": "code", "execution_count": 20, "id": "7d812823", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sparsesurv.utils import transform_survival\n", "from sklearn.decomposition import PCA\n", "from sparsesurv._base import KDSurv\n", "from sparsesurv.cv import KDPHElasticNetCV, KDEHMultiTaskLassoCV, KDAFTElasticNetCV\n", "from sparsesurv.utils import transform_survival\n", "from sklearn.pipeline import make_pipeline\n", "from sksurv.linear_model import CoxPHSurvivalAnalysis\n", "from sparsesurv.aft import AFT\n", "from sparsesurv.eh import EH\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "df = pd.read_csv(\"https://zenodo.org/records/10027434/files/OV_data_preprocessed.csv?download=1\")\n", "X = df.iloc[:, 3:].to_numpy()\n", "y = transform_survival(time=df.OS_days.values, event=df.OS.values)\n", "\n", "X_train = X[:200]\n", "X_test = X[200:]\n", "y_train = y[:200]\n", "y_test = y[200:]" ] }, { "cell_type": "code", "execution_count": 10, "id": "8dc8f485", "metadata": {}, "outputs": [], "source": [ "pipe = KDSurv(\n", " teacher=make_pipeline(\n", " StandardScaler(),\n", " PCA(n_components=16),\n", " CoxPHSurvivalAnalysis(ties=\"efron\"),\n", " ),\n", " student=make_pipeline(\n", " StandardScaler(),\n", " KDPHElasticNetCV(\n", " tie_correction=\"efron\",\n", " l1_ratio=0.9,\n", " eps=0.01,\n", " n_alphas=100,\n", " cv=5,\n", " stratify_cv=True,\n", " seed=None,\n", " shuffle_cv=False,\n", " cv_score_method=\"linear_predictor\",\n", " n_jobs=1,\n", " alpha_type=\"min\",\n", " ),\n", " ),\n", " )" ] }, { "cell_type": "markdown", "id": "f5d807a8", "metadata": {}, "source": [ "Above, we have set up an example object of *sparsesurv*. We will now go through some of the parameters of the relevant classes for fitting a model with *sparsesurv*." ] }, { "cell_type": "markdown", "id": "832c239f", "metadata": {}, "source": [ "## `KDSurv`" ] }, { "cell_type": "markdown", "id": "a3761950", "metadata": {}, "source": [ "`KDSurv` is essentially \"just\" a wrapper classes that holds the teacher and student and wraps their training and prediction into a convenient `sklearn` API. The `teacher` and `student` parameters naturally both correspond to teacher and student. We note that both teacher and student must be `sklearn` API compatible in order to work with the `KDSurv` class." ] }, { "cell_type": "markdown", "id": "04d19012", "metadata": {}, "source": [ "## Student classes" ] }, { "cell_type": "markdown", "id": "9223e8e5", "metadata": {}, "source": [ "While there is not much more to say about the teacher, the student requires additional consideration, in particular since they are fully implemented in `sparsesurv` and cannot (easily) be replaced by external models." ] }, { "cell_type": "markdown", "id": "be0f4aa2", "metadata": {}, "source": [ "There are three student types in `sparsesurv`, each corresponding to one of the three model types implemented:\n", "\n", " 1. sparsesurv.cv.KDPHElasticNetCV\n", " \n", " 2. sparsesurv.cv.KDAFTElasticNetCV\n", " \n", " 3. sparsesurv.cv.KDEHMultiTaskLassoCV" ] }, { "cell_type": "markdown", "id": "e27bbabe", "metadata": {}, "source": [ "We note that since the Extended Hazards (EH) model has two linear predictors, it is implemented via a multi task lasso instead of a Lasso or Elastic Net." ] }, { "cell_type": "markdown", "id": "910ed5d0", "metadata": {}, "source": [ "Please refer to our full API for further details on these parameters. Below, we show how each of these models may be fit." ] }, { "cell_type": "code", "execution_count": 23, "id": "b3ad2764", "metadata": {}, "outputs": [], "source": [ "pipe_cox_efron = KDSurv(\n", " teacher=make_pipeline(\n", " StandardScaler(),\n", " PCA(n_components=16),\n", " CoxPHSurvivalAnalysis(ties=\"efron\"),\n", " ),\n", " student=make_pipeline(\n", " StandardScaler(),\n", " KDPHElasticNetCV(\n", " tie_correction=\"efron\",\n", " l1_ratio=0.9,\n", " eps=0.01,\n", " n_alphas=100,\n", " cv=5,\n", " stratify_cv=True,\n", " seed=None,\n", " shuffle_cv=False,\n", " cv_score_method=\"linear_predictor\",\n", " n_jobs=1,\n", " alpha_type=\"min\",\n", " ),\n", " ),\n", " )\n", "\n", "pipe_cox_breslow = KDSurv(\n", " teacher=make_pipeline(\n", " StandardScaler(),\n", " PCA(n_components=16),\n", " CoxPHSurvivalAnalysis(ties=\"breslow\"),\n", " ),\n", " student=make_pipeline(\n", " StandardScaler(),\n", " KDPHElasticNetCV(\n", " tie_correction=\"breslow\",\n", " l1_ratio=0.9,\n", " eps=0.01,\n", " n_alphas=100,\n", " cv=5,\n", " stratify_cv=True,\n", " seed=None,\n", " shuffle_cv=False,\n", " cv_score_method=\"linear_predictor\",\n", " n_jobs=1,\n", " alpha_type=\"min\",\n", " ),\n", " ),\n", " )\n", "\n", "pipe_cox_aft = KDSurv(\n", " teacher=make_pipeline(\n", " StandardScaler(),\n", " PCA(n_components=16),\n", " AFT()\n", " ),\n", " student=make_pipeline(\n", " StandardScaler(),\n", " KDAFTElasticNetCV(\n", " l1_ratio=0.9,\n", " eps=0.01,\n", " n_alphas=100,\n", " cv=5,\n", " stratify_cv=True,\n", " seed=None,\n", " shuffle_cv=False,\n", " cv_score_method=\"linear_predictor\",\n", " n_jobs=1,\n", " alpha_type=\"min\",\n", " ),\n", " ),\n", " )\n", "\n", "pipe_cox_eh = KDSurv(\n", " teacher=make_pipeline(\n", " StandardScaler(),\n", " PCA(n_components=16),\n", " EH()\n", " ),\n", " student=make_pipeline(\n", " StandardScaler(),\n", " KDEHMultiTaskLassoCV(\n", " eps=0.01,\n", " n_alphas=100,\n", " cv=5,\n", " stratify_cv=True,\n", " seed=None,\n", " shuffle_cv=False,\n", " cv_score_method=\"linear_predictor\",\n", " n_jobs=1,\n", " alpha_type=\"min\",\n", " ),\n", " ),\n", " )" ] }, { "cell_type": "code", "execution_count": 24, "id": "1b44d1e4", "metadata": {}, "outputs": [], "source": [ "pipe_cox_efron.fit(X_train, y_train)\n", "pipe_cox_breslow.fit(X_train, y_train)\n", "pipe_cox_aft.fit(X_train, y_train)\n", "pipe_cox_eh.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 26, "id": "c469f8f3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "176" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "np.sum(pipe_cox_efron.student[1].coef_ != 0.0)" ] }, { "cell_type": "code", "execution_count": 27, "id": "c90095a2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "173" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.sum(pipe_cox_breslow.student[1].coef_ != 0.0)" ] }, { "cell_type": "code", "execution_count": 28, "id": "a1aaddfb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "185" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.sum(pipe_cox_aft.student[1].coef_ != 0.0)" ] }, { "cell_type": "code", "execution_count": 31, "id": "a6a8aafa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "319" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "int(np.sum(pipe_cox_eh.student[1].coef_ != 0.0) / 2)" ] }, { "cell_type": "code", "execution_count": 35, "id": "da39e41b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5215973920130399" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sksurv.metrics import concordance_index_censored\n", "concordance_index_censored(y_test[\"event\"], y_test[\"time\"], pipe_cox_efron.predict(X_test))[0]" ] }, { "cell_type": "code", "execution_count": 36, "id": "06238c0d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.530562347188264" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "concordance_index_censored(y_test[\"event\"], y_test[\"time\"], pipe_cox_breslow.predict(X_test))[0]" ] }, { "cell_type": "code", "execution_count": 37, "id": "92eb2aea", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5872045639771801" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "concordance_index_censored(y_test[\"event\"], y_test[\"time\"], pipe_cox_aft.predict(X_test))[0]" ] }, { "cell_type": "markdown", "id": "d0fc1063", "metadata": {}, "source": [ "We can see that both sparsity and discriminative performance of the distilled models can depend significantly on the teacher model (class)." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 5 }