import logging
import pathlib as pl
from argparse import ArgumentParser
import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
logging.basicConfig()
logging.root.setLevel(logging.INFO)
_LOGGER = logging.getLogger(__name__)
_OFFSET: int = 10_000
_LATENT = "latent"
[docs]
def read_anndata(path: str):
adata = sc.read(path)
return adata
[docs]
def get_args():
parser = ArgumentParser()
parser.add_argument("-i", "--input", type=str, required=True)
parser.add_argument("-o", "--output", type=str, required=True)
parser.add_argument("-l", "--latent", type=str, required=True)
parser.add_argument("-n", "--n-cluster", type=int, required=True)
parser.add_argument("-r", "--random-seed", type=int, default=0)
return parser.parse_args()
[docs]
def add_latent(path: str, adata: sc.AnnData) -> sc.AnnData:
if path.endswith(".npz"):
_LOGGER.info(f"Loading graph from {path}")
data = np.load(path, allow_pickle=True)
# Numpy saves everything as arrays. .item retrieves the original object.
adata = adata[data["index"]].copy()
adata.obsp[f"connectivities"] = data["connectivities"].item()
adata.obsp[f"distances"] = data["distances"].item()
adata.uns["neighbors"] = {'connectivities_key': f'connectivities',
'distances_key': f'distances',
'params': {'n_neighbors': 15, 'method': 'umap', 'random_state': 0,
'metric': 'euclidean'}}
elif path.endswith(".csv"):
_LOGGER.info(f"Loading latents from {path}")
latent = pd.read_csv(path, index_col=0)
adata = adata[latent.index].copy()
if np.any(latent.index != adata.obs_names):
raise ValueError("Mismatch in index of the latent space and the adata. Something went wrong.")
adata.obsm[_LATENT] = latent.values
_LOGGER.info(f"Computing knn-graphs")
sc.pp.neighbors(adata, use_rep=_LATENT, n_neighbors=15)
else:
raise ValueError(f"Unknown file type, {path}.")
return adata
[docs]
def get_cluster(adata, n_cluster, random_seed ,start: float = 1e-4, end: float= 2., epsilon: float = 1e-8):
for i in range(10):
try:
while end - start > epsilon:
mid = (end + start) / 2.0
sc.tl.leiden(adata, resolution=mid, random_state=_OFFSET*i+random_seed)
n_tmp = adata.obs["leiden"].nunique()
if n_tmp == n_cluster:
_LOGGER.info(f"Found {n_tmp} cluster.")
break
if n_tmp > n_cluster:
end = mid
if n_tmp < n_cluster:
start = mid
else:
raise ValueError("Number of clusters doesn't match.")
except ValueError:
_LOGGER.info(f"Unsuccessful for random state {i}. Trying next random state.")
pass
else:
break
[docs]
def main():
args = get_args()
output = pl.Path(args.output)
_LOGGER.info(f"Making output dir {str(output.parents[0])}")
output.parents[0].mkdir(exist_ok=True, parents=True)
_LOGGER.info(f"Loading adata from {args.input}")
adata = read_anndata(args.input)
_LOGGER.info(f"Loading latent codes from {args.latent}")
metasigs = get_metasigs(adata, args.latent, n_cluster=args.n_cluster, random_seed=args.random_seed)
metasigs.to_csv(output, index=False)
if __name__ == '__main__':
main()