{ "cells": [ { "cell_type": "markdown", "id": "b3cb023c", "metadata": {}, "source": [ "## Survival analysis" ] }, { "cell_type": "markdown", "id": "aedae0d7", "metadata": {}, "source": [ "sparsesurv [1] operates on survival analysis data. Below, we quote the notation from the supplementary section of our manuscript to ensure we are on the same page in terms of notation and language.\n", "\n", "> In particular survival concerns the analysis and modeling of a non-negative random variable $T > 0$, that is used to model the time until an event of interest occurs. In observational survival datasets, we let $T_i$ and $C_i$ denote the event and right-censoring times of patient $i$. In right-censored survival analysis, we observe triplets $(x_i, \\delta_i, O_i)$, where $O_i = \\text{min}(T_i, C_i)$ and $\\delta_i = {1}(T_i \\leq C_i)$. Throughout we assume conditionally independent censoring and non-informative censoring. That is, $T \\perp\\!\\!\\!\\!\\perp C \\mid X$ and $C$ may not be a function of any of the parameters of $T$ \\citep{kalbfleisch2011statistical}. Further, let $\\lambda$ denote the hazard function, $\\Lambda$ be the cumulative hazard function, and $S(t) = 1 - F(t)$ be the survival function, where $F(t)$ denotes the cumulative distribution function. We let $\\tilde T$ be the set of unique, ascending-ordered death times. $R_i$ is the risk set at time $i$, that is, $R(i) = \\{j: O_j \\geq O_i\\}$. $D_i$ denotes the death set at time $i$, $D(i) = \\{j: O_j = i \\land \\delta_i = 1\\}$.\n", "\n", "For now, *sparsesurv* operats solely on right censored data, although we may consider an extension to other censoring and truncation schemes, if there is interest. We now briefly show an example right-censored survival dataset available in *scikit-survival* [4], another Python package for survival analysis." ] }, { "cell_type": "code", "execution_count": 1, "id": "abecf90f", "metadata": {}, "outputs": [], "source": [ "from sksurv.datasets import load_flchain\n", "X, y = load_flchain()" ] }, { "cell_type": "code", "execution_count": 2, "id": "e92dd399", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agechaptercreatinineflc.grpkappalambdamgussample.yrsex
097.0Circulatory1.7105.7004.860no1997F
192.0Neoplasms0.910.8700.683no2000F
294.0Circulatory1.4104.3603.850no1997F
392.0Circulatory1.092.4202.220no1996F
493.0Circulatory1.161.3201.690no1996F
..............................
786952.0NaN1.061.2101.610no1995F
787052.0NaN0.810.8580.581no1999F
787154.0NaNNaN81.7001.720no2002F
787253.0NaNNaN91.7102.690no1995F
787350.0NaN0.741.1901.250no1998F
\n", "

7874 rows × 9 columns

\n", "
" ], "text/plain": [ " age chapter creatinine flc.grp kappa lambda mgus sample.yr sex\n", "0 97.0 Circulatory 1.7 10 5.700 4.860 no 1997 F\n", "1 92.0 Neoplasms 0.9 1 0.870 0.683 no 2000 F\n", "2 94.0 Circulatory 1.4 10 4.360 3.850 no 1997 F\n", "3 92.0 Circulatory 1.0 9 2.420 2.220 no 1996 F\n", "4 93.0 Circulatory 1.1 6 1.320 1.690 no 1996 F\n", "... ... ... ... ... ... ... ... ... ..\n", "7869 52.0 NaN 1.0 6 1.210 1.610 no 1995 F\n", "7870 52.0 NaN 0.8 1 0.858 0.581 no 1999 F\n", "7871 54.0 NaN NaN 8 1.700 1.720 no 2002 F\n", "7872 53.0 NaN NaN 9 1.710 2.690 no 1995 F\n", "7873 50.0 NaN 0.7 4 1.190 1.250 no 1998 F\n", "\n", "[7874 rows x 9 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X" ] }, { "cell_type": "markdown", "id": "4fb9c403", "metadata": {}, "source": [ "The design matrix $X$ looks the same as it would in other modeling settings, such as regression or classification, and thus does not require any special treatment from a modeling point of view." ] }, { "cell_type": "code", "execution_count": 3, "id": "e02a6556", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([( True, 85.), ( True, 1281.), ( True, 69.), ...,\n", " (False, 2507.), (False, 4982.), (False, 3995.)],\n", " dtype=[('death', '?'), ('futime', '> n$). This is exactly the setting that *sparsesurv* is designed for. *sparsesurv* is based on knowledge distillation [5], which is also referred to as preconditioning [2] or reference models in statistics [3]. Thus, we briefly introduce the idea of knowledge distillation next." ] }, { "cell_type": "code", "execution_count": 6, "id": "a680cb51", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(7874, 9)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X.shape" ] }, { "cell_type": "markdown", "id": "65210e89", "metadata": {}, "source": [ "## Knowledge distillation" ] }, { "cell_type": "markdown", "id": "6493105e", "metadata": {}, "source": [ "The original idea of knowledge distillation was not directly related interpretabiltiy or feature selection. We note however, that the idea of using something akin to knowledge distillation was used and proposed in statistics before knowledge distillation itself, under the name of preconditioning. We will continue to use the name knowledge distillation since it may be more familiar to readers in the machine learning community." ] }, { "cell_type": "markdown", "id": "720f749a", "metadata": {}, "source": [ "The actual process of knowledge distillation proceeds in two steps:\n", "\n", " 1. Fit a teacher model that can approximate the target of interest (very) well\n", " \n", " 2. Fit a student model on the predictions of the teacher model, hoping that the teacher model acts as a kind of noise filter" ] }, { "cell_type": "markdown", "id": "22ee2d78", "metadata": {}, "source": [ "We illustrate how this process can be adapted to survival analysis. Please note that this running example will continue to use `scikit-survival`, before we move on to high-dimensional data." ] }, { "cell_type": "code", "execution_count": 7, "id": "f2e5b94e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('standardscaler', StandardScaler()),\n",
       "                ('coxphsurvivalanalysis', CoxPHSurvivalAnalysis())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('standardscaler', StandardScaler()),\n", " ('coxphsurvivalanalysis', CoxPHSurvivalAnalysis())])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sksurv.ensemble import RandomSurvivalForest\n", "from sksurv.linear_model import CoxPHSurvivalAnalysis\n", "from sklearn.linear_model import LinearRegression\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.pipeline import make_pipeline\n", "from sksurv.metrics import concordance_index_censored\n", "\n", "X_train = X.iloc[:5000, [0, 3, 4]]\n", "X_test = X.iloc[5000:, [0, 3, 4]]\n", "y_train = y[:5000]\n", "y_test = y[5000:]\n", "\n", "teacher_pipe = make_pipeline(StandardScaler(), RandomSurvivalForest())\n", "student_pipe = make_pipeline(StandardScaler(), LinearRegression())\n", "baseline_pipe = make_pipeline(StandardScaler(), CoxPHSurvivalAnalysis())\n", "\n", "teacher_pipe.fit(X_train, y_train)\n", "student_pipe.fit(X_train, teacher_pipe.predict(X_train))\n", "baseline_pipe.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 8, "id": "63f20333", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5677982023489915" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "concordance_index_censored(y_test[\"death\"], y_test[\"futime\"], teacher_pipe.predict(X_test))[0]" ] }, { "cell_type": "code", "execution_count": 9, "id": "9d17ab57", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6479734724443875" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "concordance_index_censored(y_test[\"death\"], y_test[\"futime\"], student_pipe.predict(X_test))[0]" ] }, { "cell_type": "code", "execution_count": 10, "id": "63bd79fd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6444986225266496" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "concordance_index_censored(y_test[\"death\"], y_test[\"futime\"], baseline_pipe.predict(X_test))[0]" ] }, { "cell_type": "markdown", "id": "e8bfd176", "metadata": {}, "source": [ "Interestingly, in this example, the student performance (as measured by Harrell's concordance) was slightly higher than the baseline, despite the teacher performing quite bad. There is ongoing research in the ML community along these lines. For us, however, all that matters is that knowledge distillation works for survival analysis." ] }, { "cell_type": "markdown", "id": "a83a03bd", "metadata": {}, "source": [ "## Minimal example of *sparsesurv*" ] }, { "cell_type": "markdown", "id": "8a0fe82c", "metadata": {}, "source": [ "Lastly, we give a brief example of usage of *sparsesurv*. If you are interested in using *sparsesurv* on your own data, please consult the documentation or the more specific user guides linked above." ] }, { "cell_type": "code", "execution_count": 11, "id": "2567a022", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sparsesurv.utils import transform_survival\n", "df = pd.read_csv(\"https://zenodo.org/records/10027434/files/OV_data_preprocessed.csv?download=1\")\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "c62c8211", "metadata": {}, "outputs": [], "source": [ "X = df.iloc[:, 3:].to_numpy()" ] }, { "cell_type": "code", "execution_count": 15, "id": "58b8ff48", "metadata": {}, "outputs": [], "source": [ "y = transform_survival(time=df.OS_days.values, event=df.OS.values)" ] }, { "cell_type": "code", "execution_count": 16, "id": "a87a3470", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(302, 19076)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X.shape" ] }, { "cell_type": "code", "execution_count": 17, "id": "3aa29fb9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([( True, 304.), ( True, 24.), (False, 576.), (False, 1207.),\n", " ( True, 676.)], dtype=[('event', '?'), ('time', '