| import torch |
| import torch.nn as nn |
| import os |
| import json |
| from transformers import LongformerModel, AutoModel, LongformerTokenizerFast, AutoTokenizer, PreTrainedModel |
|
|
| class HarmFormer(PreTrainedModel): |
| def __init__(self, config): |
| super(HarmFormer, self).__init__(config) |
| self.num_classes = config.num_classes |
| self.num_risk_levels = config.num_risk_levels |
| |
| |
| self.base_model = AutoModel.from_config(config) |
|
|
| |
| hidden_size = self.base_model.config.hidden_size |
| |
| self.classifiers = nn.ModuleList([ |
| nn.Sequential( |
| nn.Linear(hidden_size, 128), |
| nn.ReLU(), |
| nn.Linear(128, self.num_risk_levels) |
| ) |
| for _ in range(self.num_classes) |
| ]) |
| |
| def forward(self, input_ids=None, attention_mask=None, **kwargs): |
| outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) |
| pooled_output = outputs[1] |
| |
| |
| logits = [] |
| for classifier in self.classifiers: |
| logits.append(classifier(pooled_output)) |
| |
| return logits |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| |
| config_path = os.path.join(pretrained_model_name_or_path, "config.json") |
| if os.path.exists(config_path): |
| with open(config_path, 'r') as f: |
| model_config = json.load(f) |
| else: |
| |
| from huggingface_hub import hf_hub_download |
| config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="config.json") |
| with open(config_path, 'r') as f: |
| model_config = json.load(f) |
| |
| |
| from transformers import AutoConfig |
| base_model_name = model_config.get("model_name", "allenai/longformer-base-4096") |
| base_config = AutoConfig.from_pretrained(base_model_name) |
| |
| |
| base_config.num_classes = model_config.get("num_classes", 5) |
| base_config.num_risk_levels = model_config.get("num_risk_levels", 3) |
| base_config.architecture = model_config.get("architecture", "SingleFC") |
| |
| |
| model = cls(base_config) |
| |
| |
| checkpoint_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") |
| if os.path.exists(checkpoint_path): |
| state_dict = torch.load(checkpoint_path, map_location="cpu") |
| else: |
| |
| checkpoint_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="pytorch_model.bin") |
| state_dict = torch.load(checkpoint_path, map_location="cpu") |
| |
| model.load_state_dict(state_dict) |
| model.eval() |
| |
| return model |
|
|
| def predict_batch(model, tokenizer, texts, batch_size=32): |
| device = next(model.parameters()).device |
| predictions = [] |
| |
| |
| for i in range(0, len(texts), batch_size): |
| batch_texts = texts[i:i+batch_size] |
| inputs = tokenizer( |
| batch_texts, |
| add_special_tokens=True, |
| max_length=1024, |
| truncation=True, |
| padding='max_length', |
| return_attention_mask=True, |
| return_tensors='pt', |
| ).to(device) |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = torch.stack(outputs, dim=0).permute(1, 0, 2) |
| probs = torch.softmax(logits, dim=-1) |
| batch_preds = [[[round(prob, 3) for prob in class_probs] for class_probs in sample] for sample in probs.cpu().tolist()] |
| predictions.extend(batch_preds) |
| |
| return predictions |
|
|