Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from abc import ABCMeta, abstractmethod | |
| from typing import List, Optional, Sequence | |
| import torch | |
| from mmengine.model import BaseModel | |
| from mmengine.structures import BaseDataElement | |
| class BaseClassifier(BaseModel, metaclass=ABCMeta): | |
| """Base class for classifiers. | |
| Args: | |
| init_cfg (dict, optional): Initialization config dict. | |
| Defaults to None. | |
| data_preprocessor (dict, optional): The config for preprocessing input | |
| data. If None, it will use "BaseDataPreprocessor" as type, see | |
| :class:`mmengine.model.BaseDataPreprocessor` for more details. | |
| Defaults to None. | |
| Attributes: | |
| init_cfg (dict): Initialization config dict. | |
| data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An | |
| extra data pre-processing module, which processes data from | |
| dataloader to the format accepted by :meth:`forward`. | |
| """ | |
| def __init__(self, | |
| init_cfg: Optional[dict] = None, | |
| data_preprocessor: Optional[dict] = None): | |
| super(BaseClassifier, self).__init__( | |
| init_cfg=init_cfg, data_preprocessor=data_preprocessor) | |
| def with_neck(self) -> bool: | |
| """Whether the classifier has a neck.""" | |
| return hasattr(self, 'neck') and self.neck is not None | |
| def with_head(self) -> bool: | |
| """Whether the classifier has a head.""" | |
| return hasattr(self, 'head') and self.head is not None | |
| def forward(self, | |
| inputs: torch.Tensor, | |
| data_samples: Optional[List[BaseDataElement]] = None, | |
| mode: str = 'tensor'): | |
| """The unified entry for a forward process in both training and test. | |
| The method should accept three modes: "tensor", "predict" and "loss": | |
| - "tensor": Forward the whole network and return tensor or tuple of | |
| tensor without any post-processing, same as a common nn.Module. | |
| - "predict": Forward and return the predictions, which are fully | |
| processed to a list of :obj:`BaseDataElement`. | |
| - "loss": Forward and return a dict of losses according to the given | |
| inputs and data samples. | |
| Note that this method doesn't handle neither back propagation nor | |
| optimizer updating, which are done in the :meth:`train_step`. | |
| Args: | |
| inputs (torch.Tensor): The input tensor with shape (N, C, ...) | |
| in general. | |
| data_samples (List[BaseDataElement], optional): The annotation | |
| data of every samples. It's required if ``mode="loss"``. | |
| Defaults to None. | |
| mode (str): Return what kind of value. Defaults to 'tensor'. | |
| Returns: | |
| The return type depends on ``mode``. | |
| - If ``mode="tensor"``, return a tensor or a tuple of tensor. | |
| - If ``mode="predict"``, return a list of | |
| :obj:`mmengine.BaseDataElement`. | |
| - If ``mode="loss"``, return a dict of tensor. | |
| """ | |
| pass | |
| def extract_feat(self, inputs: torch.Tensor): | |
| """Extract features from the input tensor with shape (N, C, ...). | |
| The sub-classes are recommended to implement this method to extract | |
| features from backbone and neck. | |
| Args: | |
| inputs (Tensor): A batch of inputs. The shape of it should be | |
| ``(num_samples, num_channels, *img_shape)``. | |
| """ | |
| raise NotImplementedError | |
| def extract_feats(self, multi_inputs: Sequence[torch.Tensor], | |
| **kwargs) -> list: | |
| """Extract features from a sequence of input tensor. | |
| Args: | |
| multi_inputs (Sequence[torch.Tensor]): A sequence of input | |
| tensor. It can be used in augmented inference. | |
| **kwargs: Other keyword arguments accepted by :meth:`extract_feat`. | |
| Returns: | |
| list: Features of every input tensor. | |
| """ | |
| assert isinstance(multi_inputs, Sequence), \ | |
| '`extract_feats` is used for a sequence of inputs tensor. If you '\ | |
| 'want to extract on single inputs tensor, use `extract_feat`.' | |
| return [self.extract_feat(inputs, **kwargs) for inputs in multi_inputs] | |