Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| from typing import List, Optional, Sequence, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmengine.fileio import list_from_file | |
| from mmengine.runner import autocast | |
| from mmengine.utils import is_seq_of | |
| from mmpretrain.models.losses import convert_to_one_hot | |
| from mmpretrain.registry import MODELS | |
| from mmpretrain.structures import DataSample | |
| from .cls_head import ClsHead | |
| class NormProduct(nn.Linear): | |
| """An enhanced linear layer with k clustering centers to calculate product | |
| between normalized input and linear weight. | |
| Args: | |
| in_features (int): size of each input sample. | |
| out_features (int): size of each output sample | |
| k (int): The number of clustering centers. Defaults to 1. | |
| bias (bool): Whether there is bias. If set to ``False``, the | |
| layer will not learn an additive bias. Defaults to ``True``. | |
| feature_norm (bool): Whether to normalize the input feature. | |
| Defaults to ``True``. | |
| weight_norm (bool):Whether to normalize the weight. | |
| Defaults to ``True``. | |
| """ | |
| def __init__(self, | |
| in_features: int, | |
| out_features: int, | |
| k=1, | |
| bias: bool = False, | |
| feature_norm: bool = True, | |
| weight_norm: bool = True): | |
| super().__init__(in_features, out_features * k, bias=bias) | |
| self.weight_norm = weight_norm | |
| self.feature_norm = feature_norm | |
| self.out_features = out_features | |
| self.k = k | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| if self.feature_norm: | |
| input = F.normalize(input) | |
| if self.weight_norm: | |
| weight = F.normalize(self.weight) | |
| else: | |
| weight = self.weight | |
| cosine_all = F.linear(input, weight, self.bias) | |
| if self.k == 1: | |
| return cosine_all | |
| else: | |
| cosine_all = cosine_all.view(-1, self.out_features, self.k) | |
| cosine, _ = torch.max(cosine_all, dim=2) | |
| return cosine | |
| class ArcFaceClsHead(ClsHead): | |
| """ArcFace classifier head. | |
| A PyTorch implementation of paper `ArcFace: Additive Angular Margin Loss | |
| for Deep Face Recognition <https://arxiv.org/abs/1801.07698>`_ and | |
| `Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web | |
| Faces <https://link.springer.com/chapter/10.1007/978-3-030-58621-8_43>`_ | |
| Example: | |
| To use ArcFace in config files. | |
| 1. use vanilla ArcFace | |
| .. code:: python | |
| mode = dict( | |
| backbone = xxx, | |
| neck = xxxx, | |
| head=dict( | |
| type='ArcFaceClsHead', | |
| num_classes=5000, | |
| in_channels=1024, | |
| loss = dict(type='CrossEntropyLoss', loss_weight=1.0), | |
| init_cfg=None), | |
| ) | |
| 2. use SubCenterArcFace with 3 sub-centers | |
| .. code:: python | |
| mode = dict( | |
| backbone = xxx, | |
| neck = xxxx, | |
| head=dict( | |
| type='ArcFaceClsHead', | |
| num_classes=5000, | |
| in_channels=1024, | |
| num_subcenters=3, | |
| loss = dict(type='CrossEntropyLoss', loss_weight=1.0), | |
| init_cfg=None), | |
| ) | |
| 3. use SubCenterArcFace With CountPowerAdaptiveMargins | |
| .. code:: python | |
| mode = dict( | |
| backbone = xxx, | |
| neck = xxxx, | |
| head=dict( | |
| type='ArcFaceClsHead', | |
| num_classes=5000, | |
| in_channels=1024, | |
| num_subcenters=3, | |
| loss = dict(type='CrossEntropyLoss', loss_weight=1.0), | |
| init_cfg=None), | |
| ) | |
| custom_hooks = [dict(type='SetAdaptiveMarginsHook')] | |
| Args: | |
| num_classes (int): Number of categories excluding the background | |
| category. | |
| in_channels (int): Number of channels in the input feature map. | |
| num_subcenters (int): Number of subcenters. Defaults to 1. | |
| scale (float): Scale factor of output logit. Defaults to 64.0. | |
| margins (float): The penalty margin. Could be the fllowing formats: | |
| - float: The margin, would be same for all the categories. | |
| - Sequence[float]: The category-based margins list. | |
| - str: A '.txt' file path which contains a list. Each line | |
| represents the margin of a category, and the number in the | |
| i-th row indicates the margin of the i-th class. | |
| Defaults to 0.5. | |
| easy_margin (bool): Avoid theta + m >= PI. Defaults to False. | |
| loss (dict): Config of classification loss. Defaults to | |
| ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. | |
| init_cfg (dict, optional): the config to control the initialization. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| num_classes: int, | |
| in_channels: int, | |
| num_subcenters: int = 1, | |
| scale: float = 64., | |
| margins: Optional[Union[float, Sequence[float], str]] = 0.50, | |
| easy_margin: bool = False, | |
| loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), | |
| init_cfg: Optional[dict] = None): | |
| super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg) | |
| if not isinstance(loss, nn.Module): | |
| loss = MODELS.build(loss) | |
| self.loss_module = loss | |
| assert num_subcenters >= 1 and num_classes >= 0 | |
| self.in_channels = in_channels | |
| self.num_classes = num_classes | |
| self.num_subcenters = num_subcenters | |
| self.scale = scale | |
| self.easy_margin = easy_margin | |
| self.norm_product = NormProduct(in_channels, num_classes, | |
| num_subcenters) | |
| if isinstance(margins, float): | |
| margins = [margins] * num_classes | |
| elif isinstance(margins, str) and margins.endswith('.txt'): | |
| margins = [float(item) for item in list_from_file(margins)] | |
| else: | |
| assert is_seq_of(list(margins), (float, int)), ( | |
| 'the attribute `margins` in ``ArcFaceClsHead`` should be a ' | |
| ' float, a Sequence of float, or a ".txt" file path.') | |
| assert len(margins) == num_classes, \ | |
| 'The length of margins must be equal with num_classes.' | |
| self.register_buffer( | |
| 'margins', torch.tensor(margins).float(), persistent=False) | |
| # To make `phi` monotonic decreasing, refers to | |
| # https://github.com/deepinsight/insightface/issues/108 | |
| sinm_m = torch.sin(math.pi - self.margins) * self.margins | |
| threshold = torch.cos(math.pi - self.margins) | |
| self.register_buffer('sinm_m', sinm_m, persistent=False) | |
| self.register_buffer('threshold', threshold, persistent=False) | |
| def set_margins(self, margins: Union[Sequence[float], float]) -> None: | |
| """set margins of arcface head. | |
| Args: | |
| margins (Union[Sequence[float], float]): The marigins. | |
| """ | |
| if isinstance(margins, float): | |
| margins = [margins] * self.num_classes | |
| assert is_seq_of( | |
| list(margins), float) and (len(margins) == self.num_classes), ( | |
| f'margins must be Sequence[Union(float, int)], get {margins}') | |
| self.margins = torch.tensor( | |
| margins, device=self.margins.device, dtype=torch.float32) | |
| self.sinm_m = torch.sin(self.margins) * self.margins | |
| self.threshold = -torch.cos(self.margins) | |
| def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: | |
| """The process before the final classification head. | |
| The input ``feats`` is a tuple of tensor, and each tensor is the | |
| feature of a backbone stage. In ``ArcFaceHead``, we just obtain the | |
| feature of the last stage. | |
| """ | |
| # The ArcFaceHead doesn't have other module, just return after | |
| # unpacking. | |
| return feats[-1] | |
| def _get_logit_with_margin(self, pre_logits, target): | |
| """add arc margin to the cosine in target index. | |
| The target must be in index format. | |
| """ | |
| assert target.dim() == 1 or ( | |
| target.dim() == 2 and target.shape[1] == 1), \ | |
| 'The target must be in index format.' | |
| cosine = self.norm_product(pre_logits) | |
| phi = torch.cos(torch.acos(cosine) + self.margins) | |
| if self.easy_margin: | |
| # when cosine>0, choose phi | |
| # when cosine<=0, choose cosine | |
| phi = torch.where(cosine > 0, phi, cosine) | |
| else: | |
| # when cos>th, choose phi | |
| # when cos<=th, choose cosine-mm | |
| phi = torch.where(cosine > self.threshold, phi, | |
| cosine - self.sinm_m) | |
| target = convert_to_one_hot(target, self.num_classes) | |
| output = target * phi + (1 - target) * cosine | |
| return output | |
| def forward(self, | |
| feats: Tuple[torch.Tensor], | |
| target: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """The forward process.""" | |
| # Disable AMP | |
| with autocast(enabled=False): | |
| pre_logits = self.pre_logits(feats) | |
| if target is None: | |
| # when eval, logit is the cosine between W and pre_logits; | |
| # cos(theta_yj) = (x/||x||) * (W/||W||) | |
| logit = self.norm_product(pre_logits) | |
| else: | |
| # when training, add a margin to the pre_logits where target is | |
| # True, then logit is the cosine between W and new pre_logits | |
| logit = self._get_logit_with_margin(pre_logits, target) | |
| return self.scale * logit | |
| def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], | |
| **kwargs) -> dict: | |
| """Calculate losses from the classification score. | |
| Args: | |
| feats (tuple[Tensor]): The features extracted from the backbone. | |
| Multiple stage inputs are acceptable but only the last stage | |
| will be used to classify. The shape of every item should be | |
| ``(num_samples, num_classes)``. | |
| data_samples (List[DataSample]): The annotation data of | |
| every samples. | |
| **kwargs: Other keyword arguments to forward the loss module. | |
| Returns: | |
| dict[str, Tensor]: a dictionary of loss components | |
| """ | |
| # Unpack data samples and pack targets | |
| label_target = torch.cat([i.gt_label for i in data_samples]) | |
| if 'gt_score' in data_samples[0]: | |
| # Batch augmentation may convert labels to one-hot format scores. | |
| target = torch.stack([i.gt_score for i in data_samples]) | |
| else: | |
| target = label_target | |
| # the index format target would be used | |
| cls_score = self(feats, label_target) | |
| # compute loss | |
| losses = dict() | |
| loss = self.loss_module( | |
| cls_score, target, avg_factor=cls_score.size(0), **kwargs) | |
| losses['loss'] = loss | |
| return losses | |