scgpt.loss 源代码

import torch
import torch.nn.functional as F


[文档] def masked_mse_loss( input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: """ Compute the masked MSE loss between input and target. """ mask = mask.float() loss = F.mse_loss(input * mask, target * mask, reduction="sum") return loss / mask.sum()
[文档] def criterion_neg_log_bernoulli( input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: """ Compute the negative log-likelihood of Bernoulli distribution """ mask = mask.float() bernoulli = torch.distributions.Bernoulli(probs=input) masked_log_probs = bernoulli.log_prob((target > 0).float()) * mask return -masked_log_probs.sum() / mask.sum()
[文档] def masked_relative_error( input: torch.Tensor, target: torch.Tensor, mask: torch.LongTensor ) -> torch.Tensor: """ Compute the masked relative error between input and target. """ assert mask.any() loss = torch.abs(input[mask] - target[mask]) / (target[mask] + 1e-6) return loss.mean()