Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch | |
| from mmengine.model import BaseModule | |
| from mmpretrain.registry import MODELS | |
| class MIMHead(BaseModule): | |
| """Pre-training head for Masked Image Modeling. | |
| Args: | |
| loss (dict): Config dict for module of loss functions. | |
| """ | |
| def __init__(self, loss: dict) -> None: | |
| super().__init__() | |
| self.loss_module = MODELS.build(loss) | |
| def loss(self, | |
| pred: torch.Tensor, | |
| target: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """Forward head. | |
| Args: | |
| pred (torch.Tensor): Predictions with shape B x L x C. | |
| target (torch.Tensor): Targets with shape B x L x C. | |
| mask (torch.Tensor): Mask with shape B x L. | |
| Returns: | |
| torch.Tensor: The loss tensor. | |
| """ | |
| loss = self.loss_module(pred, target, mask) | |
| return loss | |