Spaces:
Running
Running
| """ | |
| Model Loader | |
| ============ | |
| Responsible for loading and initializing models (Single Responsibility) | |
| """ | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| from typing import Tuple | |
| import logging | |
| from app.models.phobert_model import PhoBERTFineTuned | |
| from app.core.config import settings | |
| from app.core.exceptions import ModelNotLoadedException | |
| logger = logging.getLogger(__name__) | |
| class ModelLoader: | |
| """ | |
| Model loader service | |
| Responsibilities: | |
| - Load tokenizer | |
| - Load base model | |
| - Load fine-tuned weights | |
| - Initialize model on correct device | |
| """ | |
| def __init__(self): | |
| self._model: PhoBERTFineTuned | None = None | |
| self._tokenizer: AutoTokenizer | None = None | |
| self._device: torch.device | None = None | |
| def load(self) -> Tuple[PhoBERTFineTuned, AutoTokenizer, torch.device]: | |
| """ | |
| Load model, tokenizer, and set device | |
| Returns: | |
| model: Loaded model | |
| tokenizer: Loaded tokenizer | |
| device: Device (CPU/CUDA) | |
| Raises: | |
| ModelNotLoadedException: If loading fails | |
| """ | |
| try: | |
| # Set device | |
| self._device = torch.device(settings.DEVICE) | |
| logger.info(f"Using device: {self._device}") | |
| # Load tokenizer | |
| logger.info(f"Loading tokenizer: {settings.MODEL_NAME}") | |
| self._tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME) | |
| # Load base model | |
| logger.info(f"Loading base model: {settings.MODEL_NAME}") | |
| phobert = AutoModel.from_pretrained(settings.MODEL_NAME) | |
| # Initialize fine-tuned model | |
| logger.info("Initializing fine-tuned model") | |
| self._model = PhoBERTFineTuned( | |
| embedding_model=phobert, | |
| hidden_dim=768, | |
| dropout=0.3, | |
| num_classes=2, | |
| num_layers_to_finetune=4, | |
| pooling='mean' | |
| ) | |
| # Load weights | |
| logger.info(f"Loading weights from: {settings.MODEL_PATH}") | |
| state_dict = torch.load( | |
| settings.MODEL_PATH, | |
| map_location=self._device | |
| ) | |
| self._model.load_state_dict(state_dict) | |
| # Move to device and set eval mode | |
| self._model = self._model.to(self._device) | |
| self._model.eval() | |
| logger.info("Model loaded successfully") | |
| return self._model, self._tokenizer, self._device | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise ModelNotLoadedException() | |
| def model(self) -> PhoBERTFineTuned: | |
| """Get loaded model""" | |
| if self._model is None: | |
| raise ModelNotLoadedException() | |
| return self._model | |
| def tokenizer(self) -> AutoTokenizer: | |
| """Get loaded tokenizer""" | |
| if self._tokenizer is None: | |
| raise ModelNotLoadedException() | |
| return self._tokenizer | |
| def device(self) -> torch.device: | |
| """Get device""" | |
| if self._device is None: | |
| raise ModelNotLoadedException() | |
| return self._device | |
| def is_loaded(self) -> bool: | |
| """Check if model is loaded""" | |
| return all([ | |
| self._model is not None, | |
| self._tokenizer is not None, | |
| self._device is not None | |
| ]) | |
| # Singleton instance | |
| model_loader = ModelLoader() | |