Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| from mmengine.model import BaseModule | |
| from mmpretrain.registry import MODELS | |
| class SparKPretrainHead(BaseModule): | |
| """Pre-training head for SparK. | |
| Args: | |
| loss (dict): Config of loss. | |
| norm_pix (bool): Whether or not normalize target. Defaults to True. | |
| patch_size (int): Patch size, equal to downsample ratio of backbone. | |
| Defaults to 32. | |
| """ | |
| def __init__(self, | |
| loss: dict, | |
| norm_pix: bool = True, | |
| patch_size: int = 32) -> None: | |
| super().__init__() | |
| self.norm_pix = norm_pix | |
| self.patch_size = patch_size | |
| self.loss = MODELS.build(loss) | |
| def patchify(self, imgs): | |
| """Split images into non-overlapped patches. | |
| Args: | |
| imgs (torch.Tensor): A batch of images, of shape B x C x H x W. | |
| Returns: | |
| torch.Tensor: Patchified images. The shape is B x L x D. | |
| """ | |
| p = self.patch_size | |
| assert len(imgs.shape | |
| ) == 4 and imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0 | |
| B, C, ori_h, ori_w = imgs.shape | |
| h = ori_h // p | |
| w = ori_w // p | |
| x = imgs.reshape(shape=(B, C, h, p, w, p)) | |
| x = torch.einsum('bchpwq->bhwpqc', x) | |
| # (B, f*f, downsample_raito*downsample_raito*3) | |
| x = x.reshape(shape=(B, h * w, p**2 * C)) | |
| return x | |
| def construct_target(self, target: torch.Tensor) -> torch.Tensor: | |
| """Construct the reconstruction target. | |
| In addition to splitting images into tokens, this module will also | |
| normalize the image according to ``norm_pix``. | |
| Args: | |
| target (torch.Tensor): Image with the shape of B x 3 x H x W | |
| Returns: | |
| torch.Tensor: Tokenized images with the shape of B x L x C | |
| """ | |
| target = self.patchify(target) | |
| if self.norm_pix: | |
| # normalize the target image | |
| mean = target.mean(dim=-1, keepdim=True) | |
| var = target.var(dim=-1, keepdim=True) | |
| target = (target - mean) / (var + 1.e-6)**.5 | |
| return target | |
| def forward(self, pred: torch.Tensor, target: torch.Tensor, | |
| active_mask: torch.Tensor) -> torch.Tensor: | |
| """Forward function of MAE head. | |
| Args: | |
| pred (torch.Tensor): The reconstructed image. | |
| target (torch.Tensor): The target image. | |
| active_mask (torch.Tensor): The mask of the target image. | |
| Returns: | |
| torch.Tensor: The reconstruction loss. | |
| """ | |
| # (B, C, H, W) -> (B, L, C) and perform normalization | |
| target = self.construct_target(target) | |
| # (B, C, H, W) -> (B, L, C) | |
| pred = self.patchify(pred) | |
| # (B, 1, f, f) -> (B, L) | |
| non_active_mask = active_mask.logical_not().int().view( | |
| active_mask.shape[0], -1) | |
| # MSE loss on masked patches | |
| loss = self.loss(pred, target, non_active_mask) | |
| return loss | |