scgpt.model.flash_layers 源代码

from functools import lru_cache
import math
from typing import Optional

from numpy import dtype

from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.transformer import _get_clones

from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.modules.mha import FlashCrossAttention
from .layers import MultiheadAttention


[文档] class FlashscGPTMHA(nn.Module): """ Custom MHA layer for scGPT. This takes two separate forward passes on the pect genes, and on the gen genes. """ def __init__( self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0, causal=False, device=None, dtype=None, ) -> None: assert batch_first factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.causal = causal self.num_heads = num_heads assert ( self.embed_dim % num_heads == 0 ), "self.kdim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads assert ( self.head_dim % 8 == 0 and self.head_dim <= 128 ), "Only support head_dim <= 128 and divisible by 8" self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) self.self_attn = FlashAttention(attention_dropout=attention_dropout) self.cross_attn = MultiheadAttention( embed_dim, num_heads, dropout=attention_dropout, batch_first=batch_first, **factory_kwargs, ) # self.cross_attn = FlashCrossAttention(attention_dropout=attention_dropout) # for cross attetion, launch multiple queries in parallel, each query is just # a single gen gene. Then each kv is the entire set of pect genes plus this gen # gene together. # In practice, we can simply put these queries in the batch dimension, and then # they can be processed in parallel. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
[文档] def forward( self, pcpt_total_embs: Tensor, gen_total_embs: Tensor, pcpt_key_padding_mask: Optional[Tensor] = None, gen_key_padding_mask: Optional[Tensor] = None, need_weights=False, ): """ pcpt_total_embs: (batch, pcpt_len, hidden_dim) (where hidden_dim = num heads * head dim) gen_total_embs: (batch, gen_len, hidden_dim) pcpt_key_padding_mask: bool tensor of shape (batch, pcpt_len), 1 means valid and 0 means not valid. gen_key_padding_mask: bool tensor of shape (batch, gen_len), 1 means valid and 0 means not valid. """ pcpt_qkv = self.Wqkv(pcpt_total_embs) pcpt_qkv = rearrange( pcpt_qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads ) # full self attention on pcpt genes pcpt_context, pcpt_attn_weights = self.self_attn( pcpt_qkv, key_padding_mask=pcpt_key_padding_mask, need_weights=need_weights, causal=self.causal, ) # print(f"pcpt_qkv: {type(pcpt_qkv)} {pcpt_qkv.dtype}") # print(f"pcpt_key_padding_mask: {type(pcpt_key_padding_mask)} {pcpt_key_padding_mask.dtype}") # print(f"need_weights: {type(need_weights)} {need_weights}") # print(f"pcpt_context: {type(pcpt_context)} {pcpt_context.dtype}") # torch.float16 # print(f"pcpt_attn_weights: {type(pcpt_attn_weights)} {pcpt_attn_weights}") # temp = rearrange(pcpt_context, "b s h d -> b s (h d)") # print(f"temp ... 1: {type(temp)} {temp.dtype} {temp.shape}") # torch.float16 temp = temp.to(torch.float32) pcpt_context = self.out_proj(temp) # pcpt_context = self.out_proj(rearrange(pcpt_context, "b s h d -> b s (h d)")) # print(f"pcpt_context: {type(pcpt_context)} {pcpt_context.dtype} {pcpt_context.shape}") if gen_total_embs is None: return (pcpt_context, None), (pcpt_attn_weights, None) gen_qkv = self.Wqkv(gen_total_embs) gen_qkv = rearrange( gen_qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads ) # CROSS ATTENTION USING RAW PYTORCH IMPLEMENTATION cross_q = gen_qkv[:, :, 0, :, :] # (batch, gen_len, nheads, head_dim) cross_q = rearrange(cross_q, "b gen_s h d -> b gen_s (h d)") cross_kv = torch.cat( [pcpt_qkv[:, :, 1:, :, :], gen_qkv[:, :, 1:, :, :]], dim=1 ) # (batch, pcpt_seq+gen_seq, 2, nheads, head_dim) cross_kv = rearrange(cross_kv, "b pcpt_gen_s two h d -> b pcpt_gen_s two (h d)") # make the attention mask, for pytorch implementation, true means attention is not allowed @lru_cache(maxsize=1) def make_mask(q_len, k_len, device): attention_mask = torch.zeros( (q_len, k_len), device=device, dtype=torch.bool ) # (gen_len, pcpt_len+gen_len) # make the last gen_len by gen_gen to be true, only the diagonal is allowed with false attention_mask[:, -q_len:] = ~torch.eye( q_len, device=device, dtype=torch.bool ) return attention_mask attention_mask = make_mask(cross_q.shape[1], cross_kv.shape[1], cross_q.device) if pcpt_key_padding_mask is None and gen_key_padding_mask is None: key_padding_mask = None else: if pcpt_key_padding_mask is None: pcpt_key_padding_mask = torch.ones( (pcpt_qkv.shape[0], pcpt_qkv.shape[1]), device=pcpt_qkv.device, dtype=torch.bool, ) elif gen_key_padding_mask is None: gen_key_padding_mask = torch.ones( (gen_qkv.shape[0], gen_qkv.shape[1]), device=gen_qkv.device, dtype=torch.bool, ) key_padding_mask = ~torch.cat( [pcpt_key_padding_mask, gen_key_padding_mask], dim=1 ) cross_context, _ = self.cross_attn( cross_q, cross_kv[:, :, 0, :], cross_kv[:, :, 1, :], key_padding_mask=key_padding_mask, attn_mask=attention_mask, ) gen_context = cross_context # (batch, gen_len, hidden_dim) gen_attn_weights = None # # CROSS ATTENTION ON GEN GENES # # prepare cross_q, where each query is per only one gen gene # cross_q = gen_qkv[:, :, 0, :, :] # (batch, gen_len, nheads, head_dim) # cross_q = rearrange(cross_q, "b s h d -> b s (h d)") # cross_q_unpad, indices_q, cu_seq_len_q, max_seqlen_q = unpad_input( # cross_q, gen_key_padding_mask # ) # # if not care about padding (b gen_s) 1 (h d) # # the input to cross attention, q needs to be (total_q, nheads, head_dim) # # prepare gen_kv, where each kv is this gen gene plus the entire set of pect genes # # if not care about padding (b gen_s) pcpt_seq+1 (h d) # # call cross attention # gen_context, gen_attn_weights = self.cross_attn( # gen_q, # gen_kv, # q_padding_mask=gen_key_padding_mask, # kv_padding_mask=pcpt_key_padding_mask, # need_weights=need_weights, # ) # # rearrange output to (batch, gen_len, hidden_dim) # # TEMP TEST # gen_context, gen_attn_weights = self.self_attn( # gen_qkv, # key_padding_mask=gen_key_padding_mask, # need_weights=need_weights, # causal=self.causal, # ) # gen_context = self.out_proj(rearrange(gen_context, "b s h d -> b s (h d)")) return (pcpt_context, gen_context), (pcpt_attn_weights, gen_attn_weights)
[文档] class FlashscGPTLayer(nn.Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. The class is modified from torch.nn.TransformerEncoderLayer to support the FlashAttention. Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of intermediate layer, relu or gelu (default=relu). layer_norm_eps: the eps value in layer normalization components (default=1e-5). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False``. Examples:: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> out = encoder_layer(src) Alternatively, when ``batch_first`` is ``True``: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) """ __constants__ = ["batch_first"] def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", layer_norm_eps=1e-5, batch_first=True, device=None, dtype=None, norm_scheme="post", # "pre" or "post" ) -> None: super().__init__() factory_kwargs = {"device": device, "dtype": dtype} # print(f"factory_kwargs: {factory_kwargs}") self.self_attn = FlashscGPTMHA( embed_dim=d_model, num_heads=nhead, batch_first=batch_first, attention_dropout=dropout, **factory_kwargs, ) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = self._get_activation_fn(activation) self.norm_scheme = norm_scheme if norm_scheme not in ["pre", "post"]: raise ValueError("norm_scheme must be either pre or post") @staticmethod def _get_activation_fn(activation): if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) def __setstate__(self, state): if "activation" not in state: state["activation"] = F.relu super().__setstate__(state) def _reverse_key_padding_mask(self, src_key_padding_mask): """ Reverse the true false values of the key padding mask. This is because we follow pytorch rule that the mask is True for padded tokens, but in the inner flash MHA, it assumes the mask is False for padded tokens. """ if src_key_padding_mask is None: return None if not src_key_padding_mask.any().item(): # no padding tokens in src return None return ~src_key_padding_mask
[文档] def forward( self, pcpt_total_embs: Tensor, gen_total_embs: Tensor, pcpt_key_padding_mask: Optional[Tensor] = None, gen_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ # print(f">>> FlashscGPTLayer") # print(f"pcpt_key_padding_mask: {type(pcpt_key_padding_mask)} {pcpt_key_padding_mask.dtype} {pcpt_key_padding_mask}") # print(f"gen_key_padding_mask: {type(gen_key_padding_mask)} {gen_key_padding_mask.dtype} {gen_key_padding_mask}") pcpt_key_padding_mask_ = self._reverse_key_padding_mask(pcpt_key_padding_mask) gen_key_padding_mask_ = self._reverse_key_padding_mask(gen_key_padding_mask) # print(f"FlashscGPTLayer ... 1") if self.norm_scheme == "pre": # print(f"FlashscGPTLayer ... 1 pre") pcpt_total_embs = self.norm1(pcpt_total_embs) if gen_total_embs is not None: gen_total_embs = self.norm1(gen_total_embs) pcpt_total_embs2, gen_total_embs2 = self.self_attn( pcpt_total_embs, gen_total_embs, pcpt_key_padding_mask=pcpt_key_padding_mask_, gen_key_padding_mask=gen_key_padding_mask_, )[0] pcpt_total_embs = pcpt_total_embs + self.dropout1(pcpt_total_embs2) pcpt_total_embs = self.norm2(pcpt_total_embs) pcpt_total_embs2 = self.linear2( self.dropout(self.activation(self.linear1(pcpt_total_embs))) ) pcpt_total_embs = pcpt_total_embs + self.dropout2(pcpt_total_embs2) if gen_total_embs is not None: gen_total_embs = gen_total_embs + self.dropout1(gen_total_embs2) gen_total_embs = self.norm2(gen_total_embs) gen_total_embs2 = self.linear2( self.dropout(self.activation(self.linear1(gen_total_embs))) ) gen_total_embs = gen_total_embs + self.dropout2(gen_total_embs2) else: # pcpt_total_embs # gen_total_embs # pcpt_key_padding_mask_ # gen_key_padding_mask_ # print(f"FlashscGPTLayer ... 1 ...") # print(f"pcpt_total_embs: {type(pcpt_total_embs)} {pcpt_total_embs.dtype} {pcpt_total_embs}") # print(f"gen_total_embs: {type(gen_total_embs)} {gen_total_embs.dtype} {gen_total_embs}") # print(f"pcpt_key_padding_mask_: {type(pcpt_key_padding_mask_)} {pcpt_key_padding_mask_.dtype} {pcpt_key_padding_mask_}") # print(f"gen_key_padding_mask_: {type(gen_key_padding_mask_)} {gen_key_padding_mask_.dtype} {gen_key_padding_mask_}") # print(f"FlashscGPTLayer ... 1.1 self_attn ... start") pcpt_total_embs2, gen_total_embs2 = self.self_attn( pcpt_total_embs, gen_total_embs, pcpt_key_padding_mask=pcpt_key_padding_mask_, gen_key_padding_mask=gen_key_padding_mask_, )[0] # print(f"FlashscGPTLayer ... 1.1 self_attn ... end") pcpt_total_embs = pcpt_total_embs + self.dropout1(pcpt_total_embs2) pcpt_total_embs = self.norm1(pcpt_total_embs) # print(f"pcpt_total_embs: {type(pcpt_total_embs)} {pcpt_total_embs.dtype} {pcpt_total_embs}") pcpt_total_embs2 = self.linear2( self.dropout(self.activation(self.linear1(pcpt_total_embs))) ) pcpt_total_embs = pcpt_total_embs + self.dropout2(pcpt_total_embs2) pcpt_total_embs = self.norm2(pcpt_total_embs) # print(f"gen_total_embs: {gen_total_embs}") if gen_total_embs is not None: gen_total_embs = gen_total_embs + self.dropout1(gen_total_embs2) gen_total_embs = self.norm1(gen_total_embs) # print(f"gen_total_embs: {type(gen_total_embs)} {gen_total_embs.dtype} {gen_total_embs}") gen_total_embs2 = self.linear2( self.dropout(self.activation(self.linear1(gen_total_embs))) ) gen_total_embs = gen_total_embs + self.dropout2(gen_total_embs2) gen_total_embs = self.norm2(gen_total_embs) # print(f"FlashscGPTLayer ... 2") return pcpt_total_embs, gen_total_embs
[文档] class FlashscGPTGenerator(nn.Module): # takes in the set of different inputs in an mapping r"""TransformerEncoder is a stack of N encoder layers. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. Args: encoder_layer: an instance of the TransformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). norm: the layer normalization component (optional). enable_nested_tensor: if True, input will automatically convert to nested tensor (and convert back on output). This will improve the overall performance of TransformerEncoder when padding rate is high. Default: ``True`` (enabled). Examples:: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src) """ __constants__ = ["norm"] def __init__( self, encoder_layer, num_layers, norm=None, mask_check=True, ): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.mask_check = mask_check
[文档] def forward( self, pcpt_total_embs: Tensor, gen_total_embs: Tensor, pcpt_key_padding_mask: Optional[Tensor] = None, gen_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ if pcpt_key_padding_mask is not None: _skpm_dtype = pcpt_key_padding_mask.dtype if _skpm_dtype != torch.bool and not torch.is_floating_point( pcpt_key_padding_mask ): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) for mod in self.layers: pcpt_total_embs, gen_total_embs = mod( pcpt_total_embs, gen_total_embs, pcpt_key_padding_mask, gen_key_padding_mask, ) if self.norm is not None: pcpt_total_embs = self.norm(pcpt_total_embs) gen_total_embs = self.norm(gen_total_embs) return pcpt_total_embs, gen_total_embs