Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch | |
| from mmengine.model import BaseModule | |
| from torch import nn | |
| from mmpretrain.registry import MODELS | |
| class CosineSimilarityLoss(BaseModule): | |
| """Cosine similarity loss function. | |
| Compute the similarity between two features and optimize that similarity as | |
| loss. | |
| Args: | |
| shift_factor (float): The shift factor of cosine similarity. | |
| Default: 0.0. | |
| scale_factor (float): The scale factor of cosine similarity. | |
| Default: 1.0. | |
| """ | |
| def __init__(self, | |
| shift_factor: float = 0.0, | |
| scale_factor: float = 1.0) -> None: | |
| super().__init__() | |
| self.shift_factor = shift_factor | |
| self.scale_factor = scale_factor | |
| def forward(self, | |
| pred: torch.Tensor, | |
| target: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """Forward function of cosine similarity loss. | |
| Args: | |
| pred (torch.Tensor): The predicted features. | |
| target (torch.Tensor): The target features. | |
| Returns: | |
| torch.Tensor: The cosine similarity loss. | |
| """ | |
| pred_norm = nn.functional.normalize(pred, dim=-1) | |
| target_norm = nn.functional.normalize(target, dim=-1) | |
| loss = self.shift_factor - self.scale_factor * ( | |
| pred_norm * target_norm).sum(dim=-1) | |
| if mask is None: | |
| loss = loss.mean() | |
| else: | |
| loss = (loss * mask).sum() / mask.sum() | |
| return loss | |