from functools import lru_cache
import math
from typing import Optional
from numpy import dtype
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.transformer import _get_clones
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.modules.mha import FlashCrossAttention
from .layers import MultiheadAttention
[文档]
class FlashscGPTMHA(nn.Module):
"""
Custom MHA layer for scGPT. This takes two separate forward passes on the pect
genes, and on the gen genes.
"""
def __init__(
self,
embed_dim,
num_heads,
bias=True,
batch_first=True,
attention_dropout=0.0,
causal=False,
device=None,
dtype=None,
) -> None:
assert batch_first
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.num_heads = num_heads
assert (
self.embed_dim % num_heads == 0
), "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert (
self.head_dim % 8 == 0 and self.head_dim <= 128
), "Only support head_dim <= 128 and divisible by 8"
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.self_attn = FlashAttention(attention_dropout=attention_dropout)
self.cross_attn = MultiheadAttention(
embed_dim,
num_heads,
dropout=attention_dropout,
batch_first=batch_first,
**factory_kwargs,
)
# self.cross_attn = FlashCrossAttention(attention_dropout=attention_dropout)
# for cross attetion, launch multiple queries in parallel, each query is just
# a single gen gene. Then each kv is the entire set of pect genes plus this gen
# gene together.
# In practice, we can simply put these queries in the batch dimension, and then
# they can be processed in parallel.
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
[文档]
def forward(
self,
pcpt_total_embs: Tensor,
gen_total_embs: Tensor,
pcpt_key_padding_mask: Optional[Tensor] = None,
gen_key_padding_mask: Optional[Tensor] = None,
need_weights=False,
):
"""
pcpt_total_embs: (batch, pcpt_len, hidden_dim) (where hidden_dim = num heads * head dim)
gen_total_embs: (batch, gen_len, hidden_dim)
pcpt_key_padding_mask: bool tensor of shape (batch, pcpt_len), 1 means valid and 0 means not valid.
gen_key_padding_mask: bool tensor of shape (batch, gen_len), 1 means valid and 0 means not valid.
"""
pcpt_qkv = self.Wqkv(pcpt_total_embs)
pcpt_qkv = rearrange(
pcpt_qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
)
# full self attention on pcpt genes
pcpt_context, pcpt_attn_weights = self.self_attn(
pcpt_qkv,
key_padding_mask=pcpt_key_padding_mask,
need_weights=need_weights,
causal=self.causal,
)
# print(f"pcpt_qkv: {type(pcpt_qkv)} {pcpt_qkv.dtype}")
# print(f"pcpt_key_padding_mask: {type(pcpt_key_padding_mask)} {pcpt_key_padding_mask.dtype}")
# print(f"need_weights: {type(need_weights)} {need_weights}")
# print(f"pcpt_context: {type(pcpt_context)} {pcpt_context.dtype}") # torch.float16
# print(f"pcpt_attn_weights: {type(pcpt_attn_weights)} {pcpt_attn_weights}") #
temp = rearrange(pcpt_context, "b s h d -> b s (h d)")
# print(f"temp ... 1: {type(temp)} {temp.dtype} {temp.shape}") # torch.float16
temp = temp.to(torch.float32)
pcpt_context = self.out_proj(temp)
# pcpt_context = self.out_proj(rearrange(pcpt_context, "b s h d -> b s (h d)"))
# print(f"pcpt_context: {type(pcpt_context)} {pcpt_context.dtype} {pcpt_context.shape}")
if gen_total_embs is None:
return (pcpt_context, None), (pcpt_attn_weights, None)
gen_qkv = self.Wqkv(gen_total_embs)
gen_qkv = rearrange(
gen_qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
)
# CROSS ATTENTION USING RAW PYTORCH IMPLEMENTATION
cross_q = gen_qkv[:, :, 0, :, :] # (batch, gen_len, nheads, head_dim)
cross_q = rearrange(cross_q, "b gen_s h d -> b gen_s (h d)")
cross_kv = torch.cat(
[pcpt_qkv[:, :, 1:, :, :], gen_qkv[:, :, 1:, :, :]], dim=1
) # (batch, pcpt_seq+gen_seq, 2, nheads, head_dim)
cross_kv = rearrange(cross_kv, "b pcpt_gen_s two h d -> b pcpt_gen_s two (h d)")
# make the attention mask, for pytorch implementation, true means attention is not allowed
@lru_cache(maxsize=1)
def make_mask(q_len, k_len, device):
attention_mask = torch.zeros(
(q_len, k_len), device=device, dtype=torch.bool
) # (gen_len, pcpt_len+gen_len)
# make the last gen_len by gen_gen to be true, only the diagonal is allowed with false
attention_mask[:, -q_len:] = ~torch.eye(
q_len, device=device, dtype=torch.bool
)
return attention_mask
attention_mask = make_mask(cross_q.shape[1], cross_kv.shape[1], cross_q.device)
if pcpt_key_padding_mask is None and gen_key_padding_mask is None:
key_padding_mask = None
else:
if pcpt_key_padding_mask is None:
pcpt_key_padding_mask = torch.ones(
(pcpt_qkv.shape[0], pcpt_qkv.shape[1]),
device=pcpt_qkv.device,
dtype=torch.bool,
)
elif gen_key_padding_mask is None:
gen_key_padding_mask = torch.ones(
(gen_qkv.shape[0], gen_qkv.shape[1]),
device=gen_qkv.device,
dtype=torch.bool,
)
key_padding_mask = ~torch.cat(
[pcpt_key_padding_mask, gen_key_padding_mask], dim=1
)
cross_context, _ = self.cross_attn(
cross_q,
cross_kv[:, :, 0, :],
cross_kv[:, :, 1, :],
key_padding_mask=key_padding_mask,
attn_mask=attention_mask,
)
gen_context = cross_context # (batch, gen_len, hidden_dim)
gen_attn_weights = None
# # CROSS ATTENTION ON GEN GENES
# # prepare cross_q, where each query is per only one gen gene
# cross_q = gen_qkv[:, :, 0, :, :] # (batch, gen_len, nheads, head_dim)
# cross_q = rearrange(cross_q, "b s h d -> b s (h d)")
# cross_q_unpad, indices_q, cu_seq_len_q, max_seqlen_q = unpad_input(
# cross_q, gen_key_padding_mask
# )
# # if not care about padding (b gen_s) 1 (h d)
# # the input to cross attention, q needs to be (total_q, nheads, head_dim)
# # prepare gen_kv, where each kv is this gen gene plus the entire set of pect genes
# # if not care about padding (b gen_s) pcpt_seq+1 (h d)
# # call cross attention
# gen_context, gen_attn_weights = self.cross_attn(
# gen_q,
# gen_kv,
# q_padding_mask=gen_key_padding_mask,
# kv_padding_mask=pcpt_key_padding_mask,
# need_weights=need_weights,
# )
# # rearrange output to (batch, gen_len, hidden_dim)
# # TEMP TEST
# gen_context, gen_attn_weights = self.self_attn(
# gen_qkv,
# key_padding_mask=gen_key_padding_mask,
# need_weights=need_weights,
# causal=self.causal,
# )
# gen_context = self.out_proj(rearrange(gen_context, "b s h d -> b s (h d)"))
return (pcpt_context, gen_context), (pcpt_attn_weights, gen_attn_weights)
[文档]
class FlashscGPTLayer(nn.Module):
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
The class is modified from torch.nn.TransformerEncoderLayer to support the
FlashAttention.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False``.
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
Alternatively, when ``batch_first`` is ``True``:
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
>>> src = torch.rand(32, 10, 512)
>>> out = encoder_layer(src)
"""
__constants__ = ["batch_first"]
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
layer_norm_eps=1e-5,
batch_first=True,
device=None,
dtype=None,
norm_scheme="post", # "pre" or "post"
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
# print(f"factory_kwargs: {factory_kwargs}")
self.self_attn = FlashscGPTMHA(
embed_dim=d_model,
num_heads=nhead,
batch_first=batch_first,
attention_dropout=dropout,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = self._get_activation_fn(activation)
self.norm_scheme = norm_scheme
if norm_scheme not in ["pre", "post"]:
raise ValueError("norm_scheme must be either pre or post")
@staticmethod
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = F.relu
super().__setstate__(state)
def _reverse_key_padding_mask(self, src_key_padding_mask):
"""
Reverse the true false values of the key padding mask. This is because
we follow pytorch rule that the mask is True for padded tokens, but
in the inner flash MHA, it assumes the mask is False for padded tokens.
"""
if src_key_padding_mask is None:
return None
if not src_key_padding_mask.any().item():
# no padding tokens in src
return None
return ~src_key_padding_mask
[文档]
def forward(
self,
pcpt_total_embs: Tensor,
gen_total_embs: Tensor,
pcpt_key_padding_mask: Optional[Tensor] = None,
gen_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
# print(f">>> FlashscGPTLayer")
# print(f"pcpt_key_padding_mask: {type(pcpt_key_padding_mask)} {pcpt_key_padding_mask.dtype} {pcpt_key_padding_mask}")
# print(f"gen_key_padding_mask: {type(gen_key_padding_mask)} {gen_key_padding_mask.dtype} {gen_key_padding_mask}")
pcpt_key_padding_mask_ = self._reverse_key_padding_mask(pcpt_key_padding_mask)
gen_key_padding_mask_ = self._reverse_key_padding_mask(gen_key_padding_mask)
# print(f"FlashscGPTLayer ... 1")
if self.norm_scheme == "pre":
# print(f"FlashscGPTLayer ... 1 pre")
pcpt_total_embs = self.norm1(pcpt_total_embs)
if gen_total_embs is not None:
gen_total_embs = self.norm1(gen_total_embs)
pcpt_total_embs2, gen_total_embs2 = self.self_attn(
pcpt_total_embs,
gen_total_embs,
pcpt_key_padding_mask=pcpt_key_padding_mask_,
gen_key_padding_mask=gen_key_padding_mask_,
)[0]
pcpt_total_embs = pcpt_total_embs + self.dropout1(pcpt_total_embs2)
pcpt_total_embs = self.norm2(pcpt_total_embs)
pcpt_total_embs2 = self.linear2(
self.dropout(self.activation(self.linear1(pcpt_total_embs)))
)
pcpt_total_embs = pcpt_total_embs + self.dropout2(pcpt_total_embs2)
if gen_total_embs is not None:
gen_total_embs = gen_total_embs + self.dropout1(gen_total_embs2)
gen_total_embs = self.norm2(gen_total_embs)
gen_total_embs2 = self.linear2(
self.dropout(self.activation(self.linear1(gen_total_embs)))
)
gen_total_embs = gen_total_embs + self.dropout2(gen_total_embs2)
else:
# pcpt_total_embs
# gen_total_embs
# pcpt_key_padding_mask_
# gen_key_padding_mask_
# print(f"FlashscGPTLayer ... 1 ...")
# print(f"pcpt_total_embs: {type(pcpt_total_embs)} {pcpt_total_embs.dtype} {pcpt_total_embs}")
# print(f"gen_total_embs: {type(gen_total_embs)} {gen_total_embs.dtype} {gen_total_embs}")
# print(f"pcpt_key_padding_mask_: {type(pcpt_key_padding_mask_)} {pcpt_key_padding_mask_.dtype} {pcpt_key_padding_mask_}")
# print(f"gen_key_padding_mask_: {type(gen_key_padding_mask_)} {gen_key_padding_mask_.dtype} {gen_key_padding_mask_}")
# print(f"FlashscGPTLayer ... 1.1 self_attn ... start")
pcpt_total_embs2, gen_total_embs2 = self.self_attn(
pcpt_total_embs,
gen_total_embs,
pcpt_key_padding_mask=pcpt_key_padding_mask_,
gen_key_padding_mask=gen_key_padding_mask_,
)[0]
# print(f"FlashscGPTLayer ... 1.1 self_attn ... end")
pcpt_total_embs = pcpt_total_embs + self.dropout1(pcpt_total_embs2)
pcpt_total_embs = self.norm1(pcpt_total_embs)
# print(f"pcpt_total_embs: {type(pcpt_total_embs)} {pcpt_total_embs.dtype} {pcpt_total_embs}")
pcpt_total_embs2 = self.linear2(
self.dropout(self.activation(self.linear1(pcpt_total_embs)))
)
pcpt_total_embs = pcpt_total_embs + self.dropout2(pcpt_total_embs2)
pcpt_total_embs = self.norm2(pcpt_total_embs)
# print(f"gen_total_embs: {gen_total_embs}")
if gen_total_embs is not None:
gen_total_embs = gen_total_embs + self.dropout1(gen_total_embs2)
gen_total_embs = self.norm1(gen_total_embs)
# print(f"gen_total_embs: {type(gen_total_embs)} {gen_total_embs.dtype} {gen_total_embs}")
gen_total_embs2 = self.linear2(
self.dropout(self.activation(self.linear1(gen_total_embs)))
)
gen_total_embs = gen_total_embs + self.dropout2(gen_total_embs2)
gen_total_embs = self.norm2(gen_total_embs)
# print(f"FlashscGPTLayer ... 2")
return pcpt_total_embs, gen_total_embs
[文档]
class FlashscGPTGenerator(nn.Module):
# takes in the set of different inputs in an mapping
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(
self,
encoder_layer,
num_layers,
norm=None,
mask_check=True,
):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.mask_check = mask_check
[文档]
def forward(
self,
pcpt_total_embs: Tensor,
gen_total_embs: Tensor,
pcpt_key_padding_mask: Optional[Tensor] = None,
gen_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
if pcpt_key_padding_mask is not None:
_skpm_dtype = pcpt_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
pcpt_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
for mod in self.layers:
pcpt_total_embs, gen_total_embs = mod(
pcpt_total_embs,
gen_total_embs,
pcpt_key_padding_mask,
gen_key_padding_mask,
)
if self.norm is not None:
pcpt_total_embs = self.norm(pcpt_total_embs)
gen_total_embs = self.norm(gen_total_embs)
return pcpt_total_embs, gen_total_embs