scgpt.model.generation_model 源代码
import os
import math
from typing import Mapping, Optional, Tuple, Any, Union
import torch
from torch import nn, Tensor
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.distributions import Bernoulli
from torch.utils.data import dataset
from tqdm import trange
from .model import (
ExprDecoder,
MVCDecoder,
ContinuousValueEncoder,
FastTransformerEncoderWrapper,
FlashTransformerEncoderLayer,
)
from ..utils import map_raw_id_to_vocab_id
from .. import logger
[文档]
class TransformerGenerator(nn.Module):
def __init__(
self,
ntoken: int,
d_model: int,
nhead: int,
d_hid: int,
nlayers: int,
nlayers_cls: int,
n_cls: int,
vocab: Any,
dropout: float = 0.5,
pad_token: str = "<pad>",
pad_value: int = 0,
pert_pad_id: int = 2,
do_mvc: bool = False,
domain_spec_batchnorm: Union[bool, str] = False,
cell_emb_style: str = "cls",
mvc_decoder_style: str = "inner product",
ecs_threshold: float = 0.3,
explicit_zero_prob: bool = False,
use_fast_transformer: bool = False,
fast_transformer_backend: str = "flash",
pre_norm: bool = False,
):
super().__init__()
self.model_type = "Transformer"
self.d_model = d_model
self.pad_token_id = vocab[pad_token]
self.pad_value = pad_value
self.pert_pad_id = pert_pad_id
self.ecs_threshold = ecs_threshold
self.domain_spec_batchnorm = domain_spec_batchnorm
self.cell_emb_style = cell_emb_style
self.explicit_zero_prob = explicit_zero_prob
self.norm_scheme = "pre" if pre_norm else "post"
if cell_emb_style not in ["cls", "avg-pool", "w-pool"]:
raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}")
self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token])
self.value_encoder = ContinuousValueEncoder(d_model, dropout)
self.pert_encoder = nn.Embedding(3, d_model, padding_idx=pert_pad_id)
print("Using simple batchnorm instead of domain specific batchnorm")
self.bn = nn.BatchNorm1d(d_model, eps=6.1e-5)
if use_fast_transformer:
if fast_transformer_backend == "linear":
self.transformer_encoder = FastTransformerEncoderWrapper(
d_model, nhead, d_hid, nlayers, dropout
)
elif fast_transformer_backend == "flash":
encoder_layers = FlashTransformerEncoderLayer(
d_model,
nhead,
d_hid,
dropout,
batch_first=True,
norm_scheme=self.norm_scheme,
)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
else:
encoder_layers = TransformerEncoderLayer(
d_model, nhead, d_hid, dropout, batch_first=True
)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
# self.decoder = nn.Linear(d_model, 1)
self.decoder = ExprDecoder(
d_model,
explicit_zero_prob=explicit_zero_prob,
)
self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls)
if do_mvc:
self.mvc_decoder = MVCDecoder(
d_model,
arch_style=mvc_decoder_style,
explicit_zero_prob=explicit_zero_prob,
)
self.sim = Similarity(temp=0.5)
self.creterion_cce = nn.CrossEntropyLoss()
self.init_weights()
[文档]
def init_weights(self) -> None:
initrange = 0.1
self.encoder.embedding.weight.data.uniform_(-initrange, initrange)
def _encode(
self,
src: Tensor,
values: Tensor,
input_pert_flags,
src_key_padding_mask: Tensor,
) -> Tensor:
src = self.encoder(src) # (batch, seq_len, embsize)
self.cur_gene_token_embs = src
values = self.value_encoder(values) # (batch, seq_len, embsize)
perts = self.pert_encoder(input_pert_flags) # (batch, seq_len, embsize)
total_embs = src + values + perts
total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1)
output = self.transformer_encoder(
total_embs, src_key_padding_mask=src_key_padding_mask
)
return output # (batch, seq_len, embsize)
def _get_cell_emb_from_layer(
self, layer_output: Tensor, weights: Tensor = None
) -> Tensor:
"""
Args:
layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize)
weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used
when :attr:`self.cell_emb_style` is "w-pool".
Returns:
:obj:`Tensor`: shape (batch, embsize)
"""
if self.cell_emb_style == "cls":
cell_emb = layer_output[:, 0, :] # (batch, embsize)
elif self.cell_emb_style == "avg-pool":
cell_emb = torch.mean(layer_output, dim=1)
elif self.cell_emb_style == "w-pool":
if weights is None:
raise ValueError("weights is required when cell_emb_style is w-pool")
if weights.dim() != 2:
raise ValueError("weights should be 2D")
cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1)
cell_emb = F.normalize(cell_emb, p=2, dim=1) # (batch, embsize)
return cell_emb
[文档]
def forward(
self,
src: Tensor,
values: Tensor,
input_pert_flags: Tensor,
src_key_padding_mask: Tensor,
CLS: bool = False,
CCE: bool = False,
MVC: bool = False,
ECS: bool = False,
do_sample: bool = False,
) -> Mapping[str, Tensor]:
"""
Args:
src (:obj:`Tensor`): token ids, shape [batch_size, seq_len]
values (:obj:`Tensor`): token values, shape [batch_size, seq_len]
src_key_padding_mask (:obj:`Tensor`): mask for src, shape [batch_size,
seq_len]
CLS (:obj:`bool`): if True, return the celltype classification objective
(CLS) output
CCE (:obj:`bool`): if True, return the contrastive cell embedding objective
(CCE) output
MVC (:obj:`bool`): if True, return the masked value prediction for cell
embedding MVC output
ECS (:obj:`bool`): if True, return the elastic cell similarity objective
(ECS) output.
Returns:
dict of output Tensors.
"""
if self.explicit_zero_prob and not do_sample and not self.training:
do_sample = True
logger.warning("Auto set do_sample to True when model is in eval mode.")
transformer_output = self._encode(
src, values, input_pert_flags, src_key_padding_mask
)
output = {}
mlm_output = self.decoder(transformer_output)
if self.explicit_zero_prob and do_sample:
bernoulli = Bernoulli(probs=mlm_output["zero_probs"])
output["mlm_output"] = bernoulli.sample() * mlm_output["pred"]
else:
output["mlm_output"] = mlm_output["pred"] # (batch, seq_len)
if self.explicit_zero_prob:
output["mlm_zero_probs"] = mlm_output["zero_probs"]
cell_emb = self._get_cell_emb_from_layer(transformer_output, values)
if CLS:
output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls)
if MVC:
mvc_output = self.mvc_decoder(
cell_emb,
self.cur_gene_token_embs,
) # (batch, seq_len)
if self.explicit_zero_prob and do_sample:
bernoulli = Bernoulli(probs=mvc_output["zero_probs"])
output["mvc_output"] = bernoulli.sample() * mvc_output["pred"]
else:
output["mvc_output"] = mvc_output["pred"] # (batch, seq_len)
if self.explicit_zero_prob:
output["mvc_zero_probs"] = mvc_output["zero_probs"]
if ECS:
# Here using customized cosine similarity instead of F.cosine_similarity
# to avoid the pytorch issue of similarity larger than 1.0, pytorch # 78064
# normalize the embedding
cell_emb_normed = F.normalize(cell_emb, p=2, dim=1)
cos_sim = torch.mm(cell_emb_normed, cell_emb_normed.t()) # (batch, batch)
# mask out diagnal elements
mask = torch.eye(cos_sim.size(0)).bool().to(cos_sim.device)
cos_sim = cos_sim.masked_fill(mask, 0.0)
# only optimize positive similarities
cos_sim = F.relu(cos_sim)
output["loss_ecs"] = torch.mean(1 - (cos_sim - self.ecs_threshold) ** 2)
return output
[文档]
def encode_batch(
self,
src: Tensor,
values: Tensor,
src_key_padding_mask: Tensor,
batch_size: int,
output_to_cpu: bool = True,
) -> Tensor:
"""
Args:
src: Tensor, shape [N, seq_len]
values: Tensor, shape [N, seq_len]
src_key_padding_mask: Tensor, shape [N, seq_len]
Returns:
output Tensor of shape [N, seq_len, embsize]
"""
outputs = []
N = src.size(0)
device = next(self.parameters()).device
for i in trange(0, N, batch_size):
output = self._encode(
src[i : i + batch_size].to(device),
values[i : i + batch_size].to(device),
src_key_padding_mask[i : i + batch_size].to(device),
)
if output_to_cpu:
output = output.cpu()
outputs.append(output)
return torch.cat(outputs, dim=0)
[文档]
def pred_perturb(
self,
batch_data,
include_zero_gene="batch-wise",
gene_ids=None,
amp=True,
) -> Tensor:
"""
Args:
batch_data: a dictionary of input data with keys.
Returns:
output Tensor of shape [N, seq_len]
"""
self.eval()
device = next(self.parameters()).device
batch_data.to(device)
batch_size = len(batch_data.pert)
x: torch.Tensor = batch_data.x
ori_gene_values = x[:, 0].view(batch_size, -1) # (batch_size, n_genes)
pert_flags = x[:, 1].long().view(batch_size, -1)
if include_zero_gene in ["all", "batch-wise"]:
assert gene_ids is not None
if include_zero_gene == "all":
input_gene_ids = torch.arange(ori_gene_values.size(1), device=device)
else: # batch-wise
input_gene_ids = (
ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
)
input_values = ori_gene_values[:, input_gene_ids]
input_pert_flags = pert_flags[:, input_gene_ids]
mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)
src_key_padding_mask = torch.zeros_like(
input_values, dtype=torch.bool, device=device
)
with torch.cuda.amp.autocast(enabled=amp):
output_dict = self(
mapped_input_gene_ids,
input_values,
input_pert_flags,
src_key_padding_mask=src_key_padding_mask,
CLS=False,
CCE=False,
MVC=False,
ECS=False,
do_sample=True,
)
output_values = output_dict["mlm_output"].float()
pred_gene_values = torch.zeros_like(ori_gene_values)
pred_gene_values[:, input_gene_ids] = output_values
return pred_gene_values
[文档]
def generate_square_subsequent_mask(sz: int) -> Tensor:
"""Generates an upper-triangular matrix of -inf, with zeros on diag."""
return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)
[文档]
class GeneEncoder(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
):
super().__init__()
self.embedding = nn.Embedding(
num_embeddings, embedding_dim, padding_idx=padding_idx
)
self.enc_norm = nn.LayerNorm(embedding_dim)
[文档]
def forward(self, x: Tensor) -> Tensor:
x = self.embedding(x) # (batch, seq_len, embsize)
x = self.enc_norm(x)
return x
[文档]
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
[文档]
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[: x.size(0)]
return self.dropout(x)
[文档]
class Similarity(nn.Module):
"""
Dot product or cosine similarity
"""
def __init__(self, temp):
super().__init__()
self.temp = temp
self.cos = nn.CosineSimilarity(dim=-1)
[文档]
class ClsDecoder(nn.Module):
"""
Decoder for classification task.
"""
def __init__(
self,
d_model: int,
n_cls: int,
nlayers: int = 3,
activation: callable = nn.ReLU,
):
super().__init__()
# module list
self._decoder = nn.ModuleList()
for i in range(nlayers - 1):
self._decoder.append(nn.Linear(d_model, d_model))
self._decoder.append(activation())
self._decoder.append(nn.LayerNorm(d_model))
self.out_layer = nn.Linear(d_model, n_cls)
[文档]
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [batch_size, embsize]
"""
for layer in self._decoder:
x = layer(x)
return self.out_layer(x)