scgpt.scbank.databank 源代码

import json
from pathlib import Path
from dataclasses import InitVar, dataclass, field
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from typing_extensions import Self, Literal

import numpy as np
from scipy.sparse import spmatrix, csr_matrix
from anndata import AnnData
from datasets import Dataset, load_dataset
from scvi import settings

from scgpt.tokenizer import GeneVocab
from .data import DataTable, MetaInfo
from .setting import Setting
from . import logger


[文档] @dataclass class DataBank: """ The data structure for large-scale single cell data containing multiple studies. See https://github.com/subercui/scGPT-release#the-data-structure-for-large-scale-computing. """ meta_info: MetaInfo = None data_tables: Dict[str, DataTable] = field( default_factory=dict, metadata={"help": "Data tables in the DataBank."}, ) gene_vocab: InitVar[GeneVocab] = field( default=None, # Because of the issue https://github.com/python/cpython/issues/94067, the # attribute can not be correctly set with default value, if using a getter # and setter, i.e. @property. So just explicityly set it in the __post_init__. metadata={"help": "Gene vocabulary mapping gene tokens to integer ids."}, ) settings: Setting = field( default_factory=Setting, metadata={"help": "The settings for scBank, use default if not set."}, ) def __post_init__(self, gene_vocab) -> None: """ Initialize a DataBank. If initialize a non-empty DataBank, will check the meta info and import the data tables. The main data table idicator will be required, if at least one data table is provided. """ # primarily like the DataSets here, should support slicing, indexing, map, etc. # also, because DataBank is desined to work with large-scale data. The class # instance itself works like a scheme framework that stores the metadata of # complete data to easily show or print the info. In the meantime, the actual # data chunks are kept synchronized in disk, and the DataBank will efficiently # load and return the partion in memory when queried, in the memory and # computational effective format of Datasets or Apache Arrow. # DataBank always tracks the current accessed and queried table/partition, # and the changes made. When trying to close the current table or switching # to another one, DataBank will explicitly ask whether to save the changes # to disk, and show the info of the updated table fields. if isinstance(gene_vocab, property): # walkaround for the issue https://github.com/python/cpython/issues/94067 self._gene_vocab = None else: self.gene_vocab = gene_vocab # empty initialization: if self.meta_info is None: if len(self.data_tables) > 0: raise ValueError("Need to provide meta info if non-empty data tables.") if self.gene_vocab is not None: raise ValueError("Need to provide meta info if non-empty gene vocab.") logger.debug("Initialize an empty DataBank.") return # only-meta-info initializtion: if len(self.data_tables) == 0 and self.gene_vocab is None: logger.debug("DataBank initialized with meta info only.") self.sync() if self.settings.immediate_save else self.track() return # full initialization: if self.gene_vocab is None: raise ValueError("Need to provide gene vocab if non-empty data tables.") # validate the meta info, and the consistency between meta and data tables # we assume the input meta_info is complete and correct. Usually this should # be handled by factory constructors. self._validate_data() self.sync() if self.settings.immediate_save else self.track() @property def gene_vocab(self) -> Optional[GeneVocab]: """ The gene vocabulary mapping gene tokens to integer ids. """ # if self._gene_vocab is None: # self._gene_vocab = GeneVocab.load(self.meta_info.gene_vocab_path) return self._gene_vocab @gene_vocab.setter def gene_vocab(self, gene_vocab: Union[GeneVocab, Path, str]) -> None: """ Set the gene vocabulary from an :obj:`GeneVocab` object or a file path. """ if isinstance(gene_vocab, (Path, str)): gene_vocab = GeneVocab.from_file(gene_vocab) elif not isinstance(gene_vocab, GeneVocab): raise ValueError("gene_vocab must be a GeneVocab object or a path.") self._validate_vocab(gene_vocab) # e.g. check md5 every time calling setter self._gene_vocab = gene_vocab self.sync( # sync to disk, update md5 in meta info attr_keys=["gene_vocab"] ) if self.settings.immediate_save else self.track(["gene_vocab"]) @property def main_table_key(self) -> Optional[str]: """ The main data table key. """ if self.meta_info is None: return None return self.meta_info.main_table_key @main_table_key.setter def main_table_key(self, table_key: str) -> None: """Set the main data table key.""" if self.meta_info is None: raise ValueError("Need to have self.meta_info if setting main table key.") self.meta_info.main_table_key = table_key self.sync(["meta_info"]) if self.settings.immediate_save else self.track( ["meta_info"] ) @property def main_data(self) -> DataTable: """The main data table.""" return self.data_tables[self.main_table_key]
[文档] @classmethod def from_anndata( cls, adata: Union[AnnData, Path, str], vocab: Union[GeneVocab, Mapping[str, int]], to: Union[Path, str], main_table_key: str = "X", token_col: str = "gene name", immediate_save: bool = True, ) -> Self: """ Create a DataBank from an AnnData object. Args: adata (AnnData): Annotated data or path to anndata file. vocab (GeneVocab or Mapping[str, int]): Gene vocabulary maps gene token to index. to (Path or str): Data directory. main_table_key (str): This layer/obsm in anndata will be used as the main data table. token_col (str): Column name of the gene token. immediate_save (bool): Whether to save the data immediately after creation. Returns: DataBank: DataBank instance. """ if isinstance(adata, str) or isinstance(adata, Path): import scanpy as sc adata = sc.read(adata, cache=True) elif not isinstance(adata, AnnData): raise ValueError("adata must points to an AnnData object.") if isinstance(vocab, (Path, str)): vocab = GeneVocab.from_file(vocab) elif isinstance(vocab, Mapping) and not isinstance(vocab, GeneVocab): vocab = GeneVocab.from_dict(vocab) elif not isinstance(vocab, GeneVocab): raise ValueError("vocab must be a GeneVocab object or a mapping.") if isinstance(to, str): to = Path(to) to.mkdir(parents=True, exist_ok=True) db = cls( meta_info=MetaInfo(on_disk_path=to), gene_vocab=vocab, settings=Setting(immediate_save=immediate_save), ) # TODO: Add other data tables, currently only read the main data data_table = db.load_anndata( adata, data_keys=[main_table_key], token_col=token_col, )[0] # update and immediate save db.main_table_key = main_table_key db.update_datatables(new_tables=[data_table], immediate_save=immediate_save) return db
[文档] @classmethod def batch_from_anndata(cls, adata: List[AnnData], to: Union[Path, str]) -> Self: raise NotImplementedError
[文档] @classmethod def from_path(cls, path: Union[Path, str]) -> Self: """ Create a DataBank from a directory containing scBank data. **NOTE**: this method will automatically check whether md5sum record in the :file:`manifest.json` matches the md5sum of the loaded gene vocabulary. Args: path (Path or str): Directory path. Returns: DataBank: DataBank instance. """ if isinstance(path, str): path = Path(path) if not path.exists(): raise ValueError(f"Path {path} does not exist.") if not path.is_dir(): raise ValueError(f"Path {path} is not a directory.") if not path.joinpath("gene_vocab.json").exists(): logger.warning(f"DataBank at {path} does not contain gene_vocab.json.") data_table_files = [f for f in path.glob("*.datatable.*") if f.is_file()] if len(data_table_files) == 0: logger.warning(f"Loading empty DataBank at {path} without datatables.") db = cls(meta_info=MetaInfo.from_path(path)) db.gene_vocab = GeneVocab.from_file(path / "gene_vocab.json") data_format = db.meta_info.on_disk_format for data_table_file in data_table_files: logger.info(f"Loading datatable {data_table_file}.") data_table = DataTable( name=data_table_file.name.split(".")[0], data=load_dataset( data_format, data_files=str(data_table_file), cache_dir=str(path), split="train", ), ) db.update_datatables(new_tables=[data_table]) return db
def __len__(self) -> int: """Return the number of cells in DataBank.""" raise NotImplementedError def _load_anndata_layer( self, adata: AnnData, data_key: Optional[str] = "X", index_map: Optional[Mapping[int, int]] = None, ) -> Optional[Dataset]: """ Load anndata layer as a :class:Dataset object. Args: adata (:class:`AnnData`): Annotated data object to load. data_key (:obj:`str`, optional): Data key to load, default to "X". The data key must be in :attr:`adata.X`, :attr:`adata.layers` or :attr:`adata.obsm`. index_map (:obj:`Mapping[int, int]`, optional): A mapping from old index to new index. If None, meaning the index is unchanged and the converted data rows do not have explicit keys. Returns: :class:`Dataset`: Dataset object loaded. """ if not isinstance(adata, AnnData): raise ValueError("adata must be an AnnData object.") if index_map is None: # TODO: implement loading data w/ index unchanged, i.e. w/o explicit keys raise NotImplementedError if data_key == "X": data = adata.X elif data_key in adata.layers: data = adata.layers[data_key] elif data_key in adata.obsm: data = adata.obsm[data_key] else: logger.warning(f"Data key {data_key} not found, skip loading.") return None tokenized_data = self._tokenize(data, index_map) return Dataset.from_dict(tokenized_data) def _tokenize( self, data: Union[np.ndarray, csr_matrix], ind2ind: Mapping[str, int], new_indices: Optional[List[int]] = None, ) -> Dict[str, List]: """ Tokenize the data with the given vocabulary. TODO: currently by default uses the key-value datatable scheme. Add other support. TODO: currently just counting from zero for the cell id. Args: data (np.ndarray or spmatrix): Data to be tokenized. tokens (List[str]): List of gene tokens. ind2ind (Mapping[str, int]): Old index to new index mapping. new_indices (List[int]): New indices to be used, will ignore ind2ind. **NOTE**: Usually this should be None. Should only be used in the spawned recursive calls if any. Returns: Dict[str, List]: Tokenized data. """ if not isinstance(data, (np.ndarray, csr_matrix)): raise ValueError("data must be a numpy array or sparse matrix.") if isinstance(data, np.ndarray): zero_ratio = np.sum(data == 0) / data.size if zero_ratio > 0.85: logger.debug( f"{zero_ratio*100:.0f}% of the data is zero, " "auto convert to sparse matrix before tokenizing." ) data = csr_matrix(data) # process sparse matrix actually faster # remove zero rows if self.settings.remove_zero_rows: if isinstance(data, np.ndarray) and data.size > 1e9: logger.warning( "Going to remove zero rows from a large ndarray data. This " "may take long time. If you want to disable this, set " f"`remove_zero_rows` to False in {self.__name__}.settings." ) if isinstance(data, csr_matrix): data = data[data.getnnz(axis=1) > 0] else: data = data[~np.all(data == 0, axis=1)] n_rows, n_cols = data.shape # usually means n_cells, n_genes if new_indices is None: new_indices = np.array( [ind2ind.get(i, -100) for i in range(n_cols)], int ) # TODO: can be accelerated else: assert len(new_indices) == n_cols if isinstance(data, csr_matrix): indptr = data.indptr indices = data.indices non_zero_data = data.data tokenized_data = {"id": [], "genes": [], "expressions": []} tokenized_data["id"] = list(range(n_rows)) for i in range(n_rows): # ~2s/100k cells row_indices = indices[indptr[i] : indptr[i + 1]] row_new_indices = new_indices[row_indices] row_non_zero_data = non_zero_data[indptr[i] : indptr[i + 1]] match_mask = row_new_indices != -100 row_new_indices = row_new_indices[match_mask] row_non_zero_data = row_non_zero_data[match_mask] tokenized_data["genes"].append(row_new_indices) tokenized_data["expressions"].append(row_non_zero_data) else: # remove -100 cols in data and new_indices, i.e. genes not in vocab mask = new_indices != -100 new_indices = new_indices[mask] data = data[:, mask] tokenized_data = _nparray2mapped_values( data, new_indices, "numba" ) # ~7s/100k cells return tokenized_data def _validate_vocab(self, vocab: Optional[GeneVocab] = None) -> None: """ Validate the vocabulary. Check :attr:`self.vocab` if no vocab is given. """ if vocab is None: vocab = self.gene_vocab # TODO: implement validation def _validate_data(self) -> None: """ Validate the current DataBank, including checking md5sum, table names, etc. """ # first validate vocabulary self._validate_vocab() if len(self.data_tables) == 0 and self.main_table_key is not None: raise ValueError( "No data tables found, but main table key is set. " "Please set main table key to None or add data tables." ) if len(self.data_tables) > 0: if self.main_table_key is None: raise ValueError( "Main table key can not be empty if non-empty data tables." ) if self.main_table_key not in self.data_tables: raise ValueError( "Main table key {self.main_table_key} not found in data tables." )
[文档] def update_datatables( self, new_tables: List[DataTable], use_names: List[str] = None, overwrite: bool = False, immediate_save: Optional[bool] = None, ) -> None: """ Update the data tables in the DataBank with new data tables. Args: new_tables (list of :class:`DataTable`): New data tables to update. use_names (list of :obj:`str`): Names of the new data tables to use. If not provided, will use the names of the new data tables. overwrite (:obj:`bool`): Whether to overwrite the existing data tables. immediate_save (:obj:`bool`): Whether to save the data immediately after updating. Will save to :attr:`self.meta_info.on_disk_path`. If not provided, will follow :attr:`self.settings.immediate_save` instead. Default to None. """ if not isinstance(new_tables, list) or not all( isinstance(t, DataTable) for t in new_tables ): raise ValueError("new_tables must be a list of DataTable.") if use_names is None: use_names = [t.name for t in new_tables] else: if len(use_names) != len(new_tables): raise ValueError("use_names must have the same length as new_tables.") if not overwrite: overlaps = set(use_names) & set(self.data_tables.keys()) if len(overlaps) > 0: raise ValueError( f"Data table names {overlaps} already exist in the DataBank. " "Please set overwrite=True if replacing the existing data table." ) if immediate_save is None: immediate_save = self.settings.immediate_save for dt, name in zip(new_tables, use_names): self.data_tables[name] = dt self._validate_data() self.sync("data_tables") if immediate_save else self.track("data_tables")
[文档] def load_anndata( self, adata: AnnData, data_keys: Optional[List[str]] = None, token_col: str = "gene name", ) -> List[DataTable]: """ Load anndata into datatables. Args: adata (:class:`AnnData`): Annotated data object to load. data_keys (list of :obj:`str`): List of data keys to load. If None, all data keys in :attr:`adata.X`, :attr:`adata.layers` and :attr:`adata.obsm` will be loaded. token_col (:obj:`str`): Column name of the gene token. Tokens will be converted to indices by :attr:`self.gene_vocab`. Returns: list of :class:`DataTable`: List of data tables loaded. """ if not isinstance(adata, AnnData): raise ValueError("adata must be an AnnData object.") if data_keys is None: data_keys = ["X"] + list(adata.layers.keys()) + list(adata.obsm.keys()) if token_col not in adata.var: raise ValueError(f"token_col {token_col} not found in adata.var.") if not isinstance(adata.var[token_col][0], str): raise ValueError(f"token_col {token_col} must be of type str.") # validate matching between tokens and vocab tokens = adata.var[token_col].tolist() match_ratio = sum([1 for t in tokens if t in self.gene_vocab]) / len(tokens) if match_ratio < 0.9: raise ValueError( f"{match_ratio*100:.0f}% of the tokens in adata.var[{token_col}] are not in " "vocab. Please check if using the correct vocab and token_col." ) # build mapping to scBank datatable keys _ind2ind = _map_ind(tokens, self.gene_vocab) # old index to new index data_tables = [] for key in data_keys: data = self._load_anndata_layer(adata, key, _ind2ind) data_table = DataTable( name=key, data=data, ) data_tables.append(data_table) return data_tables
[文档] def append_study( self, study_id: int, study_data: Union[AnnData, "DataBank"], ) -> None: """ Append a study to the current DataBank. Args: study_id (str): Study ID. study_data (AnnData or DataBank): Study data. """ raise NotImplementedError
[文档] def delete_study(self, study_id: int) -> None: """ Delete a study from the current DataBank. """ raise NotImplementedError
[文档] def filter( self, study_ids: Optional[List[int]] = None, cell_ids: Optional[List[int]] = None, inplace: bool = True, ) -> Self: """ Filter the current DataBank by study ID and cell ID. Args: study_ids (list): Study IDs to filter. cell_ids (list): Cell IDs to filter. inplace (bool): Whether to also filter inplace. Returns: DataBank: Filtered DataBank. """ raise NotImplementedError
[文档] def custom_filter( self, field: str, filter_func: callable, inplace: bool = True, ) -> Self: """ Filter the current DataBank by applying a custom filter function to a field. Args: field (str): Field to filter. filter_func (callable): Filter function. inplace (bool): Whether to also filter inplace. Returns: DataBank: Filtered DataBank. """ raise NotImplementedError
[文档] def load_table(self, table_name: str) -> Dataset: """ Load a data table from the current DataBank. """ raise NotImplementedError
[文档] def load(self, path: Union[Path, str]) -> Dataset: """ Load scBank data from a data directory. Since DataBank is designed to work with large-scale data, this only loads the main data table to memory by default. This does as well load the meta info and perform validation check. """ if isinstance(path, str): path = Path(path) assert path.is_dir(), f"Data path {path} should be a directory." raise NotImplementedError
[文档] def load_all(self, path: Union[Path, str]) -> Dict[str, Dataset]: """ Load scBank data from a data directory. This will load all the data tables to memory. """ if isinstance(path, str): path = Path(path) assert path.is_dir(), f"Data path {path} should be a directory." raise NotImplementedError
[文档] def track(self, attr_keys: Union[List[str], str, None] = None) -> List: """ Track all the changes made to the current DataBank and that have not been synced to disk. This will return a list of changes. Args: attr_keys (list of :obj:`str`): List of attribute keys to look for changes. If None, all attributes will be checked. """ if attr_keys is None: attr_keys = [ "meta_info", "data_tables", "gene_vocab", ] elif isinstance(attr_keys, str): attr_keys = [attr_keys] changed_attrs = [] # TODO: implement the checking of the changes, currently just return all changed_attrs.extend(attr_keys) return changed_attrs
[文档] def sync(self, attr_keys: Union[List[str], str, None] = None) -> None: """ Sync the current DataBank to a data directory, including, save the updated data/vocab to files, update the meta info and save to files. **NOTE**: This will overwrite the existing data directory. Args: attr_keys (list of :obj:`str`): List of attribute keys to sync. If None, will sync all the attributes with tracked changes. """ if attr_keys is None: attr_keys = self.track() elif isinstance(attr_keys, str): attr_keys = [attr_keys] # TODO: implement. Remeber particularly to update md5 in metainfo when # updating the gene vocabulary. on_disk_path = self.meta_info.on_disk_path data_format = self.meta_info.on_disk_format if "meta_info" in attr_keys: self.meta_info.save(on_disk_path) if "data_tables" in attr_keys: for data_table in self.data_tables.values(): save_to = on_disk_path / f"{data_table.name}.datatable.{data_format}" logger.info(f"Saving data table {data_table.name} to {save_to}.") data_table.save( path=save_to, format=data_format, ) if "gene_vocab" in attr_keys: if self.gene_vocab is not None: self._gene_vocab.save_json(on_disk_path / "gene_vocab.json")
[文档] def save(self, path: Union[Path, str, None], replace: bool = False) -> None: """ Save scBank data to a data directory. Args: path (Path): Path to save scBank data. If None, will save to the directory at :attr:`self.meta_info.on_disk_path`. replace (bool): Whether to replace existing data in the directory. """ if path is None: path = self.meta_info.on_disk_path elif isinstance(path, str): path = Path(path) path.mkdir(parents=True, exist_ok=True) raise NotImplementedError
def _map_ind(tokens: List[str], vocab: Mapping[str, int]) -> Mapping[int, int]: """ Create a mapping from old index to new index, for a list of tokens. Args: tokens (list): List of tokens, in the order of old indeces. vocab (Mapping[str, int]): Vocabulary mapping from token to new index. """ ind2ind = {} unmatched_tokens = [] for i, t in enumerate(tokens): if t in vocab: ind2ind[i] = vocab[t] else: unmatched_tokens.append(t) if len(unmatched_tokens) > 0: logger.warning( f"{len(unmatched_tokens)}/{len(tokens)} tokens/genes unmatched " "during vocabulary conversion." ) return ind2ind def _nparray2mapped_values( data: np.ndarray, new_indices: np.ndarray, mode: Literal["plain", "numba"] = "plain", ) -> Dict[str, List]: """ Convert a numpy array to mapped values. Only include the non-zero values. Args: data (np.ndarray): Data matrix. new_indices (np.ndarray): New indices. Returns: Dict[str, List]: Mapping from column name to list of values. """ if mode == "plain": convert_func = _nparray2indexed_values elif mode == "numba": convert_func = _nparray2indexed_values_numba else: raise ValueError(f"Unknown mode {mode}.") tokenized_data = {} row_ids, col_inds, values = convert_func(data, new_indices) tokenized_data["id"] = row_ids tokenized_data["genes"] = col_inds tokenized_data["expressions"] = values return tokenized_data def _nparray2indexed_values( data: np.ndarray, new_indices: np.ndarray, ) -> Tuple[List, List, List]: """ Convert a numpy array to indexed values. Only include the non-zero values. Args: data (np.ndarray): Data matrix. new_indices (np.ndarray): New indices. Returns: Tuple[List, List, List]: Row IDs, column indices, and values. """ row_ids, col_inds, values = [], [], [] for i in range(len(data)): # TODO: accelerate with numba? joblib? row = data[i] idx = np.nonzero(row)[0] expressions = row[idx] genes = new_indices[idx] row_ids.append(i) col_inds.append(genes) values.append(expressions) return row_ids, col_inds, values from numba import jit, njit, prange @njit(parallel=True) def _nparray2indexed_values_numba( data: np.ndarray, new_indices: np.ndarray, ) -> Tuple[List, List, List]: """ Convert a numpy array to indexed values. Only include the non-zero values. Using numba to accelerate. Args: data (np.ndarray): Data matrix. new_indices (np.ndarray): New indices. Returns: Tuple[List, List, List]: Row IDs, column indices, and values. """ row_ids, col_inds, values = ( [1] * len(data), [np.empty(0, dtype=np.int64)] * len(data), [np.empty(0, dtype=data.dtype)] * len(data), ) # placeholders for i in prange(len(data)): row = data[i] idx = np.nonzero(row)[0] expressions = row[idx] genes = new_indices[idx] row_ids[i] = i col_inds[i] = genes values[i] = expressions return row_ids, col_inds, values