scgpt.model.dsbn 源代码

from typing import Optional, Tuple

import torch
from torch import nn

# The code is modified from https://github.com/wgchang/DSBN/blob/master/model/dsbn.py
class _DomainSpecificBatchNorm(nn.Module):
    _version = 2

    def __init__(
        self,
        num_features: int,
        num_domains: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True,
    ):
        super(_DomainSpecificBatchNorm, self).__init__()
        self._cur_domain = None
        self.num_domains = num_domains
        self.bns = nn.ModuleList(
            [
                self.bn_handle(num_features, eps, momentum, affine, track_running_stats)
                for _ in range(num_domains)
            ]
        )

    @property
    def bn_handle(self) -> nn.Module:
        raise NotImplementedError

    @property
    def cur_domain(self) -> Optional[int]:
        return self._cur_domain

    @cur_domain.setter
    def cur_domain(self, domain_label: int):
        self._cur_domain = domain_label

    def reset_running_stats(self):
        for bn in self.bns:
            bn.reset_running_stats()

    def reset_parameters(self):
        for bn in self.bns:
            bn.reset_parameters()

    def _check_input_dim(self, input: torch.Tensor):
        raise NotImplementedError

    def forward(self, x: torch.Tensor, domain_label: int) -> torch.Tensor:
        self._check_input_dim(x)
        if domain_label >= self.num_domains:
            raise ValueError(
                f"Domain label {domain_label} exceeds the number of domains {self.num_domains}"
            )
        bn = self.bns[domain_label]
        self.cur_domain = domain_label
        return bn(x)


[文档] class DomainSpecificBatchNorm1d(_DomainSpecificBatchNorm): @property def bn_handle(self) -> nn.Module: return nn.BatchNorm1d def _check_input_dim(self, input: torch.Tensor): if input.dim() > 3: raise ValueError( "expected at most 3D input (got {}D input)".format(input.dim()) )
[文档] class DomainSpecificBatchNorm2d(_DomainSpecificBatchNorm): @property def bn_handle(self) -> nn.Module: return nn.BatchNorm2d def _check_input_dim(self, input: torch.Tensor): if input.dim() != 4: raise ValueError("expected 4D input (got {}D input)".format(input.dim()))