scgpt.model.model 源代码
import gc
import math
from typing import Dict, Mapping, Optional, Tuple, Any, Union
import warnings
import torch
import numpy as np
from torch import nn, Tensor
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.distributions import Bernoulli
from tqdm import trange
from zmq import device
try:
from flash_attn.flash_attention import FlashMHA
from .flash_layers import FlashscGPTLayer, FlashscGPTGenerator
except ImportError:
import warnings
warnings.warn("flash_attn is not installed")
from .dsbn import DomainSpecificBatchNorm1d
from .grad_reverse import grad_reverse
[文档]
class TransformerModel(nn.Module):
def __init__(
self,
ntoken: int,
d_model: int,
nhead: int,
d_hid: int,
nlayers: int,
nlayers_cls: int = 3,
n_cls: int = 1,
vocab: Any = None,
dropout: float = 0.5,
pad_token: str = "<pad>",
pad_value: int = 0,
do_mvc: bool = False,
do_dab: bool = False,
use_batch_labels: bool = False,
num_batch_labels: Optional[int] = None,
domain_spec_batchnorm: Union[bool, str] = False,
input_emb_style: str = "continuous",
n_input_bins: Optional[int] = None,
cell_emb_style: str = "cls",
mvc_decoder_style: str = "inner product",
ecs_threshold: float = 0.3,
explicit_zero_prob: bool = False,
use_generative_training=False,
use_fast_transformer: bool = False,
fast_transformer_backend: str = "flash",
pre_norm: bool = False,
use_sim_decoder: bool = False,
):
super().__init__()
self.model_type = "Transformer"
self.d_model = d_model
self.do_dab = do_dab
self.ecs_threshold = ecs_threshold
self.use_batch_labels = use_batch_labels
self.domain_spec_batchnorm = domain_spec_batchnorm
self.input_emb_style = input_emb_style
self.cell_emb_style = cell_emb_style
self.explicit_zero_prob = explicit_zero_prob
self.norm_scheme = "pre" if pre_norm else "post"
if self.input_emb_style not in ["category", "continuous", "scaling"]:
raise ValueError(
f"input_emb_style should be one of category, continuous, scaling, "
f"got {input_emb_style}"
)
if cell_emb_style not in ["cls", "avg-pool", "w-pool"]:
raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}")
# TODO: add dropout in the GeneEncoder
self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token])
self.flag_encoder = nn.Embedding(2, d_model)
# Value Encoder, NOTE: the scaling style is also handled in _encode method
if input_emb_style == "continuous":
self.value_encoder = ContinuousValueEncoder(d_model, dropout)
elif input_emb_style == "category":
assert n_input_bins > 0
self.value_encoder = CategoryValueEncoder(
n_input_bins, d_model, padding_idx=pad_value
)
else:
self.value_encoder = nn.Identity() # nn.Softmax(dim=1)
# TODO: consider row-wise normalization or softmax
# TODO: Correct handle the mask_value when using scaling
# Batch Encoder
if use_batch_labels:
self.batch_encoder = BatchLabelEncoder(num_batch_labels, d_model)
if domain_spec_batchnorm:
use_affine = True if domain_spec_batchnorm == "do_affine" else False
print(f"Use domain specific batchnorm with affine={use_affine}")
self.dsbn = DomainSpecificBatchNorm1d(
d_model, num_batch_labels, eps=6.1e-5, affine=use_affine
)
# else:
# print("Using simple batchnorm instead of domain specific batchnorm")
# self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5)
# bug
if use_generative_training:
encoder_layers = FlashscGPTLayer(
d_model,
nhead,
d_hid,
dropout,
batch_first=True,
norm_scheme=self.norm_scheme,
# device=device,
dtype=torch.float32
)
# print(f"1 use_generative_training: {use_generative_training}")
self.transformer_encoder = FlashscGPTGenerator(encoder_layers, nlayers)
# print(f"2 self.transformer_encoder: {self.transformer_encoder}")
elif use_fast_transformer:
# if use_fast_transformer:
if fast_transformer_backend == "linear":
self.transformer_encoder = FastTransformerEncoderWrapper(
d_model, nhead, d_hid, nlayers, dropout
)
elif fast_transformer_backend == "flash":
encoder_layers = FlashTransformerEncoderLayer(
d_model,
nhead,
d_hid,
dropout,
batch_first=True,
norm_scheme=self.norm_scheme,
)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
else:
encoder_layers = TransformerEncoderLayer(
d_model, nhead, d_hid, dropout, batch_first=True
)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
# print(f"3 self.transformer_encoder: {type(self.transformer_encoder)}")
self.decoder = ExprDecoder(
d_model,
explicit_zero_prob=explicit_zero_prob,
use_batch_labels=use_batch_labels,
)
if n_cls > 1:
if use_sim_decoder:
self.cls_decoder = SimDecoder(d_model, n_cls, nlayers=nlayers_cls)
else:
self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls)
if do_mvc:
self.mvc_decoder = MVCDecoder(
d_model,
arch_style=mvc_decoder_style,
explicit_zero_prob=explicit_zero_prob,
use_batch_labels=use_batch_labels,
)
if do_dab:
self.grad_reverse_discriminator = AdversarialDiscriminator(
d_model,
n_cls=num_batch_labels,
reverse_grad=True,
)
# self.sim = Similarity(temp=0.5) # TODO: auto set temp
# self.creterion_cce = nn.CrossEntropyLoss()
self.init_weights()
[文档]
def init_weights(self) -> None:
initrange = 0.1
# TODO: check if this initialization is helpful and shall we apply to all?
self.encoder.embedding.weight.data.uniform_(-initrange, initrange)
def _encode(
self,
src: Tensor,
values: Tensor,
src_key_padding_mask: Tensor,
batch_labels: Optional[Tensor] = None, # (batch,)
) -> Tensor:
self._check_batch_labels(batch_labels)
# print("*"*200)
# print(f"src: {type(src)} {src.dtype} {src.shape}")
# print(f"values: {type(values)} {values.dtype} {values.shape}")
src = self.encoder(src) # (batch, seq_len, embsize)
# print(f"src: {type(src)} {src.shape}")
self.cur_gene_token_embs = src
values = self.value_encoder(values) # (batch, seq_len, embsize)
if self.input_emb_style == "scaling":
values = values.unsqueeze(2)
total_embs = src * values
else:
total_embs = src + values
if self.domain_spec_batchnorm:
batch_label = int(batch_labels[0].item())
total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute(
0, 2, 1
) # the batch norm always works on dim 1
# else:
# total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1)
# print(f"total_embs: {type(total_embs)} {total_embs.dtype} {total_embs.shape}")
# print(f"src_key_padding_mask: {type(src_key_padding_mask)} {src_key_padding_mask.dtype} {src_key_padding_mask.shape}")
output = self.transformer_encoder(
total_embs, src_key_padding_mask=src_key_padding_mask
)
return output # (batch, seq_len, embsize)
[文档]
def transformer_generate(
self,
pcpt_genes: Tensor,
pcpt_values: Tensor,
pcpt_key_padding_mask: Tensor,
gen_genes: Tensor,
gen_key_padding_mask: Tensor,
batch_labels: Optional[Tensor] = None, # (batch,)
input_cell_emb: Optional[Tensor] = None, # (batch, seq_len, embsize)
) -> Tuple[Tensor, Tensor]:
self._check_batch_labels(batch_labels)
pcpt_token_embs = self.encoder(pcpt_genes) # (batch, pcpt_len, embsize)
pcpt_values = self.value_encoder(pcpt_values) # (batch, pcpt_len, embsize)
pcpt_total_embs = pcpt_token_embs + pcpt_values
assert self.input_emb_style != "scaling"
if gen_genes is not None:
gen_token_embs = self.encoder(gen_genes) # (batch, gen_len, embsize)
self.cur_gene_token_embs = torch.cat(
[pcpt_token_embs, gen_token_embs], dim=1
)
gen_flags = self.flag_encoder(
torch.tensor(1).to(pcpt_values.device)
).expand(gen_genes.shape[0], gen_genes.shape[1], -1)
gen_total_embs = gen_token_embs + gen_flags
else:
self.cur_gene_token_embs = pcpt_token_embs
gen_total_embs = None
if self.domain_spec_batchnorm:
batch_label = int(batch_labels[0].item())
pcpt_total_embs = self.dsbn(
pcpt_total_embs.permute(0, 2, 1), batch_label
).permute(0, 2, 1)
if gen_genes is not None:
gen_total_embs = self.dsbn(
gen_total_embs.permute(0, 2, 1), batch_label
).permute(0, 2, 1)
# else:
# pcpt_total_embs = self.bn(pcpt_total_embs.permute(0, 2, 1)).permute(0, 2, 1)
# if gen_genes is not None:
# gen_total_embs = self.bn(gen_total_embs.permute(0, 2, 1)).permute(
# 0, 2, 1
# )
if input_cell_emb is not None:
pcpt_total_embs[:, 0, :] = input_cell_emb
# print(f"pcpt_total_embs: {type(pcpt_total_embs)} {pcpt_total_embs.dtype}")
# print(f"gen_total_embs: {type(gen_total_embs)} {gen_total_embs.dtype}")
# print(f"pcpt_key_padding_mask: {type(pcpt_key_padding_mask)} {pcpt_key_padding_mask.dtype}")
# print(f"gen_key_padding_mask: {type(gen_key_padding_mask)} {gen_key_padding_mask.dtype}")
pcpt_output, gen_output = self.transformer_encoder(
pcpt_total_embs,
gen_total_embs,
pcpt_key_padding_mask=pcpt_key_padding_mask,
gen_key_padding_mask=gen_key_padding_mask,
)
return pcpt_output, gen_output
def _get_cell_emb_from_layer(
self, layer_output: Tensor, weights: Tensor = None
) -> Tensor:
"""
Args:
layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize)
weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used
when :attr:`self.cell_emb_style` is "w-pool".
Returns:
:obj:`Tensor`: shape (batch, embsize)
"""
if self.cell_emb_style == "cls":
cell_emb = layer_output[:, 0, :] # (batch, embsize)
elif self.cell_emb_style == "avg-pool":
cell_emb = torch.mean(layer_output, dim=1)
elif self.cell_emb_style == "w-pool":
if weights is None:
raise ValueError("weights is required when cell_emb_style is w-pool")
if weights.dim() != 2:
raise ValueError("weights should be 2D")
cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1)
cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize)
return cell_emb
def _check_batch_labels(self, batch_labels: Tensor) -> None:
if self.use_batch_labels or self.domain_spec_batchnorm:
assert batch_labels is not None
elif batch_labels is not None:
raise ValueError(
"batch_labels should only be provided when `self.use_batch_labels`"
" or `self.domain_spec_batchnorm` is True"
)
[文档]
def generate(
self,
cell_emb: Tensor,
src: Tensor,
values: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
gen_iters: int = 1,
batch_labels: Optional[Tensor] = None, # (batch,)
) -> Tensor:
"""
Args:
cell_emb(:obj:`Tensor`): shape (batch, embsize)
src(:obj:`Tensor`): shape (batch, seq_len)
values(:obj:`Tensor`): shape (batch, seq_len), optional
src_key_padding_mask(:obj:`Tensor`): shape (batch, seq_len), optional
gen_iters(:obj:`int`): number of generation iterations
batch_labels(:obj:`Tensor`): shape (batch,), optional
"""
# TODO: should have a tag indicate the generation mode
# TODO: if gen_iters > 1, should have a tag indicate the current iteration
try:
self._check_batch_labels(batch_labels)
except:
warnings.warn(
"batch_labels is required but not provided, using zeros instead"
)
batch_labels = torch.zeros(
cell_emb.shape[0], dtype=torch.long, device=cell_emb.device
)
src = self.encoder(src) # (batch, seq_len, embsize)
if values is not None:
values = self.value_encoder(values) # (batch, seq_len, embsize)
if self.input_emb_style == "scaling":
values = values.unsqueeze(2)
total_embs = src * values
else:
total_embs = src + values
else:
total_embs = src
if self.domain_spec_batchnorm:
batch_label = int(batch_labels[0].item())
total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute(
0, 2, 1
) # the batch norm always works on dim 1
else:
total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1)
total_embs[:, 0, :] = cell_emb
if src_key_padding_mask is None:
src_key_padding_mask = torch.zeros(
total_embs.shape[:2], dtype=torch.bool, device=total_embs.device
)
transformer_output = self.transformer_encoder(
total_embs, src_key_padding_mask=src_key_padding_mask
)
if self.use_batch_labels:
batch_emb = self.batch_encoder(batch_labels) # (batch, embsize)
mlm_output = self.decoder(
transformer_output
if not self.use_batch_labels
else torch.cat(
[
transformer_output,
batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1),
],
dim=2,
),
# else transformer_output + batch_emb.unsqueeze(1),
)
output = mlm_output["pred"] # (batch, seq_len)
return output # (batch, seq_len)
def _extend_output(
self,
output: Mapping[str, Tensor],
transformer_output: Tensor,
batch_emb: Optional[Tensor] = None,
CLS: bool = False,
MVC: bool = False,
ECS: bool = False,
do_sample: bool = False,
) -> Mapping[str, Tensor]:
cell_emb = self._get_cell_emb_from_layer(transformer_output)
output["cell_emb"] = cell_emb
if CLS:
output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls)
if MVC:
mvc_output = self.mvc_decoder(
cell_emb
if not self.use_batch_labels
else torch.cat([cell_emb, batch_emb], dim=1),
# else cell_emb + batch_emb,
self.cur_gene_token_embs,
)
if self.explicit_zero_prob and do_sample:
bernoulli = Bernoulli(probs=mvc_output["zero_probs"])
output["mvc_output"] = bernoulli.sample() * mvc_output["pred"]
else:
output["mvc_output"] = mvc_output["pred"] # (batch, seq_len)
if self.explicit_zero_prob:
output["mvc_zero_probs"] = mvc_output["zero_probs"]
if ECS:
# Here using customized cosine similarity instead of F.cosine_similarity
# to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064
# normalize the embedding
cell_emb_normed = F.normalize(cell_emb, p=2, dim=1)
cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch)
# mask out diagnal elements
mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device)
cos_sim = cos_sim.masked_fill(mask, 0.0)
# only optimize positive similarities
cos_sim = F.relu(cos_sim)
output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2)
if self.do_dab:
output["dab_output"] = self.grad_reverse_discriminator(cell_emb)
return output
[文档]
def forward(
self,
*args,
**kwargs,
) -> Mapping[str, Tensor]:
"""
Wrapper to call either generative_forward or perceptual_forward, depending
on the value of the "generative_training" kwarg.
"""
if "generative_training" not in kwargs:
# raise ValueError("generative_training kwarg is required")
warnings.warn(
"generative_training kwarg is required but not provided! "
"Using False and calling perceptual_forward instead"
)
return self.perceptual_forward(*args, **kwargs)
# get the generative training flag and pop it out
do_generative_training = kwargs.pop("generative_training")
if do_generative_training:
return self.generative_forward(*args, **kwargs)
else:
return self.perceptual_forward(*args, **kwargs)
[文档]
def generative_forward(
self,
pcpt_genes: Tensor,
pcpt_values: Tensor,
pcpt_key_padding_mask: Tensor,
gen_genes: Tensor,
gen_key_padding_mask: Tensor,
batch_labels: Optional[Tensor] = None,
CLS: bool = False,
CCE: bool = False,
MVC: bool = False,
ECS: bool = False,
do_sample: bool = False,
input_cell_emb: Optional[Tensor] = None,
) -> Mapping[str, Tensor]:
"""
Args:
pcpt_genes (:obj:`Tensor`): token ids of the perceptual part, shape
[batch_size, seq_len]
pcpt_values (:obj:`Tensor`): token values of the perceptual part, shape
[batch_size, seq_len]
pcpt_key_padding_mask (:obj:`Tensor`): mask for pcpt_genes, shape
[batch_size, seq_len]
gen_genes (:obj:`Tensor`): token ids of the generative part, shape
[batch_size, seq_len]
gen_key_padding_mask (:obj:`Tensor`): mask for gen_genes, shape
[batch_size, seq_len]
batch_labels (:obj:`Tensor`): batch labels, shape [batch_size]
do_sample (:obj:`bool`): whether to do sampling from bernoulli for
generated zero predictions.
input_cell_emb (:obj:`Tensor`): cell embeddings, shape [batch_size,
embsize]
Returns:
:obj:`Mapping[str, Tensor]`:
- pred (:obj:`Tensor`): prediction, shape [batch_size, seq_len]
- cell_emb (:obj:`Tensor`): cell embeddings, shape [batch_size,
embsize]
"""
pcpt_output, gen_output = self.transformer_generate(
pcpt_genes,
pcpt_values,
pcpt_key_padding_mask,
gen_genes,
gen_key_padding_mask,
batch_labels,
input_cell_emb=input_cell_emb,
)
if gen_output is None:
transformer_output = pcpt_output
else:
transformer_output = torch.cat([pcpt_output, gen_output], dim=1)
if self.use_batch_labels:
batch_emb = self.batch_encoder(batch_labels)
output = {}
decoder_output = self.decoder(
transformer_output
if not self.use_batch_labels
else torch.cat(
[
transformer_output,
batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1),
],
dim=2,
),
)
if self.explicit_zero_prob and do_sample:
bernoulli = Bernoulli(probs=decoder_output["zero_probs"])
full_preds = bernoulli.sample() * decoder_output["pred"]
output["pcpt_preds"] = full_preds[:, : pcpt_genes.shape[1]]
output["gen_preds"] = full_preds[:, pcpt_genes.shape[1] :]
else:
full_preds = decoder_output["pred"] # (batch, seq_len)
output["pcpt_preds"] = full_preds[:, : pcpt_genes.shape[1]]
output["gen_preds"] = full_preds[:, pcpt_genes.shape[1] :]
if self.explicit_zero_prob:
output["zero_probs"] = decoder_output["zero_probs"]
output = self._extend_output(
output,
transformer_output,
batch_emb=batch_emb if self.use_batch_labels else None,
CLS=CLS,
MVC=MVC,
ECS=ECS,
do_sample=do_sample,
)
return output
[文档]
def perceptual_forward(
self,
src: Tensor,
values: Tensor,
src_key_padding_mask: Tensor,
batch_labels: Optional[Tensor] = None,
CLS: bool = False,
CCE: bool = False,
MVC: bool = False,
ECS: bool = False,
do_sample: bool = False,
) -> Mapping[str, Tensor]:
"""
Args:
src (:obj:`Tensor`): token ids, shape [batch_size, seq_len]
values (:obj:`Tensor`): token values, shape [batch_size, seq_len]
src_key_padding_mask (:obj:`Tensor`): mask for src, shape [batch_size,
seq_len]
batch_labels (:obj:`Tensor`): batch labels, shape [batch_size]
CLS (:obj:`bool`): if True, return the celltype classification objective
(CLS) output
CCE (:obj:`bool`): if True, return the contrastive cell embedding objective
(CCE) output
MVC (:obj:`bool`): if True, return the masked value prediction for cell
embedding MVC output
ECS (:obj:`bool`): if True, return the elastic cell similarity objective
(ECS) output.
Returns:
dict of output Tensors.
"""
transformer_output = self._encode(
src, values, src_key_padding_mask, batch_labels
)
if self.use_batch_labels:
batch_emb = self.batch_encoder(batch_labels) # (batch, embsize)
output = {}
mlm_output = self.decoder(
transformer_output
if not self.use_batch_labels
else torch.cat(
[
transformer_output,
batch_emb.unsqueeze(1).repeat(1, transformer_output.shape[1], 1),
],
dim=2,
),
# else transformer_output + batch_emb.unsqueeze(1),
)
if self.explicit_zero_prob and do_sample:
bernoulli = Bernoulli(probs=mlm_output["zero_probs"])
output["mlm_output"] = bernoulli.sample() * mlm_output["pred"]
else:
output["mlm_output"] = mlm_output["pred"] # (batch, seq_len)
if self.explicit_zero_prob:
output["mlm_zero_probs"] = mlm_output["zero_probs"]
output = self._extend_output(
output,
transformer_output,
batch_emb=batch_emb if self.use_batch_labels else None,
CLS=CLS,
MVC=MVC,
ECS=ECS,
do_sample=do_sample,
)
return output
[文档]
def encode_batch(
self,
src: Tensor,
values: Tensor,
src_key_padding_mask: Tensor,
batch_size: int,
batch_labels: Optional[Tensor] = None,
output_to_cpu: bool = True,
time_step: Optional[int] = None,
return_np: bool = False,
) -> Tensor:
"""
Args:
src (Tensor): shape [N, seq_len]
values (Tensor): shape [N, seq_len]
src_key_padding_mask (Tensor): shape [N, seq_len]
batch_size (int): batch size for encoding
batch_labels (Tensor): shape [N, n_batch_labels]
output_to_cpu (bool): whether to move the output to cpu
time_step (int): the time step index in the transformer output to return.
The time step is along the second dimenstion. If None, return all.
return_np (bool): whether to return numpy array
Returns:
output Tensor of shape [N, seq_len, embsize]
"""
N = src.size(0)
device = next(self.parameters()).device
# initialize the output tensor
array_func = np.zeros if return_np else torch.zeros
float32_ = np.float32 if return_np else torch.float32
shape = (
(N, self.d_model)
if time_step is not None
else (N, src.size(1), self.d_model)
)
outputs = array_func(shape, dtype=float32_)
for i in trange(0, N, batch_size):
raw_output = self._encode(
src[i : i + batch_size].to(device),
values[i : i + batch_size].to(device),
src_key_padding_mask[i : i + batch_size].to(device),
batch_labels[i : i + batch_size].to(device)
if batch_labels is not None
else None,
)
output = raw_output.detach()
if output_to_cpu:
output = output.cpu()
if return_np:
output = output.numpy()
if time_step is not None:
output = output[:, time_step, :]
outputs[i : i + batch_size] = output
return outputs
[文档]
def generate_square_subsequent_mask(sz: int) -> Tensor:
"""Generates an upper-triangular matrix of -inf, with zeros on diag."""
return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)
[文档]
class FastTransformerEncoderWrapper(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
d_hid: int,
nlayers: int,
dropout: float = 0.5,
):
super().__init__()
self.fast_transformer_encoder = self.build_fast_transformer_encoder(
d_model, nhead, d_hid, nlayers, dropout
)
[文档]
@staticmethod
def build_fast_transformer_encoder(
d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float
) -> nn.Module:
from fast_transformers.builders import TransformerEncoderBuilder
if d_model % nhead != 0:
raise ValueError(
f"d_model must be divisible by nhead, "
f"got d_model={d_model} and nhead={nhead}"
)
builder = TransformerEncoderBuilder.from_kwargs(
n_layers=nlayers,
n_heads=nhead,
query_dimensions=d_model // nhead,
value_dimensions=d_model // nhead,
feed_forward_dimensions=d_hid,
attention_type="linear",
attention_dropout=dropout,
dropout=dropout,
activation="gelu",
)
assert builder.attention_type == "linear"
return builder.get()
[文档]
@staticmethod
def build_length_mask(
src: Tensor,
src_key_padding_mask: torch.BoolTensor,
) -> "LengthMask":
from fast_transformers.masking import LengthMask
seq_len = src.shape[1]
num_paddings = src_key_padding_mask.sum(dim=1)
actual_seq_len = seq_len - num_paddings # (N,)
length_mask = LengthMask(actual_seq_len, max_len=seq_len, device=src.device)
if src_key_padding_mask[length_mask.bool_matrix].sum() != 0:
raise ValueError(
"Found padding tokens in the middle of the sequence. "
"src_key_padding_mask and length_mask are not compatible."
)
return length_mask
[文档]
def forward(
self,
src: Tensor,
src_key_padding_mask: torch.BoolTensor,
) -> Tensor:
"""
Args:
src: Tensor, shape [N, seq_len, embsize]
src_key_padding_mask: Tensor, shape [N, seq_len]
Returns:
output Tensor of shape [N, seq_len, embsize]
"""
if src_key_padding_mask.shape != src.shape[:2]:
raise ValueError(
f"src_key_padding_mask shape {src_key_padding_mask.shape} "
f"does not match first two dims of src shape {src.shape[:2]}"
)
if src_key_padding_mask.dtype != torch.bool:
raise ValueError(
f"src_key_padding_mask needs to be of type torch.bool, "
f"got {src_key_padding_mask.dtype}"
)
length_mask = self.build_length_mask(src, src_key_padding_mask)
output = self.fast_transformer_encoder(src, length_mask=length_mask)
return output
[文档]
class FlashTransformerEncoderLayer(nn.Module):
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
The class is modified from torch.nn.TransformerEncoderLayer to support the
FlashAttention.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False``.
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
Alternatively, when ``batch_first`` is ``True``:
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
>>> src = torch.rand(32, 10, 512)
>>> out = encoder_layer(src)
"""
__constants__ = ["batch_first"]
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
layer_norm_eps=1e-5,
batch_first=True,
device=None,
dtype=None,
norm_scheme="post", # "pre" or "post"
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.self_attn = FlashMHA(
embed_dim=d_model,
num_heads=nhead,
batch_first=batch_first,
attention_dropout=dropout,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = self._get_activation_fn(activation)
self.norm_scheme = norm_scheme
if self.norm_scheme not in ["pre", "post"]:
raise ValueError(f"norm_scheme should be pre or post, not {norm_scheme}")
@staticmethod
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = F.relu
super().__setstate__(state)
[文档]
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
**kwargs,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
if src_mask is not None:
raise ValueError("FlashTransformerEncoderLayer does not support src_mask")
if not src_key_padding_mask.any().item():
# no padding tokens in src
src_key_padding_mask_ = None
else:
if src_key_padding_mask.dtype != torch.bool:
src_key_padding_mask = src_key_padding_mask.bool()
# NOTE: the FlashMHA uses mask 0 for padding tokens, which is the opposite
src_key_padding_mask_ = ~src_key_padding_mask
if self.norm_scheme == "pre":
src = self.norm1(src)
src2 = self.self_attn(src, key_padding_mask=src_key_padding_mask_)[0]
src = src + self.dropout1(src2)
src = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
else:
src2 = self.self_attn(src, key_padding_mask=src_key_padding_mask_)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
[文档]
class GeneEncoder(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
):
super().__init__()
self.embedding = nn.Embedding(
num_embeddings, embedding_dim, padding_idx=padding_idx
)
self.enc_norm = nn.LayerNorm(embedding_dim)
[文档]
def forward(self, x: Tensor) -> Tensor:
# print(f"enter the GeneEncoder")
# print(f"1 x: {type(x)} {x.shape}") # 4, 1200
x = self.embedding(x) # (batch, seq_len, embsize)
# print(f"2 x: {type(x)} {x.shape}")
x = self.enc_norm(x)
# print(f"3 x: {type(x)} {x.shape}")
return x
[文档]
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
[文档]
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[: x.size(0)]
return self.dropout(x)
[文档]
class ContinuousValueEncoder(nn.Module):
"""
Encode real number values to a vector using neural nets projection.
"""
def __init__(self, d_model: int, dropout: float = 0.1, max_value: int = 512):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.linear1 = nn.Linear(1, d_model)
self.activation = nn.ReLU()
self.linear2 = nn.Linear(d_model, d_model)
self.norm = nn.LayerNorm(d_model)
self.max_value = max_value
[文档]
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [batch_size, seq_len]
"""
# TODO: test using actual embedding layer if input is categorical
# expand last dimension
x = x.unsqueeze(-1)
# clip x to [-inf, max_value]
x = torch.clamp(x, max=self.max_value)
x = self.activation(self.linear1(x))
x = self.linear2(x)
x = self.norm(x)
return self.dropout(x)
[文档]
class CategoryValueEncoder(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
):
super().__init__()
self.embedding = nn.Embedding(
num_embeddings, embedding_dim, padding_idx=padding_idx
)
self.enc_norm = nn.LayerNorm(embedding_dim)
[文档]
def forward(self, x: Tensor) -> Tensor:
x = x.long()
x = self.embedding(x) # (batch, seq_len, embsize)
x = self.enc_norm(x)
return x
[文档]
class BatchLabelEncoder(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
):
super().__init__()
self.embedding = nn.Embedding(
num_embeddings, embedding_dim, padding_idx=padding_idx
)
self.enc_norm = nn.LayerNorm(embedding_dim)
[文档]
def forward(self, x: Tensor) -> Tensor:
x = self.embedding(x) # (batch, embsize)
x = self.enc_norm(x)
return x
[文档]
class Similarity(nn.Module):
"""
Dot product or cosine similarity
"""
def __init__(self, temp):
super().__init__()
self.temp = temp
self.cos = nn.CosineSimilarity(dim=-1)
[文档]
class ExprDecoder(nn.Module):
def __init__(
self,
d_model: int,
explicit_zero_prob: bool = False,
use_batch_labels: bool = False,
):
super().__init__()
d_in = d_model * 2 if use_batch_labels else d_model
self.fc = nn.Sequential(
nn.Linear(d_in, d_model),
nn.LeakyReLU(),
nn.Linear(d_model, d_model),
nn.LeakyReLU(),
nn.Linear(d_model, 1),
)
self.explicit_zero_prob = explicit_zero_prob
if explicit_zero_prob:
self.zero_logit = nn.Sequential(
nn.Linear(d_in, d_model),
nn.LeakyReLU(),
nn.Linear(d_model, d_model),
nn.LeakyReLU(),
nn.Linear(d_model, 1),
)
[文档]
def forward(self, x: Tensor) -> Dict[str, Tensor]:
"""x is the output of the transformer, (batch, seq_len, d_model)"""
pred_value = self.fc(x).squeeze(-1) # (batch, seq_len)
if not self.explicit_zero_prob:
return dict(pred=pred_value)
zero_logits = self.zero_logit(x).squeeze(-1) # (batch, seq_len)
zero_probs = torch.sigmoid(zero_logits)
return dict(pred=pred_value, zero_probs=zero_probs)
# TODO: note that the return currently is only for training. Since decoder
# is not used in the test setting for the integration task, the eval/inference
# logic is not implemented yet. However, remember to implement it when
# the decoder is used in any test setting. The inference logic will need
# to sample from the bernoulli distribution with the zero_probs.
[文档]
class ClsDecoder(nn.Module):
"""
Decoder for classification task.
"""
def __init__(
self,
d_model: int,
n_cls: int,
nlayers: int = 3,
activation: callable = nn.ReLU,
):
super().__init__()
# module list
self._decoder = nn.ModuleList()
for i in range(nlayers - 1):
self._decoder.append(nn.Linear(d_model, d_model))
self._decoder.append(activation())
self._decoder.append(nn.LayerNorm(d_model))
self.out_layer = nn.Linear(d_model, n_cls)
[文档]
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [batch_size, embsize]
"""
for layer in self._decoder:
x = layer(x)
return self.out_layer(x)
[文档]
class SimDecoder(nn.Module):
"""
Decoder for classification task with similarity matrix.
"""
def __init__(
self,
d_model: int,
n_cls: int,
nlayers: int = 3,
activation: callable = nn.ReLU,
projection_dim: int = 2048,
):
super().__init__()
# module list
self._decoder = nn.ModuleList()
for i in range(nlayers - 1):
self._decoder.append(nn.Linear(d_model, d_model))
self._decoder.append(activation())
self._decoder.append(nn.LayerNorm(d_model))
self.out_layer = nn.Linear(d_model, projection_dim)
self.cls_token_matrix = nn.Parameter(torch.randn(n_cls, projection_dim))
self.embed_norm = nn.LayerNorm(projection_dim)
self.token_norm = nn.LayerNorm(projection_dim)
[文档]
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [batch_size, embsize]
"""
for layer in self._decoder:
x = layer(x)
x = self.out_layer(x)
x = self.embed_norm(x)
sim = self.sim_matrix(x, self.cls_token_matrix)
return sim
[文档]
def sim_matrix(self, a, b, eps=1e-8):
# b = self.token_norm(b)
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
# sim_mt = torch.mm(a, b.transpose(0, 1))
return sim_mt
[文档]
class MVCDecoder(nn.Module):
"""
Decoder for the masked value prediction for cell embeddings.
"""
def __init__(
self,
d_model: int,
arch_style: str = "inner product",
query_activation: nn.Module = nn.Sigmoid,
hidden_activation: nn.Module = nn.PReLU,
explicit_zero_prob: bool = False,
use_batch_labels: bool = False,
) -> None:
"""
Args:
d_model (:obj:`int`): dimension of the gene embedding.
arch_style (:obj:`str`): architecture style of the decoder, choice from
1. "inner product" or 2. "concat query" or 3. "sum query".
query_activation (:obj:`nn.Module`): activation function for the query
vectors.
hidden_activation (:obj:`nn.Module`): activation function for the hidden
layers.
"""
super().__init__()
d_in = d_model * 2 if use_batch_labels else d_model
if arch_style in ["inner product", "inner product, detach"]:
self.gene2query = nn.Linear(d_model, d_model)
self.query_activation = query_activation()
self.W = nn.Linear(d_model, d_in, bias=False)
if explicit_zero_prob: # by default, gene-wise prob rate
self.W_zero_logit = nn.Linear(d_model, d_in)
elif arch_style == "concat query":
self.gene2query = nn.Linear(d_model, 64)
self.query_activation = query_activation()
self.fc1 = nn.Linear(d_model + 64, 64)
self.hidden_activation = hidden_activation()
self.fc2 = nn.Linear(64, 1)
elif arch_style == "sum query":
self.gene2query = nn.Linear(d_model, d_model)
self.query_activation = query_activation()
self.fc1 = nn.Linear(d_model, 64)
self.hidden_activation = hidden_activation()
self.fc2 = nn.Linear(64, 1)
else:
raise ValueError(f"Unknown arch_style: {arch_style}")
self.arch_style = arch_style
self.do_detach = arch_style.endswith("detach")
self.explicit_zero_prob = explicit_zero_prob
[文档]
def forward(
self, cell_emb: Tensor, gene_embs: Tensor
) -> Union[Tensor, Dict[str, Tensor]]:
"""
Args:
cell_emb: Tensor, shape (batch, embsize=d_model)
gene_embs: Tensor, shape (batch, seq_len, embsize=d_model)
"""
gene_embs = gene_embs.detach() if self.do_detach else gene_embs
if self.arch_style in ["inner product", "inner product, detach"]:
query_vecs = self.query_activation(self.gene2query(gene_embs))
cell_emb = cell_emb.unsqueeze(2) # (batch, embsize, 1)
# the pred gene expr values, # (batch, seq_len)
pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2)
if not self.explicit_zero_prob:
return dict(pred=pred_value)
# zero logits need to based on the cell_emb, because of input exprs
zero_logits = torch.bmm(self.W_zero_logit(query_vecs), cell_emb).squeeze(2)
zero_probs = torch.sigmoid(zero_logits)
return dict(pred=pred_value, zero_probs=zero_probs)
elif self.arch_style == "concat query":
query_vecs = self.query_activation(self.gene2query(gene_embs))
# expand cell_emb to (batch, seq_len, embsize)
cell_emb = cell_emb.unsqueeze(1).expand(-1, gene_embs.shape[1], -1)
h = self.hidden_activation(
self.fc1(torch.cat([cell_emb, query_vecs], dim=2))
)
if self.explicit_zero_prob:
raise NotImplementedError
return self.fc2(h).squeeze(2) # (batch, seq_len)
elif self.arch_style == "sum query":
query_vecs = self.query_activation(self.gene2query(gene_embs))
cell_emb = cell_emb.unsqueeze(1)
h = self.hidden_activation(self.fc1(cell_emb + query_vecs))
if self.explicit_zero_prob:
raise NotImplementedError
return self.fc2(h).squeeze(2) # (batch, seq_len)
[文档]
class AdversarialDiscriminator(nn.Module):
"""
Discriminator for the adversarial training for batch correction.
"""
def __init__(
self,
d_model: int,
n_cls: int,
nlayers: int = 3,
activation: callable = nn.LeakyReLU,
reverse_grad: bool = False,
):
super().__init__()
# module list
self._decoder = nn.ModuleList()
for i in range(nlayers - 1):
self._decoder.append(nn.Linear(d_model, d_model))
self._decoder.append(activation())
self._decoder.append(nn.LayerNorm(d_model))
self.out_layer = nn.Linear(d_model, n_cls)
self.reverse_grad = reverse_grad
[文档]
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [batch_size, embsize]
"""
if self.reverse_grad:
x = grad_reverse(x, lambd=1.0)
for layer in self._decoder:
x = layer(x)
return self.out_layer(x)