scgpt.utils.util 源代码

import functools
import json
import logging
import os
from pathlib import Path
import random
import subprocess
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import pandas as pd
from anndata import AnnData
import scib
from matplotlib import pyplot as plt
from matplotlib import axes
from IPython import get_ipython

from .. import logger


[文档] def gene_vocabulary(): """ Generate the gene name2id and id2name dictionaries. """ pass
[文档] def set_seed(seed): """set random seed.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
# if n_gpu > 0: # torch.cuda.manual_seed_all(seed)
[文档] def add_file_handler(logger: logging.Logger, log_file_path: Path): """ Add a file handler to the logger. """ h = logging.FileHandler(log_file_path) # format showing time, name, function, and message formatter = logging.Formatter( "%(asctime)s-%(name)s-%(levelname)s-%(funcName)s: %(message)s", datefmt="%H:%M:%S", ) h.setFormatter(formatter) h.setLevel(logger.level) logger.addHandler(h)
[文档] def category_str2int(category_strs: List[str]) -> List[int]: set_category_strs = set(category_strs) name2id = {name: i for i, name in enumerate(set_category_strs)} return [name2id[name] for name in category_strs]
[文档] def isnotebook() -> bool: """check whether excuting in jupyter notebook.""" try: shell = get_ipython().__class__.__name__ if shell == "ZMQInteractiveShell": return True # Jupyter notebook or qtconsole elif shell == "TerminalInteractiveShell": return True # Terminal running IPython else: return False # Other type (?) except NameError: return False # Probably standard Python interpreter
[文档] def get_free_gpu(): import subprocess import sys from io import StringIO import pandas as pd gpu_stats = subprocess.check_output( [ "nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free", ] ).decode("utf-8") gpu_df = pd.read_csv( StringIO(gpu_stats), names=["memory.used", "memory.free"], skiprows=1 ) print("GPU usage:\n{}".format(gpu_df)) gpu_df["memory.free"] = gpu_df["memory.free"].map(lambda x: int(x.rstrip(" [MiB]"))) idx = gpu_df["memory.free"].idxmax() print( "Find free GPU{} with {} free MiB".format(idx, gpu_df.iloc[idx]["memory.free"]) ) return idx
[文档] def get_git_commit(): return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
[文档] def histogram( *data: List[np.ndarray], label: List[str] = ["train", "valid"], color: List[str] = ["blue", "red"], figsize: Tuple[int, int] = (9, 4), title: Optional[str] = None, show: bool = False, save: Optional[str] = None, ) -> axes.Axes: """ Plot histogram of the data. Args: data (List[np.ndarray]): The data to plot. label (List[str]): The label of the data. color (List[str]): The color of the data. figsize (Tuple[int, int]): The size of the figure. title (Optional[str]): The title of the figure. show (bool): Whether to show the figure. save (Optional[str]): The path to save the figure. Returns: axes.Axes: The axes of the figure. """ # show histogram of the clipped values assert len(data) == len(label), "The number of data and labels must be equal." fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=150) max_value = max(np.max(data) for data in data) ax.hist( [d.flatten() for d in data], bins=np.arange(0, max_value + 1, 1) + 0.5 if max_value < 60 else 60, label=label, density=True, histtype="bar", linewidth=2, rwidth=0.85, color=color, ) ax.legend() ax.set_xlabel("counts") ax.set_ylabel("density") if title is not None: ax.set_title(title) if show: plt.show() if save is not None: fig.savefig(save, bbox_inches="tight") return ax
def _indicate_col_name(adata: AnnData, promt_str: str) -> Optional[str]: """ Indicate the column name of the data. Args: adata (AnnData): The AnnData object. promt_str (str): The prompt string. Returns: Optional[str]: The column name. """ while True: col_name = input(promt_str) if col_name == "": col_name = None break elif col_name in adata.var.columns: break elif col_name in adata.obs.columns: break else: print(f"The column {col_name} is not in the data. " f"Please input again.") return col_name
[文档] def find_required_colums( adata: AnnData, id: str, configs_dir: Union[str, Path], update: bool = False, ) -> List[Optional[str]]: """ Find the required columns in AnnData, including celltype column, str_celltype column, the gene name column, and the experimental batch key. This function asks the user to input the required column names if the first time loading the data. The names are saved in the config file and will be automatically loaded next time. Args: adata (AnnData): The AnnData object. id (str): The id of the AnnData object, will be used as the file name for saving the config file. configs_dir (Union[str, Path]): The directory of saved config files. update (bool): Whether to update the config file. Returns: List[Optional[str]]: The required columns, including celltype_col, str_celltype_col, gene_col, and batch_col. """ if isinstance(configs_dir, str): configs_dir = Path(configs_dir) if not configs_dir.exists(): configs_dir.mkdir() config_file = configs_dir / f"{id}.json" if not config_file.exists() or update: print( "The config file does not exist, this may be the first time " "loading the data. \nPlease input the required column names." ) print(adata) celltype_col = _indicate_col_name( adata, "Please input the celltype column name (skip if not applicable): ", ) str_celltype_col = _indicate_col_name( adata, "Please input the str_celltype column name: " ) gene_col = _indicate_col_name(adata, "Please input the gene column name: ") batch_col = _indicate_col_name(adata, "Please input the batch column name: ") config = { "celltype_col": celltype_col, "str_celltype_col": str_celltype_col, "gene_col": gene_col, "batch_col": batch_col, } with open(config_file, "w") as f: json.dump(config, f) else: with open(config_file, "r") as f: config = json.load(f) return [ config["celltype_col"], config["str_celltype_col"], config["gene_col"], config["batch_col"], ]
[文档] def tensorlist2tensor(tensorlist, pad_value): max_len = max(len(t) for t in tensorlist) dtype = tensorlist[0].dtype device = tensorlist[0].device tensor = torch.zeros(len(tensorlist), max_len, dtype=dtype, device=device) tensor.fill_(pad_value) for i, t in enumerate(tensorlist): tensor[i, : len(t)] = t return tensor
[文档] def map_raw_id_to_vocab_id( raw_ids: Union[np.ndarray, torch.Tensor], gene_ids: np.ndarray, ) -> Union[np.ndarray, torch.Tensor]: """ Map some raw ids which are indices of the raw gene names to the indices of the Args: raw_ids: the raw ids to map gene_ids: the gene ids to map to """ if isinstance(raw_ids, torch.Tensor): device = raw_ids.device dtype = raw_ids.dtype return_pt = True raw_ids = raw_ids.cpu().numpy() elif isinstance(raw_ids, np.ndarray): return_pt = False dtype = raw_ids.dtype else: raise ValueError(f"raw_ids must be either torch.Tensor or np.ndarray.") if raw_ids.ndim != 1: raise ValueError(f"raw_ids must be 1d, got {raw_ids.ndim}d.") if gene_ids.ndim != 1: raise ValueError(f"gene_ids must be 1d, got {gene_ids.ndim}d.") mapped_ids: np.ndarray = gene_ids[raw_ids] assert mapped_ids.shape == raw_ids.shape if return_pt: return torch.from_numpy(mapped_ids).type(dtype).to(device) return mapped_ids.astype(dtype)
# Wrapper for all scib metrics, we leave out some metrics like hvg_score, cell_cyvle, # trajectory_conservation, because we only evaluate the latent embeddings here and # these metrics are evaluating the reconstructed gene expressions or pseudotimes.
[文档] def eval_scib_metrics( adata: AnnData, batch_key: str = "str_batch", label_key: str = "celltype", notes: Optional[str] = None, ) -> Dict: results = scib.metrics.metrics( adata, adata_int=adata, batch_key=batch_key, label_key=label_key, embed="X_scGPT", isolated_labels_asw_=False, silhouette_=True, hvg_score_=False, graph_conn_=True, pcr_=True, isolated_labels_f1_=False, trajectory_=False, nmi_=True, # use the clustering, bias to the best matching ari_=True, # use the clustering, bias to the best matching cell_cycle_=False, kBET_=False, # kBET return nan sometimes, need to examine ilisi_=False, clisi_=False, ) if notes is not None: logger.info(f"{notes}") logger.info(f"{results}") result_dict = results[0].to_dict() logger.info( "Biological Conservation Metrics: \n" f"ASW (cell-type): {result_dict['ASW_label']:.4f}, graph cLISI: {result_dict['cLISI']:.4f}, " f"isolated label silhouette: {result_dict['isolated_label_silhouette']:.4f}, \n" "Batch Effect Removal Metrics: \n" f"PCR_batch: {result_dict['PCR_batch']:.4f}, ASW (batch): {result_dict['ASW_label/batch']:.4f}, " f"graph connectivity: {result_dict['graph_conn']:.4f}, graph iLISI: {result_dict['iLISI']:.4f}" ) result_dict["avg_bio"] = np.mean( [ result_dict["NMI_cluster/label"], result_dict["ARI_cluster/label"], result_dict["ASW_label"], ] ) # remove nan value in result_dict result_dict = {k: v for k, v in result_dict.items() if not np.isnan(v)} return result_dict
# wrapper to make sure all methods are called only on the main process
[文档] def main_process_only(func): @functools.wraps(func) def wrapper(*args, **kwargs): if os.environ.get("LOCAL_RANK", "0") == "0": return func(*args, **kwargs) return wrapper
# class wrapper to make sure all methods are called only on the main process
[文档] class MainProcessOnly: def __init__(self, obj): self.obj = obj def __getattr__(self, name): attr = getattr(self.obj, name) if callable(attr): attr = main_process_only(attr) return attr