import torch
import torch.nn.functional as F
[文档]
def masked_mse_loss(
input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""
Compute the masked MSE loss between input and target.
"""
mask = mask.float()
loss = F.mse_loss(input * mask, target * mask, reduction="sum")
return loss / mask.sum()
[文档]
def criterion_neg_log_bernoulli(
input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""
Compute the negative log-likelihood of Bernoulli distribution
"""
mask = mask.float()
bernoulli = torch.distributions.Bernoulli(probs=input)
masked_log_probs = bernoulli.log_prob((target > 0).float()) * mask
return -masked_log_probs.sum() / mask.sum()
[文档]
def masked_relative_error(
input: torch.Tensor, target: torch.Tensor, mask: torch.LongTensor
) -> torch.Tensor:
"""
Compute the masked relative error between input and target.
"""
assert mask.any()
loss = torch.abs(input[mask] - target[mask]) / (target[mask] + 1e-6)
return loss.mean()