scgpt.model.grad_reverse 源代码

import torch
from torch.autograd import Function


[文档] class GradReverse(Function):
[文档] @staticmethod def forward(ctx, x: torch.Tensor, lambd: float) -> torch.Tensor: ctx.lambd = lambd return x.view_as(x)
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: return grad_output.neg() * ctx.lambd, None
[文档] def grad_reverse(x: torch.Tensor, lambd: float = 1.0) -> torch.Tensor: return GradReverse.apply(x, lambd)