import anndata
import logging
import numpy as np
import pathlib as pl
import scanpy as sc
import warnings
from argparse import ArgumentParser, Namespace
from os import PathLike
from scipy.sparse import csr_matrix, diags
from sklearn.utils import sparsefuncs
from typing import Tuple, Literal, List, Optional
warnings.simplefilter(action='ignore', category=FutureWarning)
SEQ_10x = "10x"
MICROWELL = 'microwell array-based platform'
MICROWELL_SEQ = "microwell-seq"
SEQ_SS2 = "smartseq2"
SMARTSEQ2 = "SmartSeq2"
SEQWELL = "seqwell"
SAMPLE = "sample"
NORMALIZED_KEY = "normalized"
COUNT_TYPE = "count_type"
_LOGGER = logging.getLogger(__name__)
def get_args() -> Namespace:
Parse command line arguments for the preprocessing pipeline.
Namespace: Parsed command line arguments including:
- input: Path to input h5ad file
- output: Path for output h5ad file
- excluded_sample: List of sample names to exclude
- min_genes: Minimum number of genes required per cell
- min_counts: Minimum number of counts required per cell
- max_pct_mt: Maximum percentage of mitochondrial genes allowed
parser = ArgumentParser()
parser.add_argument("-i", "--input", type=str, required=True)
parser.add_argument("-o", "--output", type=str, required=True)
parser.add_argument("--excluded-sample", type=str, nargs="+")
parser.add_argument("--min-genes", type=int)
parser.add_argument("--min-counts", type=int)
parser.add_argument("--max-pct-mt", type=float, required=True)
return parser.parse_args()
def get_counts_per_cell(x: csr_matrix):
Calculate the total count of gene expression for each cell.
x (csr_matrix): Sparse matrix containing gene expression data
numpy.ndarray: Array containing total counts for each cell
return np.asarray(x.sum(1)).ravel()
def get_counts_from_tpm(tpm: csr_matrix, technology: str) -> csr_matrix:
Convert TPM (Transcripts Per Million) values to estimated count data.
tpm (csr_matrix): Sparse matrix of TPM values
technology (str): Sequencing technology used
csr_matrix: Estimated count data
NotImplementedError: If the sequencing technology is not recognized
tpm = tpm.copy()
if any(technology == sequencing_tech for sequencing_tech in [SEQ_10x, MICROWELL, MICROWELL_SEQ, SEQWELL]):
library_sizes = np.zeros(tpm.shape[0])
const = get_counts_per_cell(tpm).mean()
for n_row in range(tpm.shape[0]):
row = tpm[n_row]
library_size = const / row[row > 0.0].min()
library_sizes[n_row] = library_size
counts = diags(library_sizes / const) * tpm = np.round("Converting TPMs to counts by estimating library size. Estimated library size:\n"
f" Mean: {library_sizes.mean()}\n Max: {library_sizes.max()}\n Min: {library_sizes.min()}")
elif technology == SEQ_SS2:"Converting TPMs to counts by rounding.")
counts = tpm.copy() = np.round(
raise NotImplementedError(f"Unknown technology {technology}.")
return counts
def normalize(counts: csr_matrix, target_counts=1e5) -> csr_matrix:
Normalize counts to a target sum per cell.
counts (csr_matrix): Raw count data
target_counts (float): Target sum for each cell after normalization
csr_matrix: Normalized count data
library_size = get_counts_per_cell(counts)
sparsefuncs.inplace_row_scale(counts, target_counts / library_size)
return counts
def get_tpm_counts(input, count_type: str, technology: str) -> Tuple[csr_matrix, csr_matrix]:
Process input data to get both TPM and count matrices.
input: Input count or TPM data
count_type (str): Type of count data provided
technology (str): Sequencing technology used
Tuple[csr_matrix, csr_matrix]: TPM values and count data
""""Found {count_type} count type.")
if 'Exp_data_UMIcounts' == count_type:
counts = input.copy()
tpm = normalize(input)
return tpm, counts
counts = get_counts_from_tpm(input, technology)
tpm = normalize(input)
return tpm, counts
def set_tpm_counts(adata: anndata.AnnData, technology: str) -> anndata.AnnData:
Set TPM and count data in AnnData object and perform log transformation.
adata (anndata.AnnData): Input data object
technology (str): Sequencing technology used
anndata.AnnData: Processed data object with TPM and counts
count_type = adata.uns[COUNT_TYPE]
tpm, counts = get_tpm_counts(adata.X, count_type, technology)
del adata.uns[COUNT_TYPE]
adata.layers["counts"] = counts
adata.X = tpm"Log plus one transforming the normalized data (adata.X). ")
sc.pp.log1p(adata, base=2)
return adata
def remove_low_count_cells(adata: anndata.AnnData, min_counts: Optional[int],
min_genes: Optional[int]) -> anndata.AnnData:
Filter cells based on minimum count and gene expression thresholds.
adata (anndata.AnnData): Input data object
min_counts (Optional[int]): Minimum counts required per cell
min_genes (Optional[int]): Minimum genes required per cell
anndata.AnnData: Filtered data object
if min_counts:"Removing cells with less than {min_counts} counts.")
sc.pp.filter_cells(adata, min_counts=min_counts)
if min_genes:"Removing cells with less than {min_genes} genes expressed.")
sc.pp.filter_cells(adata, min_genes=min_genes)"Kept {adata.n_obs} cells.")
return adata
def subset_malignant(adata: anndata.AnnData) -> anndata.AnnData:
Subset data to include only malignant cells.
adata (anndata.AnnData): Input data object
anndata.AnnData: Data object containing only malignant cells
ValueError: If malignant cells cannot be identified
MALIGNANT = "malignant"
if MALIGNANT in adata.obs.columns:"Used adata.obs['malignant'] to find malignant cells.")
idx = adata.obs[MALIGNANT] == "yes"
elif MALIGNANT in adata.obs["cell_type"].str.lower().values:
f"Used cell_types {adata.obs.cell_type.unique().tolist()} to find malignant cells.")
idx = adata.obs["cell_type"].str.lower() == MALIGNANT
raise ValueError("No way of determining which cells are malignant.")
adata = adata[idx].copy()"Found {adata.n_obs} malignant cells.")
return adata
def determine_seq_technology(adata: anndata.AnnData):
Determine the sequencing technology used from the AnnData object.
adata (anndata.AnnData): Input data object
str: Sequencing technology identifier
KeyError: If technology information is missing
ValueError: If multiple technologies are found
TECHNOLOGY = "technology"
if TECHNOLOGY not in adata.obs.columns:
raise KeyError(f"Key {TECHNOLOGY} not in adata.obs.columns.")
technology = adata.obs[TECHNOLOGY].unique()
if len(technology) != 1:
raise ValueError("More than one technology found.")
technology = technology[0]"Determined {technology} as sequencing technology.")
return technology.lower()
def remove_high_mt_cells(adata: anndata.AnnData, max_pct_mt: float) -> anndata.AnnData:
Remove cells with high mitochondrial gene expression.
adata (anndata.AnnData): Input data object
max_pct_mt (float): Maximum percentage of mitochondrial counts allowed
anndata.AnnData: Filtered data object
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], percent_top=None, log1p=False,
adata = adata[adata.obs["pct_counts_mt"] <= max_pct_mt].copy()"Kept {adata.n_obs} cells after removing high MT cells.")
return adata
def remove_samples(adata: anndata.AnnData) -> anndata.AnnData:
Remove samples with fewer than 50 cells in G1 phase.
adata (anndata.AnnData): Input data object
anndata.AnnData: Filtered data object
obs = adata.obs.copy()
obs_no_cc = obs[obs["phase"]=="G1"]
cells_per_sample = obs_no_cc[SAMPLE].value_counts()
samples_to_keep = cells_per_sample[cells_per_sample > 50].index.tolist()
adata = adata[adata.obs[SAMPLE].isin(samples_to_keep)].copy()"Removed {cells_per_sample.shape[0] - len(samples_to_keep)} of {cells_per_sample.shape[0]} samples.")
return adata
def r_names(adata: anndata.AnnData) -> anndata.AnnData:
Convert variable and observation names to R-compatible format.
adata (anndata.AnnData): Input data object
anndata.AnnData: Data object with R-compatible names
""""Changing `obs_names` and `var_names` to R names.")
adata.var_names = adata.var_names.str.replace("_", "-")
adata.obs_names = adata.obs_names.str.replace("_", "-")
return adata
def remove_excluded_samples(adata: sc.AnnData, excluded_samples: List[str]) -> sc.AnnData:
Remove specified samples from the dataset.
adata (sc.AnnData): Input data object
excluded_samples (List[str]): List of sample names to exclude
sc.AnnData: Filtered data object
""""Anndata contains {adata.obs[SAMPLE].unique().tolist()} as samples.")
adata = adata[~adata.obs[SAMPLE].isin(excluded_samples)].copy()"After removing excluded samples anndata contains {adata.obs[SAMPLE].unique().tolist()} as samples.")
return adata
def preprocessing(adata: anndata.AnnData, excluded_samples: List[str], min_genes: int, min_counts: int,
max_pct_mt: float) -> anndata.AnnData:
Perform complete preprocessing pipeline on single-cell RNA sequencing data.
adata (anndata.AnnData): Input data object
excluded_samples (List[str]): List of sample names to exclude
min_genes (int): Minimum number of genes required per cell
min_counts (int): Minimum number of counts required per cell
max_pct_mt (float): Maximum percentage of mitochondrial counts allowed
anndata.AnnData: Fully preprocessed data object
f"Started preprocessing with {adata.n_obs} cells, {adata.n_vars} genes and "
f"{adata.obs[SAMPLE].nunique()} samples.")
if "normalized_Exp_data_TPM" == adata.uns[COUNT_TYPE]: = (2 ** - 1
technology = determine_seq_technology(adata)
adata = subset_malignant(adata)
adata = remove_low_count_cells(adata, min_genes=min_genes, min_counts=min_counts)
adata = remove_high_mt_cells(adata, max_pct_mt)
adata = score_cell_cycle(adata)
adata = remove_samples(adata)
if excluded_samples:
adata = remove_excluded_samples(adata, excluded_samples)
adata = set_tpm_counts(adata, technology)
adata = r_names(adata)
sc.pp.filter_genes(adata, min_cells=int(0.001 * adata.n_obs))
f"Finished preprocessing with {adata.n_obs} cells, {adata.n_vars} genes and "
f"{adata.obs[SAMPLE].nunique()} samples.")
return adata
def read_anndata(path: PathLike) -> anndata.AnnData:"Reading Anndata from {path}.")
adata = anndata.read_h5ad(path)
if adata.obs_names.str.isdigit().any():"Found integers in adata.obs_names appending sample name "
"to ensure casting to strings.")
adata.obs_names = adata.obs_names.astype(str) + "_" + adata.obs[SAMPLE].astype(str)
return adata
def write_anndata(adata: anndata.AnnData, output_path: PathLike) -> None:
output_path = pl.Path(output_path)
output_path.parent.mkdir(exist_ok=True, parents=True)
adata.strings_to_categoricals()"Writing AnnData to {output_path}.")
def score_cell_cycle(adata: anndata.AnnData) -> anndata.AnnData:
s_genes = ['MCM5', 'PCNA', 'TYMS', 'FEN1', 'MCM2', 'MCM4', 'RRM1', 'UNG',
'GINS2', 'MCM6', 'CDCA7', 'DTL', 'PRIM1', 'UHRF1', 'MLF1IP',
'HELLS', 'RFC2', 'RPA2', 'NASP', 'RAD51AP1', 'GMNN', 'WDR76',
'SLBP', 'CCNE2', 'UBR7', 'POLD3', 'MSH2', 'ATAD2', 'RAD51', 'RRM2',
'CDC45', 'CDC6', 'EXO1', 'TIPIN', 'DSCC1', 'BLM', 'CASP8AP2',
'USP1', 'CLSPN', 'POLA1', 'CHAF1B', 'BRIP1', 'E2F8']
g2m_genes = ['HMGB2', 'CDK1', 'NUSAP1', 'UBE2C', 'BIRC5', 'TPX2', 'TOP2A',
'NDC80', 'CKS2', 'NUF2', 'CKS1B', 'MKI67', 'TMPO', 'CENPF',
'TACC3', 'FAM64A', 'SMC4', 'CCNB2', 'CKAP2L', 'CKAP2', 'AURKB',
'BUB1', 'KIF11', 'ANP32E', 'TUBB4B', 'GTSE1', 'KIF20B', 'HJURP',
'CDCA3', 'HN1', 'CDC20', 'TTK', 'CDC25C', 'KIF2C', 'RANGAP1',
'NCAPD2', 'DLGAP5', 'CDCA2', 'CDCA8', 'ECT2', 'KIF23', 'HMMR',
'AURKA', 'PSRC1', 'ANLN', 'LBR', 'CKAP5', 'CENPE', 'CTCF', 'NEK2',
'G2E3', 'GAS2L3', 'CBX5', 'CENPA'], s_genes=s_genes, g2m_genes=g2m_genes)
return adata
def make_datadir(output):
output_path = pl.Path(output)
output_path.parents[0].mkdir(exist_ok=True, parents=True)
def main():
args = get_args()
adata = read_anndata(args.input)
adata = preprocessing(adata, excluded_samples=args.excluded_sample, min_genes=args.min_genes,
max_pct_mt=args.max_pct_mt, min_counts=args.min_counts)
write_anndata(adata, args.output)
if __name__ == '__main__':