scgpt.model.grad_reverse 源代码
import torch
from torch.autograd import Function
[文档]
class GradReverse(Function):
[文档]
@staticmethod
def forward(ctx, x: torch.Tensor, lambd: float) -> torch.Tensor:
ctx.lambd = lambd
return x.view_as(x)
[文档]
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return grad_output.neg() * ctx.lambd, None
[文档]
def grad_reverse(x: torch.Tensor, lambd: float = 1.0) -> torch.Tensor:
return GradReverse.apply(x, lambd)