Source code for scripts.preprocessing.subsample


import logging
from argparse import ArgumentParser, Namespace
import scanpy as sc
import numpy as np

_SAMPLE = "sample"
_ALL = "all"

logging.basicConfig()
logging.root.setLevel(logging.INFO)
_LOGGER = logging.getLogger(__name__)

[docs] def get_args() -> Namespace: """ Parse command line arguments for the sample subsetting tool. Returns: Namespace: Parsed command line arguments including: - input (str): Path to input h5ad file - output (str): Path for output h5ad file - n_samples (str): Number of samples to keep or 'all' - random_seed (int): Seed for random number generator """ parser = ArgumentParser() parser.add_argument("-i", "--input", type=str, required=True) parser.add_argument("-o", "--output", type=str, required=True) parser.add_argument("-n", "--n-samples", type=str, required=True) parser.add_argument("-r", "--random-seed", type=int, required=True) return parser.parse_args()
[docs] def main() -> None: """ Main function to execute the sample subsetting workflow. The function performs the following steps: 1. Loads an AnnData object from the specified input file 2. If n_samples is 'all', saves the complete dataset unchanged 3. Otherwise, randomly selects the specified number of samples 4. Saves the subsetted data to the specified output file Raises: AssertionError: If requested number of samples exceeds available samples or if the subsetting operation fails to produce expected results """ args = get_args() _LOGGER.info(f"Loading data from {args.input}.") adata = sc.read_h5ad(args.input) if args.n_samples == _ALL: _LOGGER.info(f"Didn't subsample the data and wrote the full data to {args.output}.") adata.write_h5ad(args.output) return n_samples = int(args.n_samples) assert n_samples <= adata.obs[_SAMPLE].nunique(), "Too few samples in the dataset" samples = adata.obs[_SAMPLE].unique().tolist() _LOGGER.info(f"Initialized generator with random seed {args.random_seed}.") rng = np.random.default_rng(seed=args.random_seed) samples_to_keep = rng.choice(samples, size=n_samples, replace=False) bdata = adata[adata.obs[_SAMPLE].isin(samples_to_keep)].copy() assert n_samples == bdata.obs[_SAMPLE].nunique(), "Something went wrong, too few samples sampled." bdata.write_h5ad(args.output)
if __name__ == '__main__': main()