from typing import Optional, Tuple
import torch
from torch import nn
# The code is modified from https://github.com/wgchang/DSBN/blob/master/model/dsbn.py
class _DomainSpecificBatchNorm(nn.Module):
_version = 2
def __init__(
self,
num_features: int,
num_domains: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
super(_DomainSpecificBatchNorm, self).__init__()
self._cur_domain = None
self.num_domains = num_domains
self.bns = nn.ModuleList(
[
self.bn_handle(num_features, eps, momentum, affine, track_running_stats)
for _ in range(num_domains)
]
)
@property
def bn_handle(self) -> nn.Module:
raise NotImplementedError
@property
def cur_domain(self) -> Optional[int]:
return self._cur_domain
@cur_domain.setter
def cur_domain(self, domain_label: int):
self._cur_domain = domain_label
def reset_running_stats(self):
for bn in self.bns:
bn.reset_running_stats()
def reset_parameters(self):
for bn in self.bns:
bn.reset_parameters()
def _check_input_dim(self, input: torch.Tensor):
raise NotImplementedError
def forward(self, x: torch.Tensor, domain_label: int) -> torch.Tensor:
self._check_input_dim(x)
if domain_label >= self.num_domains:
raise ValueError(
f"Domain label {domain_label} exceeds the number of domains {self.num_domains}"
)
bn = self.bns[domain_label]
self.cur_domain = domain_label
return bn(x)
[文档]
class DomainSpecificBatchNorm1d(_DomainSpecificBatchNorm):
@property
def bn_handle(self) -> nn.Module:
return nn.BatchNorm1d
def _check_input_dim(self, input: torch.Tensor):
if input.dim() > 3:
raise ValueError(
"expected at most 3D input (got {}D input)".format(input.dim())
)
[文档]
class DomainSpecificBatchNorm2d(_DomainSpecificBatchNorm):
@property
def bn_handle(self) -> nn.Module:
return nn.BatchNorm2d
def _check_input_dim(self, input: torch.Tensor):
if input.dim() != 4:
raise ValueError("expected 4D input (got {}D input)".format(input.dim()))