Spaces:
Runtime error
Runtime error
| """Define all losses. When possible, as inheriting from nn.Module | |
| To send predictions to target.device | |
| """ | |
| from random import random as rand | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models | |
| class GANLoss(nn.Module): | |
| def __init__( | |
| self, | |
| use_lsgan=True, | |
| target_real_label=1.0, | |
| target_fake_label=0.0, | |
| soft_shift=0.0, | |
| flip_prob=0.0, | |
| verbose=0, | |
| ): | |
| """Defines the GAN loss which uses either LSGAN or the regular GAN. | |
| When LSGAN is used, it is basically same as MSELoss, | |
| but it abstracts away the need to create the target label tensor | |
| that has the same size as the input + | |
| * label smoothing: target_real_label=0.75 | |
| * label flipping: flip_prob > 0. | |
| source: https://github.com/sangwoomo/instagan/blob | |
| /b67e9008fcdd6c41652f8805f0b36bcaa8b632d6/models/networks.py | |
| Args: | |
| use_lsgan (bool, optional): Use MSE or BCE. Defaults to True. | |
| target_real_label (float, optional): Value for the real target. | |
| Defaults to 1.0. | |
| target_fake_label (float, optional): Value for the fake target. | |
| Defaults to 0.0. | |
| flip_prob (float, optional): Probability of flipping the label | |
| (use for real target in Discriminator only). Defaults to 0.0. | |
| """ | |
| super().__init__() | |
| self.soft_shift = soft_shift | |
| self.verbose = verbose | |
| self.register_buffer("real_label", torch.tensor(target_real_label)) | |
| self.register_buffer("fake_label", torch.tensor(target_fake_label)) | |
| if use_lsgan: | |
| self.loss = nn.MSELoss() | |
| else: | |
| self.loss = nn.BCEWithLogitsLoss() | |
| self.flip_prob = flip_prob | |
| def get_target_tensor(self, input, target_is_real): | |
| soft_change = torch.FloatTensor(1).uniform_(0, self.soft_shift) | |
| if self.verbose > 0: | |
| print("GANLoss sampled soft_change:", soft_change.item()) | |
| if target_is_real: | |
| target_tensor = self.real_label - soft_change | |
| else: | |
| target_tensor = self.fake_label + soft_change | |
| return target_tensor.expand_as(input) | |
| def __call__(self, input, target_is_real, *args, **kwargs): | |
| r = rand() | |
| if isinstance(input, list): | |
| loss = 0 | |
| for pred_i in input: | |
| if isinstance(pred_i, list): | |
| pred_i = pred_i[-1] | |
| if r < self.flip_prob: | |
| target_is_real = not target_is_real | |
| target_tensor = self.get_target_tensor(pred_i, target_is_real) | |
| loss_tensor = self.loss(pred_i, target_tensor.to(pred_i.device)) | |
| loss += loss_tensor | |
| return loss / len(input) | |
| else: | |
| if r < self.flip_prob: | |
| target_is_real = not target_is_real | |
| target_tensor = self.get_target_tensor(input, target_is_real) | |
| return self.loss(input, target_tensor.to(input.device)) | |
| class FeatMatchLoss(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.criterionFeat = nn.L1Loss() | |
| def __call__(self, pred_real, pred_fake): | |
| # pred_{real, fake} are lists of features | |
| num_D = len(pred_fake) | |
| GAN_Feat_loss = 0.0 | |
| for i in range(num_D): # for each discriminator | |
| # last output is the final prediction, so we exclude it | |
| num_intermediate_outputs = len(pred_fake[i]) - 1 | |
| for j in range(num_intermediate_outputs): # for each layer output | |
| unweighted_loss = self.criterionFeat( | |
| pred_fake[i][j], pred_real[i][j].detach() | |
| ) | |
| GAN_Feat_loss += unweighted_loss / num_D | |
| return GAN_Feat_loss | |
| class CrossEntropy(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.loss = nn.CrossEntropyLoss() | |
| def __call__(self, logits, target): | |
| return self.loss(logits, target.to(logits.device).long()) | |
| class TravelLoss(nn.Module): | |
| def __init__(self, eps=1e-12): | |
| super().__init__() | |
| self.eps = eps | |
| def cosine_loss(self, real, fake): | |
| norm_real = torch.norm(real, p=2, dim=1)[:, None] | |
| norm_fake = torch.norm(fake, p=2, dim=1)[:, None] | |
| mat_real = real / norm_real | |
| mat_fake = fake / norm_fake | |
| mat_real = torch.max(mat_real, self.eps * torch.ones_like(mat_real)) | |
| mat_fake = torch.max(mat_fake, self.eps * torch.ones_like(mat_fake)) | |
| # compute only the diagonal of the matrix multiplication | |
| return torch.einsum("ij, ji -> i", mat_fake, mat_real).sum() | |
| def __call__(self, S_real, S_fake): | |
| self.v_real = [] | |
| self.v_fake = [] | |
| for i in range(len(S_real)): | |
| for j in range(i): | |
| self.v_real.append((S_real[i] - S_real[j])[None, :]) | |
| self.v_fake.append((S_fake[i] - S_fake[j])[None, :]) | |
| self.v_real_t = torch.cat(self.v_real, dim=0) | |
| self.v_fake_t = torch.cat(self.v_fake, dim=0) | |
| return self.cosine_loss(self.v_real_t, self.v_fake_t) | |
| class TVLoss(nn.Module): | |
| """Total Variational Regularization: Penalizes differences in | |
| neighboring pixel values | |
| source: | |
| https://github.com/jxgu1016/Total_Variation_Loss.pytorch/blob/master/TVLoss.py | |
| """ | |
| def __init__(self, tvloss_weight=1): | |
| """ | |
| Args: | |
| TVLoss_weight (int, optional): [lambda i.e. weight for loss]. Defaults to 1. | |
| """ | |
| super(TVLoss, self).__init__() | |
| self.tvloss_weight = tvloss_weight | |
| def forward(self, x): | |
| batch_size = x.size()[0] | |
| h_x = x.size()[2] | |
| w_x = x.size()[3] | |
| count_h = self._tensor_size(x[:, :, 1:, :]) | |
| count_w = self._tensor_size(x[:, :, :, 1:]) | |
| h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, : h_x - 1, :]), 2).sum() | |
| w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, : w_x - 1]), 2).sum() | |
| return self.tvloss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size | |
| def _tensor_size(self, t): | |
| return t.size()[1] * t.size()[2] * t.size()[3] | |
| class MinentLoss(nn.Module): | |
| """ | |
| Loss for the minimization of the entropy map | |
| Source for version 1: https://github.com/valeoai/ADVENT | |
| Version 2 adds the variance of the entropy map in the computation of the loss | |
| """ | |
| def __init__(self, version=1, lambda_var=0.1): | |
| super().__init__() | |
| self.version = version | |
| self.lambda_var = lambda_var | |
| def __call__(self, pred): | |
| assert pred.dim() == 4 | |
| n, c, h, w = pred.size() | |
| entropy_map = -torch.mul(pred, torch.log2(pred + 1e-30)) / np.log2(c) | |
| if self.version == 1: | |
| return torch.sum(entropy_map) / (n * h * w) | |
| else: | |
| entropy_map_demean = entropy_map - torch.sum(entropy_map) / (n * h * w) | |
| entropy_map_squ = torch.mul(entropy_map_demean, entropy_map_demean) | |
| return torch.sum(entropy_map + self.lambda_var * entropy_map_squ) / ( | |
| n * h * w | |
| ) | |
| class MSELoss(nn.Module): | |
| """ | |
| Creates a criterion that measures the mean squared error | |
| (squared L2 norm) between each element in the input x and target y . | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.loss = nn.MSELoss() | |
| def __call__(self, prediction, target): | |
| return self.loss(prediction, target.to(prediction.device)) | |
| class L1Loss(MSELoss): | |
| """ | |
| Creates a criterion that measures the mean absolute error | |
| (MAE) between each element in the input x and target y | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.loss = nn.L1Loss() | |
| class SIMSELoss(nn.Module): | |
| """Scale invariant MSE Loss""" | |
| def __init__(self): | |
| super(SIMSELoss, self).__init__() | |
| def __call__(self, prediction, target): | |
| d = prediction - target | |
| diff = torch.mean(d * d) | |
| relDiff = torch.mean(d) * torch.mean(d) | |
| return diff - relDiff | |
| class SIGMLoss(nn.Module): | |
| """loss from MiDaS paper | |
| MiDaS did not specify how the gradients were computed but we use Sobel | |
| filters which approximate the derivative of an image. | |
| """ | |
| def __init__(self, gmweight=0.5, scale=4, device="cuda"): | |
| super(SIGMLoss, self).__init__() | |
| self.gmweight = gmweight | |
| self.sobelx = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).to(device) | |
| self.sobely = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).to(device) | |
| self.scale = scale | |
| def __call__(self, prediction, target): | |
| # get disparities | |
| # align both the prediction and the ground truth to have zero | |
| # translation and unit scale | |
| t_pred = torch.median(prediction) | |
| t_targ = torch.median(target) | |
| s_pred = torch.mean(torch.abs(prediction - t_pred)) | |
| s_targ = torch.mean(torch.abs(target - t_targ)) | |
| pred = (prediction - t_pred) / s_pred | |
| targ = (target - t_targ) / s_targ | |
| R = pred - targ | |
| # get gradient map with sobel filters | |
| batch_size = prediction.size()[0] | |
| num_pix = prediction.size()[-1] * prediction.size()[-2] | |
| sobelx = (self.sobelx).expand((batch_size, 1, -1, -1)) | |
| sobely = (self.sobely).expand((batch_size, 1, -1, -1)) | |
| gmLoss = 0 # gradient matching term | |
| for k in range(self.scale): | |
| R_ = F.interpolate(R, scale_factor=1 / 2 ** k) | |
| Rx = F.conv2d(R_, sobelx, stride=1) | |
| Ry = F.conv2d(R_, sobely, stride=1) | |
| gmLoss += torch.sum(torch.abs(Rx) + torch.abs(Ry)) | |
| gmLoss = self.gmweight / num_pix * gmLoss | |
| # scale invariant MSE | |
| simseLoss = 0.5 / num_pix * torch.sum(torch.abs(R)) | |
| loss = simseLoss + gmLoss | |
| return loss | |
| class ContextLoss(nn.Module): | |
| """ | |
| Masked L1 loss on non-water | |
| """ | |
| def __call__(self, input, target, mask): | |
| return torch.mean(torch.abs(torch.mul((input - target), 1 - mask))) | |
| class ReconstructionLoss(nn.Module): | |
| """ | |
| Masked L1 loss on water | |
| """ | |
| def __call__(self, input, target, mask): | |
| return torch.mean(torch.abs(torch.mul((input - target), mask))) | |
| ################################################################################## | |
| # VGG network definition | |
| ################################################################################## | |
| # Source: https://github.com/NVIDIA/pix2pixHD | |
| class Vgg19(nn.Module): | |
| def __init__(self, requires_grad=False): | |
| super(Vgg19, self).__init__() | |
| vgg_pretrained_features = models.vgg19(pretrained=True).features | |
| self.slice1 = nn.Sequential() | |
| self.slice2 = nn.Sequential() | |
| self.slice3 = nn.Sequential() | |
| self.slice4 = nn.Sequential() | |
| self.slice5 = nn.Sequential() | |
| for x in range(2): | |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(2, 7): | |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(7, 12): | |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(12, 21): | |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(21, 30): | |
| self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
| if not requires_grad: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, X): | |
| h_relu1 = self.slice1(X) | |
| h_relu2 = self.slice2(h_relu1) | |
| h_relu3 = self.slice3(h_relu2) | |
| h_relu4 = self.slice4(h_relu3) | |
| h_relu5 = self.slice5(h_relu4) | |
| out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] | |
| return out | |
| # Source: https://github.com/NVIDIA/pix2pixHD | |
| class VGGLoss(nn.Module): | |
| def __init__(self, device): | |
| super().__init__() | |
| self.vgg = Vgg19().to(device).eval() | |
| self.criterion = nn.L1Loss() | |
| self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
| def forward(self, x, y): | |
| x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
| loss = 0 | |
| for i in range(len(x_vgg)): | |
| loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) | |
| return loss | |
| def get_losses(opts, verbose, device=None): | |
| """Sets the loss functions to be used by G, D and C, as specified | |
| in the opts and returns a dictionnary of losses: | |
| losses = { | |
| "G": { | |
| "gan": {"a": ..., "t": ...}, | |
| "cycle": {"a": ..., "t": ...} | |
| "auto": {"a": ..., "t": ...} | |
| "tasks": {"h": ..., "d": ..., "s": ..., etc.} | |
| }, | |
| "D": GANLoss, | |
| "C": ... | |
| } | |
| """ | |
| losses = { | |
| "G": {"a": {}, "p": {}, "tasks": {}}, | |
| "D": {"default": {}, "advent": {}}, | |
| "C": {}, | |
| } | |
| # ------------------------------ | |
| # ----- Generator Losses ----- | |
| # ------------------------------ | |
| # painter losses | |
| if "p" in opts.tasks: | |
| losses["G"]["p"]["gan"] = ( | |
| HingeLoss() | |
| if opts.gen.p.loss == "hinge" | |
| else GANLoss( | |
| use_lsgan=False, | |
| soft_shift=opts.dis.soft_shift, | |
| flip_prob=opts.dis.flip_prob, | |
| ) | |
| ) | |
| losses["G"]["p"]["dm"] = MSELoss() | |
| losses["G"]["p"]["vgg"] = VGGLoss(device) | |
| losses["G"]["p"]["tv"] = TVLoss() | |
| losses["G"]["p"]["context"] = ContextLoss() | |
| losses["G"]["p"]["reconstruction"] = ReconstructionLoss() | |
| losses["G"]["p"]["featmatch"] = FeatMatchLoss() | |
| # depth losses | |
| if "d" in opts.tasks: | |
| if not opts.gen.d.classify.enable: | |
| if opts.gen.d.loss == "dada": | |
| depth_func = DADADepthLoss() | |
| else: | |
| depth_func = SIGMLoss(opts.train.lambdas.G.d.gml) | |
| else: | |
| depth_func = CrossEntropy() | |
| losses["G"]["tasks"]["d"] = depth_func | |
| # segmentation losses | |
| if "s" in opts.tasks: | |
| losses["G"]["tasks"]["s"] = {} | |
| losses["G"]["tasks"]["s"]["crossent"] = CrossEntropy() | |
| losses["G"]["tasks"]["s"]["minent"] = MinentLoss() | |
| losses["G"]["tasks"]["s"]["advent"] = ADVENTAdversarialLoss( | |
| opts, gan_type=opts.dis.s.gan_type | |
| ) | |
| # masker losses | |
| if "m" in opts.tasks: | |
| losses["G"]["tasks"]["m"] = {} | |
| losses["G"]["tasks"]["m"]["bce"] = nn.BCEWithLogitsLoss() | |
| if opts.gen.m.use_minent_var: | |
| losses["G"]["tasks"]["m"]["minent"] = MinentLoss( | |
| version=2, lambda_var=opts.train.lambdas.advent.ent_var | |
| ) | |
| else: | |
| losses["G"]["tasks"]["m"]["minent"] = MinentLoss() | |
| losses["G"]["tasks"]["m"]["tv"] = TVLoss() | |
| losses["G"]["tasks"]["m"]["advent"] = ADVENTAdversarialLoss( | |
| opts, gan_type=opts.dis.m.gan_type | |
| ) | |
| losses["G"]["tasks"]["m"]["gi"] = GroundIntersectionLoss() | |
| # ---------------------------------- | |
| # ----- Discriminator Losses ----- | |
| # ---------------------------------- | |
| if "p" in opts.tasks: | |
| losses["D"]["p"] = losses["G"]["p"]["gan"] | |
| if "m" in opts.tasks or "s" in opts.tasks: | |
| losses["D"]["advent"] = ADVENTAdversarialLoss(opts) | |
| return losses | |
| class GroundIntersectionLoss(nn.Module): | |
| """ | |
| Penalize areas in ground seg but not in flood mask | |
| """ | |
| def __call__(self, pred, pseudo_ground): | |
| return torch.mean(1.0 * ((pseudo_ground - pred) > 0.5)) | |
| def prob_2_entropy(prob): | |
| """ | |
| convert probabilistic prediction maps to weighted self-information maps | |
| """ | |
| n, c, h, w = prob.size() | |
| return -torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c) | |
| class CustomBCELoss(nn.Module): | |
| """ | |
| The first argument is a tensor and the second argument is an int. | |
| There is no need to take sigmoid before calling this function. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.loss = nn.BCEWithLogitsLoss() | |
| def __call__(self, prediction, target): | |
| return self.loss( | |
| prediction, | |
| torch.FloatTensor(prediction.size()) | |
| .fill_(target) | |
| .to(prediction.get_device()), | |
| ) | |
| class ADVENTAdversarialLoss(nn.Module): | |
| """ | |
| The class is for calculating the advent loss. | |
| It is used to indirectly shrink the domain gap between sim and real | |
| _call_ function: | |
| prediction: torch.tensor with shape of [bs,c,h,w] | |
| target: int; domain label: 0 (sim) or 1 (real) | |
| discriminator: the discriminator model tells if a tensor is from sim or real | |
| output: the loss value of GANLoss | |
| """ | |
| def __init__(self, opts, gan_type="GAN"): | |
| super().__init__() | |
| self.opts = opts | |
| if gan_type == "GAN": | |
| self.loss = CustomBCELoss() | |
| elif gan_type == "WGAN" or "WGAN_gp" or "WGAN_norm": | |
| self.loss = lambda x, y: -torch.mean(y * x + (1 - y) * (1 - x)) | |
| else: | |
| raise NotImplementedError | |
| def __call__(self, prediction, target, discriminator, depth_preds=None): | |
| """ | |
| Compute the GAN loss from the Advent Discriminator given | |
| normalized (softmaxed) predictions (=pixel-wise class probabilities), | |
| and int labels (target). | |
| Args: | |
| prediction (torch.Tensor): pixel-wise probability distribution over classes | |
| target (torch.Tensor): pixel wise int target labels | |
| discriminator (torch.nn.Module): Discriminator to get the loss | |
| Returns: | |
| torch.Tensor: float 0-D loss | |
| """ | |
| d_out = prob_2_entropy(prediction) | |
| if depth_preds is not None: | |
| d_out = d_out * depth_preds | |
| d_out = discriminator(d_out) | |
| if self.opts.dis.m.architecture == "OmniDiscriminator": | |
| d_out = multiDiscriminatorAdapter(d_out, self.opts) | |
| loss_ = self.loss(d_out, target) | |
| return loss_ | |
| def multiDiscriminatorAdapter(d_out: list, opts: dict) -> torch.tensor: | |
| """ | |
| Because the OmniDiscriminator does not directly return a tensor | |
| (but a list of tensor). | |
| Since there is no multilevel masker, the 0th tensor in the list is all we want. | |
| This Adapter returns the first element(tensor) of the list that OmniDiscriminator | |
| returns. | |
| """ | |
| if ( | |
| isinstance(d_out, list) and len(d_out) == 1 | |
| ): # adapt the multi-scale OmniDiscriminator | |
| if not opts.dis.p.get_intermediate_features: | |
| d_out = d_out[0][0] | |
| else: | |
| d_out = d_out[0] | |
| else: | |
| raise Exception( | |
| "Check the setting of OmniDiscriminator! " | |
| + "For now, we don't support multi-scale OmniDiscriminator." | |
| ) | |
| return d_out | |
| class HingeLoss(nn.Module): | |
| """ | |
| Adapted from https://github.com/NVlabs/SPADE/blob/master/models/networks/loss.py | |
| for the painter | |
| """ | |
| def __init__(self, tensor=torch.FloatTensor): | |
| super().__init__() | |
| self.zero_tensor = None | |
| self.Tensor = tensor | |
| def get_zero_tensor(self, input): | |
| if self.zero_tensor is None: | |
| self.zero_tensor = self.Tensor(1).fill_(0) | |
| self.zero_tensor.requires_grad_(False) | |
| self.zero_tensor = self.zero_tensor.to(input.device) | |
| return self.zero_tensor.expand_as(input) | |
| def loss(self, input, target_is_real, for_discriminator=True): | |
| if for_discriminator: | |
| if target_is_real: | |
| minval = torch.min(input - 1, self.get_zero_tensor(input)) | |
| loss = -torch.mean(minval) | |
| else: | |
| minval = torch.min(-input - 1, self.get_zero_tensor(input)) | |
| loss = -torch.mean(minval) | |
| else: | |
| assert target_is_real, "The generator's hinge loss must be aiming for real" | |
| loss = -torch.mean(input) | |
| return loss | |
| def __call__(self, input, target_is_real, for_discriminator=True): | |
| # computing loss is a bit complicated because |input| may not be | |
| # a tensor, but list of tensors in case of multiscale discriminator | |
| if isinstance(input, list): | |
| loss = 0 | |
| for pred_i in input: | |
| if isinstance(pred_i, list): | |
| pred_i = pred_i[-1] | |
| loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) | |
| loss += loss_tensor | |
| return loss / len(input) | |
| else: | |
| return self.loss(input, target_is_real, for_discriminator) | |
| class DADADepthLoss: | |
| """Defines the reverse Huber loss from DADA paper for depth prediction | |
| - Samples with larger residuals are penalized more by l2 term | |
| - Samples with smaller residuals are penalized more by l1 term | |
| From https://github.com/valeoai/DADA/blob/master/dada/utils/func.py | |
| """ | |
| def loss_calc_depth(self, pred, label): | |
| n, c, h, w = pred.size() | |
| assert c == 1 | |
| pred = pred.squeeze() | |
| label = label.squeeze() | |
| adiff = torch.abs(pred - label) | |
| batch_max = 0.2 * torch.max(adiff).item() | |
| t1_mask = adiff.le(batch_max).float() | |
| t2_mask = adiff.gt(batch_max).float() | |
| t1 = adiff * t1_mask | |
| t2 = (adiff * adiff + batch_max * batch_max) / (2 * batch_max) | |
| t2 = t2 * t2_mask | |
| return (torch.sum(t1) + torch.sum(t2)) / torch.numel(pred.data) | |
| def __call__(self, pred, label): | |
| return self.loss_calc_depth(pred, label) | |