Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| # migrate from mmdetection with modifications | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmpretrain.registry import MODELS | |
| from .utils import weight_reduce_loss | |
| def seesaw_ce_loss(cls_score, | |
| labels, | |
| weight, | |
| cum_samples, | |
| num_classes, | |
| p, | |
| q, | |
| eps, | |
| reduction='mean', | |
| avg_factor=None): | |
| """Calculate the Seesaw CrossEntropy loss. | |
| Args: | |
| cls_score (torch.Tensor): The prediction with shape (N, C), | |
| C is the number of classes. | |
| labels (torch.Tensor): The learning label of the prediction. | |
| weight (torch.Tensor): Sample-wise loss weight. | |
| cum_samples (torch.Tensor): Cumulative samples for each category. | |
| num_classes (int): The number of classes. | |
| p (float): The ``p`` in the mitigation factor. | |
| q (float): The ``q`` in the compenstation factor. | |
| eps (float): The minimal value of divisor to smooth | |
| the computation of compensation factor | |
| reduction (str, optional): The method used to reduce the loss. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| assert cls_score.size(-1) == num_classes | |
| assert len(cum_samples) == num_classes | |
| onehot_labels = F.one_hot(labels, num_classes) | |
| seesaw_weights = cls_score.new_ones(onehot_labels.size()) | |
| # mitigation factor | |
| if p > 0: | |
| sample_ratio_matrix = cum_samples[None, :].clamp( | |
| min=1) / cum_samples[:, None].clamp(min=1) | |
| index = (sample_ratio_matrix < 1.0).float() | |
| sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index | |
| ) # M_{ij} | |
| mitigation_factor = sample_weights[labels.long(), :] | |
| seesaw_weights = seesaw_weights * mitigation_factor | |
| # compensation factor | |
| if q > 0: | |
| scores = F.softmax(cls_score.detach(), dim=1) | |
| self_scores = scores[ | |
| torch.arange(0, len(scores)).to(scores.device).long(), | |
| labels.long()] | |
| score_matrix = scores / self_scores[:, None].clamp(min=eps) | |
| index = (score_matrix > 1.0).float() | |
| compensation_factor = score_matrix.pow(q) * index + (1 - index) | |
| seesaw_weights = seesaw_weights * compensation_factor | |
| cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels)) | |
| loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none') | |
| if weight is not None: | |
| weight = weight.float() | |
| loss = weight_reduce_loss( | |
| loss, weight=weight, reduction=reduction, avg_factor=avg_factor) | |
| return loss | |
| class SeesawLoss(nn.Module): | |
| """Implementation of seesaw loss. | |
| Refers to `Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021) | |
| <https://arxiv.org/abs/2008.10032>`_ | |
| Args: | |
| use_sigmoid (bool): Whether the prediction uses sigmoid of softmax. | |
| Only False is supported. Defaults to False. | |
| p (float): The ``p`` in the mitigation factor. | |
| Defaults to 0.8. | |
| q (float): The ``q`` in the compenstation factor. | |
| Defaults to 2.0. | |
| num_classes (int): The number of classes. | |
| Defaults to 1000 for the ImageNet dataset. | |
| eps (float): The minimal value of divisor to smooth | |
| the computation of compensation factor, default to 1e-2. | |
| reduction (str): The method that reduces the loss to a scalar. | |
| Options are "none", "mean" and "sum". Defaults to "mean". | |
| loss_weight (float): The weight of the loss. Defaults to 1.0 | |
| """ | |
| def __init__(self, | |
| use_sigmoid=False, | |
| p=0.8, | |
| q=2.0, | |
| num_classes=1000, | |
| eps=1e-2, | |
| reduction='mean', | |
| loss_weight=1.0): | |
| super(SeesawLoss, self).__init__() | |
| assert not use_sigmoid, '`use_sigmoid` is not supported' | |
| self.use_sigmoid = False | |
| self.p = p | |
| self.q = q | |
| self.num_classes = num_classes | |
| self.eps = eps | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| self.cls_criterion = seesaw_ce_loss | |
| # cumulative samples for each category | |
| self.register_buffer('cum_samples', | |
| torch.zeros(self.num_classes, dtype=torch.float)) | |
| def forward(self, | |
| cls_score, | |
| labels, | |
| weight=None, | |
| avg_factor=None, | |
| reduction_override=None): | |
| """Forward function. | |
| Args: | |
| cls_score (torch.Tensor): The prediction with shape (N, C). | |
| labels (torch.Tensor): The learning label of the prediction. | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| reduction (str, optional): The method used to reduce the loss. | |
| Options are "none", "mean" and "sum". | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum'), \ | |
| f'The `reduction_override` should be one of (None, "none", ' \ | |
| f'"mean", "sum"), but get "{reduction_override}".' | |
| assert cls_score.size(0) == labels.view(-1).size(0), \ | |
| f'Expected `labels` shape [{cls_score.size(0)}], ' \ | |
| f'but got {list(labels.size())}' | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| assert cls_score.size(-1) == self.num_classes, \ | |
| f'The channel number of output ({cls_score.size(-1)}) does ' \ | |
| f'not match the `num_classes` of seesaw loss ({self.num_classes}).' | |
| # accumulate the samples for each category | |
| unique_labels = labels.unique() | |
| for u_l in unique_labels: | |
| inds_ = labels == u_l.item() | |
| self.cum_samples[u_l] += inds_.sum() | |
| if weight is not None: | |
| weight = weight.float() | |
| else: | |
| weight = labels.new_ones(labels.size(), dtype=torch.float) | |
| # calculate loss_cls_classes | |
| loss_cls = self.loss_weight * self.cls_criterion( | |
| cls_score, labels, weight, self.cum_samples, self.num_classes, | |
| self.p, self.q, self.eps, reduction, avg_factor) | |
| return loss_cls | |