scgpt.model.model 源代码

import gc
import math
from typing import Dict, Mapping, Optional, Tuple, Any, Union
import warnings

import torch
import numpy as np
from torch import nn, Tensor
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.distributions import Bernoulli
from tqdm import trange
from zmq import device

try:
    from flash_attn.flash_attention import FlashMHA
    from .flash_layers import FlashscGPTLayer, FlashscGPTGenerator
except ImportError:
    import warnings

    warnings.warn("flash_attn is not installed")

from .dsbn import DomainSpecificBatchNorm1d
from .grad_reverse import grad_reverse


[文档] class TransformerModel(nn.Module): def __init__( self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, nlayers_cls: int = 3, n_cls: int = 1, vocab: Any = None, dropout: float = 0.5, pad_token: str = "<pad>", pad_value: int = 0, do_mvc: bool = False, do_dab: bool = False, use_batch_labels: bool = False, num_batch_labels: Optional[int] = None, domain_spec_batchnorm: Union[bool, str] = False, input_emb_style: str = "continuous", n_input_bins: Optional[int] = None, cell_emb_style: str = "cls", mvc_decoder_style: str = "inner product", ecs_threshold: float = 0.3, explicit_zero_prob: bool = False, use_generative_training=False, use_fast_transformer: bool = False, fast_transformer_backend: str = "flash", pre_norm: bool = False, use_sim_decoder: bool = False, ): super().__init__() self.model_type = "Transformer" self.d_model = d_model self.do_dab = do_dab self.ecs_threshold = ecs_threshold self.use_batch_labels = use_batch_labels self.domain_spec_batchnorm = domain_spec_batchnorm self.input_emb_style = input_emb_style self.cell_emb_style = cell_emb_style self.explicit_zero_prob = explicit_zero_prob self.norm_scheme = "pre" if pre_norm else "post" if self.input_emb_style not in ["category", "continuous", "scaling"]: raise ValueError( f"input_emb_style should be one of category, continuous, scaling, " f"got {input_emb_style}" ) if cell_emb_style not in ["cls", "avg-pool", "w-pool"]: raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}") # TODO: add dropout in the GeneEncoder self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token]) self.flag_encoder = nn.Embedding(2, d_model) # Value Encoder, NOTE: the scaling style is also handled in _encode method if input_emb_style == "continuous": self.value_encoder = ContinuousValueEncoder(d_model, dropout) elif input_emb_style == "category": assert n_input_bins > 0 self.value_encoder = CategoryValueEncoder( n_input_bins, d_model, padding_idx=pad_value ) else: self.value_encoder = nn.Identity() # nn.Softmax(dim=1) # TODO: consider row-wise normalization or softmax # TODO: Correct handle the mask_value when using scaling # Batch Encoder if use_batch_labels: self.batch_encoder = BatchLabelEncoder(num_batch_labels, d_model) if domain_spec_batchnorm: use_affine = True if domain_spec_batchnorm == "do_affine" else False print(f"Use domain specific batchnorm with affine={use_affine}") self.dsbn = DomainSpecificBatchNorm1d( d_model, num_batch_labels, eps=6.1e-5, affine=use_affine ) # else: # print("Using simple batchnorm instead of domain specific batchnorm") # self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5) # bug if use_generative_training: encoder_layers = FlashscGPTLayer( d_model, nhead, d_hid, dropout, batch_first=True, norm_scheme=self.norm_scheme, # device=device, dtype=torch.float32 ) # print(f"1 use_generative_training: {use_generative_training}") self.transformer_encoder = FlashscGPTGenerator(encoder_layers, nlayers) # print(f"2 self.transformer_encoder: {self.transformer_encoder}") elif use_fast_transformer: # if use_fast_transformer: if fast_transformer_backend == "linear": self.transformer_encoder = FastTransformerEncoderWrapper( d_model, nhead, d_hid, nlayers, dropout ) elif fast_transformer_backend == "flash": encoder_layers = FlashTransformerEncoderLayer( d_model, nhead, d_hid, dropout, batch_first=True, norm_scheme=self.norm_scheme, ) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) else: encoder_layers = TransformerEncoderLayer( d_model, nhead, d_hid, dropout, batch_first=True ) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) # print(f"3 self.transformer_encoder: {type(self.transformer_encoder)}") self.decoder = ExprDecoder( d_model, explicit_zero_prob=explicit_zero_prob, use_batch_labels=use_batch_labels, ) if n_cls > 1: if use_sim_decoder: self.cls_decoder = SimDecoder(d_model, n_cls, nlayers=nlayers_cls) else: self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls) if do_mvc: self.mvc_decoder = MVCDecoder( d_model, arch_style=mvc_decoder_style, explicit_zero_prob=explicit_zero_prob, use_batch_labels=use_batch_labels, ) if do_dab: self.grad_reverse_discriminator = AdversarialDiscriminator( d_model, n_cls=num_batch_labels, reverse_grad=True, ) # self.sim = Similarity(temp=0.5) # TODO: auto set temp # self.creterion_cce = nn.CrossEntropyLoss() self.init_weights()
[文档] def init_weights(self) -> None: initrange = 0.1 # TODO: check if this initialization is helpful and shall we apply to all? self.encoder.embedding.weight.data.uniform_(-initrange, initrange)
def _encode( self, src: Tensor, values: Tensor, src_key_padding_mask: Tensor, batch_labels: Optional[Tensor] = None, # (batch,) ) -> Tensor: self._check_batch_labels(batch_labels) # print("*"*200) # print(f"src: {type(src)} {src.dtype} {src.shape}") # print(f"values: {type(values)} {values.dtype} {values.shape}") src = self.encoder(src) # (batch, seq_len, embsize) # print(f"src: {type(src)} {src.shape}") self.cur_gene_token_embs = src values = self.value_encoder(values) # (batch, seq_len, embsize) if self.input_emb_style == "scaling": values = values.unsqueeze(2) total_embs = src * values else: total_embs = src + values if self.domain_spec_batchnorm: batch_label = int(batch_labels[0].item()) total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute( 0, 2, 1 ) # the batch norm always works on dim 1 # else: # total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) # print(f"total_embs: {type(total_embs)} {total_embs.dtype} {total_embs.shape}") # print(f"src_key_padding_mask: {type(src_key_padding_mask)} {src_key_padding_mask.dtype} {src_key_padding_mask.shape}") output = self.transformer_encoder( total_embs, src_key_padding_mask=src_key_padding_mask ) return output # (batch, seq_len, embsize)
[文档] def transformer_generate( self, pcpt_genes: Tensor, pcpt_values: Tensor, pcpt_key_padding_mask: Tensor, gen_genes: Tensor, gen_key_padding_mask: Tensor, batch_labels: Optional[Tensor] = None, # (batch,) input_cell_emb: Optional[Tensor] = None, # (batch, seq_len, embsize) ) -> Tuple[Tensor, Tensor]: self._check_batch_labels(batch_labels) pcpt_token_embs = self.encoder(pcpt_genes) # (batch, pcpt_len, embsize) pcpt_values = self.value_encoder(pcpt_values) # (batch, pcpt_len, embsize) pcpt_total_embs = pcpt_token_embs + pcpt_values assert self.input_emb_style != "scaling" if gen_genes is not None: gen_token_embs = self.encoder(gen_genes) # (batch, gen_len, embsize) self.cur_gene_token_embs = torch.cat( [pcpt_token_embs, gen_token_embs], dim=1 ) gen_flags = self.flag_encoder( torch.tensor(1).to(pcpt_values.device) ).expand(gen_genes.shape[0], gen_genes.shape[1], -1) gen_total_embs = gen_token_embs + gen_flags else: self.cur_gene_token_embs = pcpt_token_embs gen_total_embs = None if self.domain_spec_batchnorm: batch_label = int(batch_labels[0].item()) pcpt_total_embs = self.dsbn( pcpt_total_embs.permute(0, 2, 1), batch_label ).permute(0, 2, 1) if gen_genes is not None: gen_total_embs = self.dsbn( gen_total_embs.permute(0, 2, 1), batch_label ).permute(0, 2, 1) # else: # pcpt_total_embs = self.bn(pcpt_total_embs.permute(0, 2, 1)).permute(0, 2, 1) # if gen_genes is not None: # gen_total_embs = self.bn(gen_total_embs.permute(0, 2, 1)).permute( # 0, 2, 1 # ) if input_cell_emb is not None: pcpt_total_embs[:, 0, :] = input_cell_emb # print(f"pcpt_total_embs: {type(pcpt_total_embs)} {pcpt_total_embs.dtype}") # print(f"gen_total_embs: {type(gen_total_embs)} {gen_total_embs.dtype}") # print(f"pcpt_key_padding_mask: {type(pcpt_key_padding_mask)} {pcpt_key_padding_mask.dtype}") # print(f"gen_key_padding_mask: {type(gen_key_padding_mask)} {gen_key_padding_mask.dtype}") pcpt_output, gen_output = self.transformer_encoder( pcpt_total_embs, gen_total_embs, pcpt_key_padding_mask=pcpt_key_padding_mask, gen_key_padding_mask=gen_key_padding_mask, ) return pcpt_output, gen_output
def _get_cell_emb_from_layer( self, layer_output: Tensor, weights: Tensor = None ) -> Tensor: """ Args: layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize) weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used when :attr:`self.cell_emb_style` is "w-pool". Returns: :obj:`Tensor`: shape (batch, embsize) """ if self.cell_emb_style == "cls": cell_emb = layer_output[:, 0, :] # (batch, embsize) elif self.cell_emb_style == "avg-pool": cell_emb = torch.mean(layer_output, dim=1) elif self.cell_emb_style == "w-pool": if weights is None: raise ValueError("weights is required when cell_emb_style is w-pool") if weights.dim() != 2: raise ValueError("weights should be 2D") cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1) cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize) return cell_emb def _check_batch_labels(self, batch_labels: Tensor) -> None: if self.use_batch_labels or self.domain_spec_batchnorm: assert batch_labels is not None elif batch_labels is not None: raise ValueError( "batch_labels should only be provided when `self.use_batch_labels`" " or `self.domain_spec_batchnorm` is True" )
[文档] def generate( self, cell_emb: Tensor, src: Tensor, values: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, gen_iters: int = 1, batch_labels: Optional[Tensor] = None, # (batch,) ) -> Tensor: """ Args: cell_emb(:obj:`Tensor`): shape (batch, embsize) src(:obj:`Tensor`): shape (batch, seq_len) values(:obj:`Tensor`): shape (batch, seq_len), optional src_key_padding_mask(:obj:`Tensor`): shape (batch, seq_len), optional gen_iters(:obj:`int`): number of generation iterations batch_labels(:obj:`Tensor`): shape (batch,), optional """ # TODO: should have a tag indicate the generation mode # TODO: if gen_iters > 1, should have a tag indicate the current iteration try: self._check_batch_labels(batch_labels) except: warnings.warn( "batch_labels is required but not provided, using zeros instead" ) batch_labels = torch.zeros( cell_emb.shape[0], dtype=torch.long, device=cell_emb.device ) src = self.encoder(src) # (batch, seq_len, embsize) if values is not None: values = self.value_encoder(values) # (batch, seq_len, embsize) if self.input_emb_style == "scaling": values = values.unsqueeze(2) total_embs = src * values else: total_embs = src + values else: total_embs = src if self.domain_spec_batchnorm: batch_label = int(batch_labels[0].item()) total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute( 0, 2, 1 ) # the batch norm always works on dim 1 else: total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) total_embs[:, 0, :] = cell_emb if src_key_padding_mask is None: src_key_padding_mask = torch.zeros( total_embs.shape[:2], dtype=torch.bool, device=total_embs.device ) transformer_output = self.transformer_encoder( total_embs, src_key_padding_mask=src_key_padding_mask ) if self.use_batch_labels: batch_emb = self.batch_encoder(batch_labels) # (batch, embsize) mlm_output = self.decoder( transformer_output if not self.use_batch_labels else torch.cat( [ transformer_output, batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1), ], dim=2, ), # else transformer_output + batch_emb.unsqueeze(1), ) output = mlm_output["pred"] # (batch, seq_len) return output # (batch, seq_len)
def _extend_output( self, output: Mapping[str, Tensor], transformer_output: Tensor, batch_emb: Optional[Tensor] = None, CLS: bool = False, MVC: bool = False, ECS: bool = False, do_sample: bool = False, ) -> Mapping[str, Tensor]: cell_emb = self._get_cell_emb_from_layer(transformer_output) output["cell_emb"] = cell_emb if CLS: output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls) if MVC: mvc_output = self.mvc_decoder( cell_emb if not self.use_batch_labels else torch.cat([cell_emb, batch_emb], dim=1), # else cell_emb + batch_emb, self.cur_gene_token_embs, ) if self.explicit_zero_prob and do_sample: bernoulli = Bernoulli(probs=mvc_output["zero_probs"]) output["mvc_output"] = bernoulli.sample() * mvc_output["pred"] else: output["mvc_output"] = mvc_output["pred"] # (batch, seq_len) if self.explicit_zero_prob: output["mvc_zero_probs"] = mvc_output["zero_probs"] if ECS: # Here using customized cosine similarity instead of F.cosine_similarity # to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064 # normalize the embedding cell_emb_normed = F.normalize(cell_emb, p=2, dim=1) cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch) # mask out diagnal elements mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device) cos_sim = cos_sim.masked_fill(mask, 0.0) # only optimize positive similarities cos_sim = F.relu(cos_sim) output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2) if self.do_dab: output["dab_output"] = self.grad_reverse_discriminator(cell_emb) return output
[文档] def forward( self, *args, **kwargs, ) -> Mapping[str, Tensor]: """ Wrapper to call either generative_forward or perceptual_forward, depending on the value of the "generative_training" kwarg. """ if "generative_training" not in kwargs: # raise ValueError("generative_training kwarg is required") warnings.warn( "generative_training kwarg is required but not provided! " "Using False and calling perceptual_forward instead" ) return self.perceptual_forward(*args, **kwargs) # get the generative training flag and pop it out do_generative_training = kwargs.pop("generative_training") if do_generative_training: return self.generative_forward(*args, **kwargs) else: return self.perceptual_forward(*args, **kwargs)
[文档] def generative_forward( self, pcpt_genes: Tensor, pcpt_values: Tensor, pcpt_key_padding_mask: Tensor, gen_genes: Tensor, gen_key_padding_mask: Tensor, batch_labels: Optional[Tensor] = None, CLS: bool = False, CCE: bool = False, MVC: bool = False, ECS: bool = False, do_sample: bool = False, input_cell_emb: Optional[Tensor] = None, ) -> Mapping[str, Tensor]: """ Args: pcpt_genes (:obj:`Tensor`): token ids of the perceptual part, shape [batch_size, seq_len] pcpt_values (:obj:`Tensor`): token values of the perceptual part, shape [batch_size, seq_len] pcpt_key_padding_mask (:obj:`Tensor`): mask for pcpt_genes, shape [batch_size, seq_len] gen_genes (:obj:`Tensor`): token ids of the generative part, shape [batch_size, seq_len] gen_key_padding_mask (:obj:`Tensor`): mask for gen_genes, shape [batch_size, seq_len] batch_labels (:obj:`Tensor`): batch labels, shape [batch_size] do_sample (:obj:`bool`): whether to do sampling from bernoulli for generated zero predictions. input_cell_emb (:obj:`Tensor`): cell embeddings, shape [batch_size, embsize] Returns: :obj:`Mapping[str, Tensor]`: - pred (:obj:`Tensor`): prediction, shape [batch_size, seq_len] - cell_emb (:obj:`Tensor`): cell embeddings, shape [batch_size, embsize] """ pcpt_output, gen_output = self.transformer_generate( pcpt_genes, pcpt_values, pcpt_key_padding_mask, gen_genes, gen_key_padding_mask, batch_labels, input_cell_emb=input_cell_emb, ) if gen_output is None: transformer_output = pcpt_output else: transformer_output = torch.cat([pcpt_output, gen_output], dim=1) if self.use_batch_labels: batch_emb = self.batch_encoder(batch_labels) output = {} decoder_output = self.decoder( transformer_output if not self.use_batch_labels else torch.cat( [ transformer_output, batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1), ], dim=2, ), ) if self.explicit_zero_prob and do_sample: bernoulli = Bernoulli(probs=decoder_output["zero_probs"]) full_preds = bernoulli.sample() * decoder_output["pred"] output["pcpt_preds"] = full_preds[:, : pcpt_genes.shape[1]] output["gen_preds"] = full_preds[:, pcpt_genes.shape[1] :] else: full_preds = decoder_output["pred"] # (batch, seq_len) output["pcpt_preds"] = full_preds[:, : pcpt_genes.shape[1]] output["gen_preds"] = full_preds[:, pcpt_genes.shape[1] :] if self.explicit_zero_prob: output["zero_probs"] = decoder_output["zero_probs"] output = self._extend_output( output, transformer_output, batch_emb=batch_emb if self.use_batch_labels else None, CLS=CLS, MVC=MVC, ECS=ECS, do_sample=do_sample, ) return output
[文档] def perceptual_forward( self, src: Tensor, values: Tensor, src_key_padding_mask: Tensor, batch_labels: Optional[Tensor] = None, CLS: bool = False, CCE: bool = False, MVC: bool = False, ECS: bool = False, do_sample: bool = False, ) -> Mapping[str, Tensor]: """ Args: src (:obj:`Tensor`): token ids, shape [batch_size, seq_len] values (:obj:`Tensor`): token values, shape [batch_size, seq_len] src_key_padding_mask (:obj:`Tensor`): mask for src, shape [batch_size, seq_len] batch_labels (:obj:`Tensor`): batch labels, shape [batch_size] CLS (:obj:`bool`): if True, return the celltype classification objective (CLS) output CCE (:obj:`bool`): if True, return the contrastive cell embedding objective (CCE) output MVC (:obj:`bool`): if True, return the masked value prediction for cell embedding MVC output ECS (:obj:`bool`): if True, return the elastic cell similarity objective (ECS) output. Returns: dict of output Tensors. """ transformer_output = self._encode( src, values, src_key_padding_mask, batch_labels ) if self.use_batch_labels: batch_emb = self.batch_encoder(batch_labels) # (batch, embsize) output = {} mlm_output = self.decoder( transformer_output if not self.use_batch_labels else torch.cat( [ transformer_output, batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1), ], dim=2, ), # else transformer_output + batch_emb.unsqueeze(1), ) if self.explicit_zero_prob and do_sample: bernoulli = Bernoulli(probs=mlm_output["zero_probs"]) output["mlm_output"] = bernoulli.sample() * mlm_output["pred"] else: output["mlm_output"] = mlm_output["pred"] # (batch, seq_len) if self.explicit_zero_prob: output["mlm_zero_probs"] = mlm_output["zero_probs"] output = self._extend_output( output, transformer_output, batch_emb=batch_emb if self.use_batch_labels else None, CLS=CLS, MVC=MVC, ECS=ECS, do_sample=do_sample, ) return output
[文档] def encode_batch( self, src: Tensor, values: Tensor, src_key_padding_mask: Tensor, batch_size: int, batch_labels: Optional[Tensor] = None, output_to_cpu: bool = True, time_step: Optional[int] = None, return_np: bool = False, ) -> Tensor: """ Args: src (Tensor): shape [N, seq_len] values (Tensor): shape [N, seq_len] src_key_padding_mask (Tensor): shape [N, seq_len] batch_size (int): batch size for encoding batch_labels (Tensor): shape [N, n_batch_labels] output_to_cpu (bool): whether to move the output to cpu time_step (int): the time step index in the transformer output to return. The time step is along the second dimenstion. If None, return all. return_np (bool): whether to return numpy array Returns: output Tensor of shape [N, seq_len, embsize] """ N = src.size(0) device = next(self.parameters()).device # initialize the output tensor array_func = np.zeros if return_np else torch.zeros float32_ = np.float32 if return_np else torch.float32 shape = ( (N, self.d_model) if time_step is not None else (N, src.size(1), self.d_model) ) outputs = array_func(shape, dtype=float32_) for i in trange(0, N, batch_size): raw_output = self._encode( src[i : i + batch_size].to(device), values[i : i + batch_size].to(device), src_key_padding_mask[i : i + batch_size].to(device), batch_labels[i : i + batch_size].to(device) if batch_labels is not None else None, ) output = raw_output.detach() if output_to_cpu: output = output.cpu() if return_np: output = output.numpy() if time_step is not None: output = output[:, time_step, :] outputs[i : i + batch_size] = output return outputs
[文档] def generate_square_subsequent_mask(sz: int) -> Tensor: """Generates an upper-triangular matrix of -inf, with zeros on diag.""" return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)
[文档] class FastTransformerEncoderWrapper(nn.Module): def __init__( self, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5, ): super().__init__() self.fast_transformer_encoder = self.build_fast_transformer_encoder( d_model, nhead, d_hid, nlayers, dropout )
[文档] @staticmethod def build_fast_transformer_encoder( d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float ) -> nn.Module: from fast_transformers.builders import TransformerEncoderBuilder if d_model % nhead != 0: raise ValueError( f"d_model must be divisible by nhead, " f"got d_model={d_model} and nhead={nhead}" ) builder = TransformerEncoderBuilder.from_kwargs( n_layers=nlayers, n_heads=nhead, query_dimensions=d_model // nhead, value_dimensions=d_model // nhead, feed_forward_dimensions=d_hid, attention_type="linear", attention_dropout=dropout, dropout=dropout, activation="gelu", ) assert builder.attention_type == "linear" return builder.get()
[文档] @staticmethod def build_length_mask( src: Tensor, src_key_padding_mask: torch.BoolTensor, ) -> "LengthMask": from fast_transformers.masking import LengthMask seq_len = src.shape[1] num_paddings = src_key_padding_mask.sum(dim=1) actual_seq_len = seq_len - num_paddings # (N,) length_mask = LengthMask(actual_seq_len, max_len=seq_len, device=src.device) if src_key_padding_mask[length_mask.bool_matrix].sum() != 0: raise ValueError( "Found padding tokens in the middle of the sequence. " "src_key_padding_mask and length_mask are not compatible." ) return length_mask
[文档] def forward( self, src: Tensor, src_key_padding_mask: torch.BoolTensor, ) -> Tensor: """ Args: src: Tensor, shape [N, seq_len, embsize] src_key_padding_mask: Tensor, shape [N, seq_len] Returns: output Tensor of shape [N, seq_len, embsize] """ if src_key_padding_mask.shape != src.shape[:2]: raise ValueError( f"src_key_padding_mask shape {src_key_padding_mask.shape} " f"does not match first two dims of src shape {src.shape[:2]}" ) if src_key_padding_mask.dtype != torch.bool: raise ValueError( f"src_key_padding_mask needs to be of type torch.bool, " f"got {src_key_padding_mask.dtype}" ) length_mask = self.build_length_mask(src, src_key_padding_mask) output = self.fast_transformer_encoder(src, length_mask=length_mask) return output
[文档] class FlashTransformerEncoderLayer(nn.Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. The class is modified from torch.nn.TransformerEncoderLayer to support the FlashAttention. Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of intermediate layer, relu or gelu (default=relu). layer_norm_eps: the eps value in layer normalization components (default=1e-5). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False``. Examples:: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> out = encoder_layer(src) Alternatively, when ``batch_first`` is ``True``: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) """ __constants__ = ["batch_first"] def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", layer_norm_eps=1e-5, batch_first=True, device=None, dtype=None, norm_scheme="post", # "pre" or "post" ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.self_attn = FlashMHA( embed_dim=d_model, num_heads=nhead, batch_first=batch_first, attention_dropout=dropout, **factory_kwargs, ) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = self._get_activation_fn(activation) self.norm_scheme = norm_scheme if self.norm_scheme not in ["pre", "post"]: raise ValueError(f"norm_scheme should be pre or post, not {norm_scheme}") @staticmethod def _get_activation_fn(activation): if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) def __setstate__(self, state): if "activation" not in state: state["activation"] = F.relu super().__setstate__(state)
[文档] def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, **kwargs, ) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ if src_mask is not None: raise ValueError("FlashTransformerEncoderLayer does not support src_mask") if not src_key_padding_mask.any().item(): # no padding tokens in src src_key_padding_mask_ = None else: if src_key_padding_mask.dtype != torch.bool: src_key_padding_mask = src_key_padding_mask.bool() # NOTE: the FlashMHA uses mask 0 for padding tokens, which is the opposite src_key_padding_mask_ = ~src_key_padding_mask if self.norm_scheme == "pre": src = self.norm1(src) src2 = self.self_attn(src, key_padding_mask=src_key_padding_mask_)[0] src = src + self.dropout1(src2) src = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) else: src2 = self.self_attn(src, key_padding_mask=src_key_padding_mask_)[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src
[文档] class GeneEncoder(nn.Module): def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, ): super().__init__() self.embedding = nn.Embedding( num_embeddings, embedding_dim, padding_idx=padding_idx ) self.enc_norm = nn.LayerNorm(embedding_dim)
[文档] def forward(self, x: Tensor) -> Tensor: # print(f"enter the GeneEncoder") # print(f"1 x: {type(x)} {x.shape}") # 4, 1200 x = self.embedding(x) # (batch, seq_len, embsize) # print(f"2 x: {type(x)} {x.shape}") x = self.enc_norm(x) # print(f"3 x: {type(x)} {x.shape}") return x
[文档] class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe)
[文档] def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [seq_len, batch_size, embedding_dim] """ x = x + self.pe[: x.size(0)] return self.dropout(x)
[文档] class ContinuousValueEncoder(nn.Module): """ Encode real number values to a vector using neural nets projection. """ def __init__(self, d_model: int, dropout: float = 0.1, max_value: int = 512): super().__init__() self.dropout = nn.Dropout(p=dropout) self.linear1 = nn.Linear(1, d_model) self.activation = nn.ReLU() self.linear2 = nn.Linear(d_model, d_model) self.norm = nn.LayerNorm(d_model) self.max_value = max_value
[文档] def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [batch_size, seq_len] """ # TODO: test using actual embedding layer if input is categorical # expand last dimension x = x.unsqueeze(-1) # clip x to [-inf, max_value] x = torch.clamp(x, max=self.max_value) x = self.activation(self.linear1(x)) x = self.linear2(x) x = self.norm(x) return self.dropout(x)
[文档] class CategoryValueEncoder(nn.Module): def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, ): super().__init__() self.embedding = nn.Embedding( num_embeddings, embedding_dim, padding_idx=padding_idx ) self.enc_norm = nn.LayerNorm(embedding_dim)
[文档] def forward(self, x: Tensor) -> Tensor: x = x.long() x = self.embedding(x) # (batch, seq_len, embsize) x = self.enc_norm(x) return x
[文档] class BatchLabelEncoder(nn.Module): def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, ): super().__init__() self.embedding = nn.Embedding( num_embeddings, embedding_dim, padding_idx=padding_idx ) self.enc_norm = nn.LayerNorm(embedding_dim)
[文档] def forward(self, x: Tensor) -> Tensor: x = self.embedding(x) # (batch, embsize) x = self.enc_norm(x) return x
[文档] class Similarity(nn.Module): """ Dot product or cosine similarity """ def __init__(self, temp): super().__init__() self.temp = temp self.cos = nn.CosineSimilarity(dim=-1)
[文档] def forward(self, x, y): return self.cos(x, y) / self.temp
[文档] class ExprDecoder(nn.Module): def __init__( self, d_model: int, explicit_zero_prob: bool = False, use_batch_labels: bool = False, ): super().__init__() d_in = d_model * 2 if use_batch_labels else d_model self.fc = nn.Sequential( nn.Linear(d_in, d_model), nn.LeakyReLU(), nn.Linear(d_model, d_model), nn.LeakyReLU(), nn.Linear(d_model, 1), ) self.explicit_zero_prob = explicit_zero_prob if explicit_zero_prob: self.zero_logit = nn.Sequential( nn.Linear(d_in, d_model), nn.LeakyReLU(), nn.Linear(d_model, d_model), nn.LeakyReLU(), nn.Linear(d_model, 1), )
[文档] def forward(self, x: Tensor) -> Dict[str, Tensor]: """x is the output of the transformer, (batch, seq_len, d_model)""" pred_value = self.fc(x).squeeze(-1) # (batch, seq_len) if not self.explicit_zero_prob: return dict(pred=pred_value) zero_logits = self.zero_logit(x).squeeze(-1) # (batch, seq_len) zero_probs = torch.sigmoid(zero_logits) return dict(pred=pred_value, zero_probs=zero_probs)
# TODO: note that the return currently is only for training. Since decoder # is not used in the test setting for the integration task, the eval/inference # logic is not implemented yet. However, remember to implement it when # the decoder is used in any test setting. The inference logic will need # to sample from the bernoulli distribution with the zero_probs.
[文档] class ClsDecoder(nn.Module): """ Decoder for classification task. """ def __init__( self, d_model: int, n_cls: int, nlayers: int = 3, activation: callable = nn.ReLU, ): super().__init__() # module list self._decoder = nn.ModuleList() for i in range(nlayers - 1): self._decoder.append(nn.Linear(d_model, d_model)) self._decoder.append(activation()) self._decoder.append(nn.LayerNorm(d_model)) self.out_layer = nn.Linear(d_model, n_cls)
[文档] def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [batch_size, embsize] """ for layer in self._decoder: x = layer(x) return self.out_layer(x)
[文档] class SimDecoder(nn.Module): """ Decoder for classification task with similarity matrix. """ def __init__( self, d_model: int, n_cls: int, nlayers: int = 3, activation: callable = nn.ReLU, projection_dim: int = 2048, ): super().__init__() # module list self._decoder = nn.ModuleList() for i in range(nlayers - 1): self._decoder.append(nn.Linear(d_model, d_model)) self._decoder.append(activation()) self._decoder.append(nn.LayerNorm(d_model)) self.out_layer = nn.Linear(d_model, projection_dim) self.cls_token_matrix = nn.Parameter(torch.randn(n_cls, projection_dim)) self.embed_norm = nn.LayerNorm(projection_dim) self.token_norm = nn.LayerNorm(projection_dim)
[文档] def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [batch_size, embsize] """ for layer in self._decoder: x = layer(x) x = self.out_layer(x) x = self.embed_norm(x) sim = self.sim_matrix(x, self.cls_token_matrix) return sim
[文档] def get_sim_matrix(self): return self.cls_token_matrix
[文档] def sim_matrix(self, a, b, eps=1e-8): # b = self.token_norm(b) a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) # sim_mt = torch.mm(a, b.transpose(0, 1)) return sim_mt
[文档] class MVCDecoder(nn.Module): """ Decoder for the masked value prediction for cell embeddings. """ def __init__( self, d_model: int, arch_style: str = "inner product", query_activation: nn.Module = nn.Sigmoid, hidden_activation: nn.Module = nn.PReLU, explicit_zero_prob: bool = False, use_batch_labels: bool = False, ) -> None: """ Args: d_model (:obj:`int`): dimension of the gene embedding. arch_style (:obj:`str`): architecture style of the decoder, choice from 1. "inner product" or 2. "concat query" or 3. "sum query". query_activation (:obj:`nn.Module`): activation function for the query vectors. hidden_activation (:obj:`nn.Module`): activation function for the hidden layers. """ super().__init__() d_in = d_model * 2 if use_batch_labels else d_model if arch_style in ["inner product", "inner product, detach"]: self.gene2query = nn.Linear(d_model, d_model) self.query_activation = query_activation() self.W = nn.Linear(d_model, d_in, bias=False) if explicit_zero_prob: # by default, gene-wise prob rate self.W_zero_logit = nn.Linear(d_model, d_in) elif arch_style == "concat query": self.gene2query = nn.Linear(d_model, 64) self.query_activation = query_activation() self.fc1 = nn.Linear(d_model + 64, 64) self.hidden_activation = hidden_activation() self.fc2 = nn.Linear(64, 1) elif arch_style == "sum query": self.gene2query = nn.Linear(d_model, d_model) self.query_activation = query_activation() self.fc1 = nn.Linear(d_model, 64) self.hidden_activation = hidden_activation() self.fc2 = nn.Linear(64, 1) else: raise ValueError(f"Unknown arch_style: {arch_style}") self.arch_style = arch_style self.do_detach = arch_style.endswith("detach") self.explicit_zero_prob = explicit_zero_prob
[文档] def forward( self, cell_emb: Tensor, gene_embs: Tensor ) -> Union[Tensor, Dict[str, Tensor]]: """ Args: cell_emb: Tensor, shape (batch, embsize=d_model) gene_embs: Tensor, shape (batch, seq_len, embsize=d_model) """ gene_embs = gene_embs.detach() if self.do_detach else gene_embs if self.arch_style in ["inner product", "inner product, detach"]: query_vecs = self.query_activation(self.gene2query(gene_embs)) cell_emb = cell_emb.unsqueeze(2) # (batch, embsize, 1) # the pred gene expr values, # (batch, seq_len) pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2) if not self.explicit_zero_prob: return dict(pred=pred_value) # zero logits need to based on the cell_emb, because of input exprs zero_logits = torch.bmm(self.W_zero_logit(query_vecs), cell_emb).squeeze(2) zero_probs = torch.sigmoid(zero_logits) return dict(pred=pred_value, zero_probs=zero_probs) elif self.arch_style == "concat query": query_vecs = self.query_activation(self.gene2query(gene_embs)) # expand cell_emb to (batch, seq_len, embsize) cell_emb = cell_emb.unsqueeze(1).expand(-1, gene_embs.shape[1], -1) h = self.hidden_activation( self.fc1(torch.cat([cell_emb, query_vecs], dim=2)) ) if self.explicit_zero_prob: raise NotImplementedError return self.fc2(h).squeeze(2) # (batch, seq_len) elif self.arch_style == "sum query": query_vecs = self.query_activation(self.gene2query(gene_embs)) cell_emb = cell_emb.unsqueeze(1) h = self.hidden_activation(self.fc1(cell_emb + query_vecs)) if self.explicit_zero_prob: raise NotImplementedError return self.fc2(h).squeeze(2) # (batch, seq_len)
[文档] class AdversarialDiscriminator(nn.Module): """ Discriminator for the adversarial training for batch correction. """ def __init__( self, d_model: int, n_cls: int, nlayers: int = 3, activation: callable = nn.LeakyReLU, reverse_grad: bool = False, ): super().__init__() # module list self._decoder = nn.ModuleList() for i in range(nlayers - 1): self._decoder.append(nn.Linear(d_model, d_model)) self._decoder.append(activation()) self._decoder.append(nn.LayerNorm(d_model)) self.out_layer = nn.Linear(d_model, n_cls) self.reverse_grad = reverse_grad
[文档] def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [batch_size, embsize] """ if self.reverse_grad: x = grad_reverse(x, lambd=1.0) for layer in self._decoder: x = layer(x) return self.out_layer(x)