""" PhoBERT Model ============= Model architecture definition (Single Responsibility) """ import torch import torch.nn as nn from typing import Tuple, Optional class PhoBERTFineTuned(nn.Module): """ Fine-tuned PhoBERT model for toxic text classification Responsibilities: - Define model architecture - Forward pass computation """ def __init__( self, embedding_model: nn.Module, hidden_dim: int = 768, dropout: float = 0.3, num_classes: int = 2, num_layers_to_finetune: int = 4, pooling: str = 'mean' ): super(PhoBERTFineTuned, self).__init__() self.embedding = embedding_model self.pooling = pooling self.num_layers_to_finetune = num_layers_to_finetune # Freeze all parameters for param in self.embedding.parameters(): param.requires_grad = False # Unfreeze last N layers if num_layers_to_finetune > 0: total_layers = len(self.embedding.encoder.layer) layers_to_train = list(range( total_layers - num_layers_to_finetune, total_layers )) for layer_idx in layers_to_train: for param in self.embedding.encoder.layer[layer_idx].parameters(): param.requires_grad = True if hasattr(self.embedding, 'pooler') and self.embedding.pooler is not None: for param in self.embedding.pooler.parameters(): param.requires_grad = True # Classification head self.dropout = nn.Dropout(dropout) self.fc1 = nn.Linear(hidden_dim, 256) self.fc2 = nn.Linear(256, num_classes) self.relu = nn.ReLU() self.layer_norm = nn.LayerNorm(hidden_dim) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, return_embeddings: bool = False ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Forward pass Args: input_ids: Input token IDs attention_mask: Attention mask return_embeddings: Whether to return embeddings Returns: logits: Classification logits embeddings: Hidden states (if return_embeddings=True) """ # Get embeddings outputs = self.embedding(input_ids, attention_mask=attention_mask) embeddings = outputs.last_hidden_state # Pooling if self.pooling == 'cls': pooled = embeddings[:, 0, :] elif self.pooling == 'mean': mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float() sum_embeddings = torch.sum(embeddings * mask_expanded, 1) sum_mask = mask_expanded.sum(1) pooled = sum_embeddings / sum_mask else: raise ValueError(f"Unknown pooling method: {self.pooling}") # Classification pooled = self.layer_norm(pooled) out = self.dropout(pooled) out = self.relu(self.fc1(out)) out = self.dropout(out) logits = self.fc2(out) if return_embeddings: return logits, embeddings return logits, None