Instructions to use epfl-ml4ed/MCQBert with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use epfl-ml4ed/MCQBert with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("epfl-ml4ed/MCQBert", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from transformers import BertModel | |
| import torch | |
| from .configuration_mcqbert import MCQBertConfig | |
| class MCQBert(BertModel): | |
| config_class = MCQBertConfig | |
| def __init__(self, config: MCQBertConfig): | |
| super().__init__(config) | |
| if config.integration_strategy is not None: | |
| self.student_embedding_layer = torch.nn.Linear(config.student_embedding_size, config.hidden_size) | |
| cls_input_dim_multiplier = 2 if config.integration_strategy == "cat" else 1 | |
| cls_input_dim = self.config.hidden_size * cls_input_dim_multiplier | |
| self.classifier = torch.nn.Sequential( | |
| torch.nn.Linear(cls_input_dim, config.cls_hidden_size), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(config.cls_hidden_size, 1) | |
| ) | |
| def forward(self, input_ids, student_embeddings=None): | |
| if self.config.integration_strategy is None: | |
| # don't consider embeddings is no integration strategy (MCQBert) | |
| output = super().forward(input_ids) | |
| return self.classifier(output.last_hidden_state[:, 0, :]) | |
| elif self.config.integration_strategy == "cat": | |
| # MCQStudentBertCat | |
| output = super().forward(input_ids) | |
| output_with_student_embedding = torch.cat((output.last_hidden_state[:, 0, :], self.student_embedding_layer(student_embeddings).unsqueeze(0)), dim = 1) | |
| return self.classifier(output_with_student_embedding) | |
| elif self.config.integration_strategy == "sum": | |
| # MCQStudentBertSum | |
| input_embeddings = self.embeddings(input_ids) | |
| combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).repeat(1, input_embeddings.size(1), 1) | |
| output = super().forward(inputs_embeds = combined_embeddings) | |
| return self.classifier(output.last_hidden_state[:, 0, :]) | |
| else: | |
| raise ValueError(f"{self.config.integration_strategy} is not a known integration_strategy") |