scgpt.tokenizer.gene_tokenizer 源代码

import json
import pickle
from pathlib import Path
from collections import Counter, OrderedDict
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing_extensions import Self

import numpy as np
import pandas as pd
import torch
import torchtext.vocab as torch_vocab
from torchtext.vocab import Vocab
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers import AutoTokenizer, BertTokenizer

from .. import logger


[文档] class GeneTokenizer(PreTrainedTokenizer): pass
[文档] class GeneVocab(Vocab): """ Vocabulary for genes. """ def __init__( self, gene_list_or_vocab: Union[List[str], Vocab], specials: Optional[List[str]] = None, special_first: bool = True, default_token: Optional[str] = "<pad>", ) -> None: """ Initialize the vocabulary. Note: add specials only works when init from a gene list. Args: gene_list_or_vocab (List[str] or Vocab): List of gene names or a Vocab object. specials (List[str]): List of special tokens. special_first (bool): Whether to add special tokens to the beginning of the vocabulary. default_token (str): Default token, by default will set to "<pad>", if "<pad>" is in the vocabulary. """ if isinstance(gene_list_or_vocab, Vocab): _vocab = gene_list_or_vocab if specials is not None: raise ValueError( "receive non-empty specials when init from a Vocab object." ) elif isinstance(gene_list_or_vocab, list): _vocab = self._build_vocab_from_iterator( gene_list_or_vocab, specials=specials, special_first=special_first, ) else: raise ValueError( "gene_list_or_vocab must be a list of gene names or a Vocab object." ) super().__init__(_vocab.vocab) if default_token is not None and default_token in self: self.set_default_token(default_token)
[文档] @classmethod def from_file(cls, file_path: Union[Path, str]) -> Self: """ Load the vocabulary from a file. The file should be either a pickle or a json file of token to index mapping. """ if isinstance(file_path, str): file_path = Path(file_path) if file_path.suffix == ".pkl": with file_path.open("rb") as f: vocab = pickle.load(f) return cls(vocab) elif file_path.suffix == ".json": with file_path.open("r") as f: token2idx = json.load(f) return cls.from_dict(token2idx) else: raise ValueError( f"{file_path} is not a valid file type. " "Only .pkl and .json are supported." )
[文档] @classmethod def from_dict( cls, token2idx: Dict[str, int], default_token: Optional[str] = "<pad>", ) -> Self: """ Load the vocabulary from a dictionary. Args: token2idx (Dict[str, int]): Dictionary mapping tokens to indices. """ # initiate an empty vocabulary first _vocab = cls([]) # add the tokens to the vocabulary, GeneVocab requires consecutive indices for t, i in sorted(token2idx.items(), key=lambda x: x[1]): _vocab.insert_token(t, i) if default_token is not None and default_token in _vocab: _vocab.set_default_token(default_token) return _vocab
def _build_vocab_from_iterator( self, iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True, ) -> Vocab: """ Build a Vocab from an iterator. This function is modified from torchtext.vocab.build_vocab_from_iterator. The original function always splits tokens into characters, which is not what we want. Args: iterator (Iterable): Iterator used to build Vocab. Must yield list or iterator of tokens. min_freq (int): The minimum frequency needed to include a token in the vocabulary. specials (List[str]): Special symbols to add. The order of supplied tokens will be preserved. special_first (bool): Whether to add special tokens to the beginning Returns: torchtext.vocab.Vocab: A `Vocab` object """ counter = Counter() counter.update(iterator) if specials is not None: for tok in specials: del counter[tok] sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_by_freq_tuples) if specials is not None: if special_first: specials = specials[::-1] for symbol in specials: ordered_dict.update({symbol: min_freq}) ordered_dict.move_to_end(symbol, last=not special_first) word_vocab = torch_vocab.vocab(ordered_dict, min_freq=min_freq) return word_vocab @property def pad_token(self) -> Optional[str]: """ Get the pad token. """ if getattr(self, "_pad_token", None) is None: self._pad_token = None return self._pad_token @pad_token.setter def pad_token(self, pad_token: str) -> None: """ Set the pad token. Will not add the pad token to the vocabulary. Args: pad_token (str): Pad token, should be in the vocabulary. """ if pad_token not in self: raise ValueError(f"{pad_token} is not in the vocabulary.") self._pad_token = pad_token
[文档] def save_json(self, file_path: Union[Path, str]) -> None: """ Save the vocabulary to a json file. """ if isinstance(file_path, str): file_path = Path(file_path) with file_path.open("w") as f: json.dump(self.get_stoi(), f, indent=2)
[文档] def set_default_token(self, default_token: str) -> None: """ Set the default token. Args: default_token (str): Default token. """ if default_token not in self: raise ValueError(f"{default_token} is not in the vocabulary.") self.set_default_index(self[default_token])
[文档] def get_default_gene_vocab() -> GeneVocab: """ Get the default gene vocabulary, consisting of gene symbols and ids. """ vocab_file = Path(__file__).parent / "default_gene_vocab.json" if not vocab_file.exists(): logger.info( f"No existing default vocab, will build one and save to {vocab_file}" ) return _build_default_gene_vocab(save_vocab_to=vocab_file) logger.info(f"Loading gene vocabulary from {vocab_file}") return GeneVocab.from_file(vocab_file)
def _build_default_gene_vocab( download_source_to: str = "/tmp", save_vocab_to: Union[Path, str, None] = None, ) -> GeneVocab: """ Build the default gene vocabulary from HGNC gene symbols. Args: download_source_to (str): Directory to download the source data. save_vocab_to (Path or str): Path to save the vocabulary. If None, the vocabulary will not be saved. Default to None. """ gene_collection_file = ( Path(download_source_to) / "human.gene_name_symbol.from_genenames.org.tsv" ) if not gene_collection_file.exists(): # download and save file from url url = ( "https://www.genenames.org/cgi-bin/download/custom?col=gd_app_sym&" "col=md_ensembl_id&status=Approved&status=Entry%20Withdrawn&hgnc_dbtag" "=on&order_by=gd_app_sym_sort&format=text&submit=submit" ) import requests r = requests.get(url) gene_collection_file.write_text(r.text) logger.info(f"Building gene vocabulary from {gene_collection_file}") df = pd.read_csv(gene_collection_file, sep="\t") gene_list = df["Approved symbol"].dropna().unique().tolist() gene_vocab = GeneVocab(gene_list) # no special tokens set in default vocab if save_vocab_to is not None: gene_vocab.save_json(Path(save_vocab_to)) return gene_vocab
[文档] def tokenize_batch( data: np.ndarray, gene_ids: np.ndarray, return_pt: bool = True, append_cls: bool = True, include_zero_gene: bool = False, cls_id: int = "<cls>", ) -> List[Tuple[Union[torch.Tensor, np.ndarray]]]: """ Tokenize a batch of data. Returns a list of tuple (gene_id, count). Args: data (array-like): A batch of data, with shape (batch_size, n_features). n_features equals the number of all genes. gene_ids (array-like): A batch of gene ids, with shape (n_features,). return_pt (bool): Whether to return torch tensors of gene_ids and counts, default to True. Returns: list: A list of tuple (gene_id, count) of non zero gene expressions. """ if data.shape[1] != len(gene_ids): raise ValueError( f"Number of features in data ({data.shape[1]}) does not match " f"number of gene_ids ({len(gene_ids)})." ) tokenized_data = [] for i in range(len(data)): row = data[i] if include_zero_gene: values = row genes = gene_ids else: idx = np.nonzero(row)[0] values = row[idx] genes = gene_ids[idx] if append_cls: genes = np.insert(genes, 0, cls_id) values = np.insert(values, 0, 0) if return_pt: genes = torch.from_numpy(genes).long() values = torch.from_numpy(values).float() tokenized_data.append((genes, values)) return tokenized_data
[文档] def pad_batch( batch: List[Tuple], max_len: int, vocab: Vocab, pad_token: str = "<pad>", pad_value: int = 0, cls_appended: bool = True, ) -> Dict[str, torch.Tensor]: """ Pad a batch of data. Returns a list of Dict[gene_id, count]. Args: batch (list): A list of tuple (gene_id, count). max_len (int): The maximum length of the batch. vocab (Vocab): The vocabulary containing the pad token. pad_token (str): The token to pad with. Returns: Dict[str, torch.Tensor]: A dictionary of gene_id and count. """ pad_id = vocab[pad_token] gene_ids_list = [] values_list = [] for i in range(len(batch)): gene_ids, values = batch[i] if len(gene_ids) > max_len: # sample max_len genes if not cls_appended: idx = np.random.choice(len(gene_ids), max_len, replace=False) else: idx = np.random.choice(len(gene_ids) - 1, max_len - 1, replace=False) idx = idx + 1 idx = np.insert(idx, 0, 0) gene_ids = gene_ids[idx] values = values[idx] if len(gene_ids) < max_len: gene_ids = torch.cat( [ gene_ids, torch.full( (max_len - len(gene_ids),), pad_id, dtype=gene_ids.dtype ), ] ) values = torch.cat( [ values, torch.full((max_len - len(values),), pad_value, dtype=values.dtype), ] ) gene_ids_list.append(gene_ids) values_list.append(values) batch_padded = { "genes": torch.stack(gene_ids_list, dim=0), "values": torch.stack(values_list, dim=0), } return batch_padded
[文档] def tokenize_and_pad_batch( data: np.ndarray, gene_ids: np.ndarray, max_len: int, vocab: Vocab, pad_token: str, pad_value: int, append_cls: bool = True, include_zero_gene: bool = False, cls_token: str = "<cls>", return_pt: bool = True, ) -> Dict[str, torch.Tensor]: """ Tokenize and pad a batch of data. Returns a list of tuple (gene_id, count). """ cls_id = vocab[cls_token] tokenized_data = tokenize_batch( data, gene_ids, return_pt=return_pt, append_cls=append_cls, include_zero_gene=include_zero_gene, cls_id=cls_id, ) batch_padded = pad_batch( tokenized_data, max_len, vocab, pad_token, pad_value, cls_appended=append_cls ) return batch_padded
[文档] def random_mask_value( values: Union[torch.Tensor, np.ndarray], mask_ratio: float = 0.15, mask_value: int = -1, pad_value: int = 0, ) -> torch.Tensor: """ Randomly mask a batch of data. Args: values (array-like): A batch of tokenized data, with shape (batch_size, n_features). mask_ratio (float): The ratio of genes to mask, default to 0.15. mask_value (int): The value to mask with, default to -1. pad_value (int): The value of padding in the values, will be kept unchanged. Returns: torch.Tensor: A tensor of masked data. """ if isinstance(values, torch.Tensor): # it is crutial to clone the tensor, otherwise it changes the original tensor values = values.clone().detach().numpy() else: values = values.copy() for i in range(len(values)): row = values[i] non_padding_idx = np.nonzero(row - pad_value)[0] n_mask = int(len(non_padding_idx) * mask_ratio) mask_idx = np.random.choice(non_padding_idx, n_mask, replace=False) row[mask_idx] = mask_value return torch.from_numpy(values).float()