scgpt.model.generation_model 源代码

import os
import math
from typing import Mapping, Optional, Tuple, Any, Union

import torch
from torch import nn, Tensor
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.distributions import Bernoulli
from torch.utils.data import dataset
from tqdm import trange

from .model import (
    ExprDecoder,
    MVCDecoder,
    ContinuousValueEncoder,
    FastTransformerEncoderWrapper,
    FlashTransformerEncoderLayer,
)
from ..utils import map_raw_id_to_vocab_id
from .. import logger


[文档] class TransformerGenerator(nn.Module): def __init__( self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, nlayers_cls: int, n_cls: int, vocab: Any, dropout: float = 0.5, pad_token: str = "<pad>", pad_value: int = 0, pert_pad_id: int = 2, do_mvc: bool = False, domain_spec_batchnorm: Union[bool, str] = False, cell_emb_style: str = "cls", mvc_decoder_style: str = "inner product", ecs_threshold: float = 0.3, explicit_zero_prob: bool = False, use_fast_transformer: bool = False, fast_transformer_backend: str = "flash", pre_norm: bool = False, ): super().__init__() self.model_type = "Transformer" self.d_model = d_model self.pad_token_id = vocab[pad_token] self.pad_value = pad_value self.pert_pad_id = pert_pad_id self.ecs_threshold = ecs_threshold self.domain_spec_batchnorm = domain_spec_batchnorm self.cell_emb_style = cell_emb_style self.explicit_zero_prob = explicit_zero_prob self.norm_scheme = "pre" if pre_norm else "post" if cell_emb_style not in ["cls", "avg-pool", "w-pool"]: raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}") self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token]) self.value_encoder = ContinuousValueEncoder(d_model, dropout) self.pert_encoder = nn.Embedding(3, d_model, padding_idx=pert_pad_id) print("Using simple batchnorm instead of domain specific batchnorm") self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5) if use_fast_transformer: if fast_transformer_backend == "linear": self.transformer_encoder = FastTransformerEncoderWrapper( d_model, nhead, d_hid, nlayers, dropout ) elif fast_transformer_backend == "flash": encoder_layers = FlashTransformerEncoderLayer( d_model, nhead, d_hid, dropout, batch_first=True, norm_scheme=self.norm_scheme, ) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) else: encoder_layers = TransformerEncoderLayer( d_model, nhead, d_hid, dropout, batch_first=True ) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) # self.decoder = nn.Linear(d_model, 1) self.decoder = ExprDecoder( d_model, explicit_zero_prob=explicit_zero_prob, ) self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls) if do_mvc: self.mvc_decoder = MVCDecoder( d_model, arch_style=mvc_decoder_style, explicit_zero_prob=explicit_zero_prob, ) self.sim = Similarity(temp=0.5) self.creterion_cce = nn.CrossEntropyLoss() self.init_weights()
[文档] def init_weights(self) -> None: initrange = 0.1 self.encoder.embedding.weight.data.uniform_(-initrange, initrange)
def _encode( self, src: Tensor, values: Tensor, input_pert_flags, src_key_padding_mask: Tensor, ) -> Tensor: src = self.encoder(src) # (batch, seq_len, embsize) self.cur_gene_token_embs = src values = self.value_encoder(values) # (batch, seq_len, embsize) perts = self.pert_encoder(input_pert_flags) # (batch, seq_len, embsize) total_embs = src + values + perts total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) output = self.transformer_encoder( total_embs, src_key_padding_mask=src_key_padding_mask ) return output # (batch, seq_len, embsize) def _get_cell_emb_from_layer( self, layer_output: Tensor, weights: Tensor = None ) -> Tensor: """ Args: layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize) weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used when :attr:`self.cell_emb_style` is "w-pool". Returns: :obj:`Tensor`: shape (batch, embsize) """ if self.cell_emb_style == "cls": cell_emb = layer_output[:, 0, :] # (batch, embsize) elif self.cell_emb_style == "avg-pool": cell_emb = torch.mean(layer_output, dim=1) elif self.cell_emb_style == "w-pool": if weights is None: raise ValueError("weights is required when cell_emb_style is w-pool") if weights.dim() != 2: raise ValueError("weights should be 2D") cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1) cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize) return cell_emb
[文档] def forward( self, src: Tensor, values: Tensor, input_pert_flags: Tensor, src_key_padding_mask: Tensor, CLS: bool = False, CCE: bool = False, MVC: bool = False, ECS: bool = False, do_sample: bool = False, ) -> Mapping[str, Tensor]: """ Args: src (:obj:`Tensor`): token ids, shape [batch_size, seq_len] values (:obj:`Tensor`): token values, shape [batch_size, seq_len] src_key_padding_mask (:obj:`Tensor`): mask for src, shape [batch_size, seq_len] CLS (:obj:`bool`): if True, return the celltype classification objective (CLS) output CCE (:obj:`bool`): if True, return the contrastive cell embedding objective (CCE) output MVC (:obj:`bool`): if True, return the masked value prediction for cell embedding MVC output ECS (:obj:`bool`): if True, return the elastic cell similarity objective (ECS) output. Returns: dict of output Tensors. """ if self.explicit_zero_prob and not do_sample and not self.training: do_sample = True logger.warning("Auto set do_sample to True when model is in eval mode.") transformer_output = self._encode( src, values, input_pert_flags, src_key_padding_mask ) output = {} mlm_output = self.decoder(transformer_output) if self.explicit_zero_prob and do_sample: bernoulli = Bernoulli(probs=mlm_output["zero_probs"]) output["mlm_output"] = bernoulli.sample() * mlm_output["pred"] else: output["mlm_output"] = mlm_output["pred"] # (batch, seq_len) if self.explicit_zero_prob: output["mlm_zero_probs"] = mlm_output["zero_probs"] cell_emb = self._get_cell_emb_from_layer(transformer_output, values) if CLS: output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls) if MVC: mvc_output = self.mvc_decoder( cell_emb, self.cur_gene_token_embs, ) # (batch, seq_len) if self.explicit_zero_prob and do_sample: bernoulli = Bernoulli(probs=mvc_output["zero_probs"]) output["mvc_output"] = bernoulli.sample() * mvc_output["pred"] else: output["mvc_output"] = mvc_output["pred"] # (batch, seq_len) if self.explicit_zero_prob: output["mvc_zero_probs"] = mvc_output["zero_probs"] if ECS: # Here using customized cosine similarity instead of F.cosine_similarity # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064 # normalize the embedding cell_emb_normed = F.normalize(cell_emb, p=2, dim=1) cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch) # mask out diagnal elements mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device) cos_sim = cos_sim.masked_fill(mask, 0.0) # only optimize positive similarities cos_sim = F.relu(cos_sim) output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2) return output
[文档] def encode_batch( self, src: Tensor, values: Tensor, src_key_padding_mask: Tensor, batch_size: int, output_to_cpu: bool = True, ) -> Tensor: """ Args: src: Tensor, shape [N, seq_len] values: Tensor, shape [N, seq_len] src_key_padding_mask: Tensor, shape [N, seq_len] Returns: output Tensor of shape [N, seq_len, embsize] """ outputs = [] N = src.size(0) device = next(self.parameters()).device for i in trange(0, N, batch_size): output = self._encode( src[i : i + batch_size].to(device), values[i : i + batch_size].to(device), src_key_padding_mask[i : i + batch_size].to(device), ) if output_to_cpu: output = output.cpu() outputs.append(output) return torch.cat(outputs, dim=0)
[文档] def pred_perturb( self, batch_data, include_zero_gene="batch-wise", gene_ids=None, amp=True, ) -> Tensor: """ Args: batch_data: a dictionary of input data with keys. Returns: output Tensor of shape [N, seq_len] """ self.eval() device = next(self.parameters()).device batch_data.to(device) batch_size = len(batch_data.pert) x: torch.Tensor = batch_data.x ori_gene_values = x[:, 0].view(batch_size, -1) # (batch_size, n_genes) pert_flags = x[:, 1].long().view(batch_size, -1) if include_zero_gene in ["all", "batch-wise"]: assert gene_ids is not None if include_zero_gene == "all": input_gene_ids = torch.arange(ori_gene_values.size(1), device=device) else: # batch-wise input_gene_ids = ( ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0] ) input_values = ori_gene_values[:, input_gene_ids] input_pert_flags = pert_flags[:, input_gene_ids] mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids) mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1) src_key_padding_mask = torch.zeros_like( input_values, dtype=torch.bool, device=device ) with torch.cuda.amp.autocast(enabled=amp): output_dict = self( mapped_input_gene_ids, input_values, input_pert_flags, src_key_padding_mask=src_key_padding_mask, CLS=False, CCE=False, MVC=False, ECS=False, do_sample=True, ) output_values = output_dict["mlm_output"].float() pred_gene_values = torch.zeros_like(ori_gene_values) pred_gene_values[:, input_gene_ids] = output_values return pred_gene_values
[文档] def generate_square_subsequent_mask(sz: int) -> Tensor: """Generates an upper-triangular matrix of -inf, with zeros on diag.""" return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)
[文档] class GeneEncoder(nn.Module): def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, ): super().__init__() self.embedding = nn.Embedding( num_embeddings, embedding_dim, padding_idx=padding_idx ) self.enc_norm = nn.LayerNorm(embedding_dim)
[文档] def forward(self, x: Tensor) -> Tensor: x = self.embedding(x) # (batch, seq_len, embsize) x = self.enc_norm(x) return x
[文档] class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe)
[文档] def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [seq_len, batch_size, embedding_dim] """ x = x + self.pe[: x.size(0)] return self.dropout(x)
[文档] class Similarity(nn.Module): """ Dot product or cosine similarity """ def __init__(self, temp): super().__init__() self.temp = temp self.cos = nn.CosineSimilarity(dim=-1)
[文档] def forward(self, x, y): return self.cos(x, y) / self.temp
[文档] class ClsDecoder(nn.Module): """ Decoder for classification task. """ def __init__( self, d_model: int, n_cls: int, nlayers: int = 3, activation: callable = nn.ReLU, ): super().__init__() # module list self._decoder = nn.ModuleList() for i in range(nlayers - 1): self._decoder.append(nn.Linear(d_model, d_model)) self._decoder.append(activation()) self._decoder.append(nn.LayerNorm(d_model)) self.out_layer = nn.Linear(d_model, n_cls)
[文档] def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [batch_size, embsize] """ for layer in self._decoder: x = layer(x) return self.out_layer(x)