Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| from .. import models | |
| class RetroDataModelArguments: | |
| pass | |
| class DataArguments(RetroDataModelArguments): | |
| max_seq_length: int = field( | |
| default=512, | |
| metadata={"help": ""}, | |
| ) | |
| max_answer_length: int = field( | |
| default=30, | |
| metadata={"help": ""}, | |
| ) | |
| doc_stride: int = field( | |
| default=128, | |
| metadata={"help": ""}, | |
| ) | |
| return_token_type_ids: bool = field( | |
| default=True, | |
| metadata={"help": ""}, | |
| ) | |
| pad_to_max_length: bool = field( | |
| default=True, | |
| metadata={"help": ""}, | |
| ) | |
| preprocessing_num_workers: int = field( | |
| default=5, | |
| metadata={"help": ""}, | |
| ) | |
| overwrite_cache: bool = field( | |
| default=False, | |
| metadata={"help": ""}, | |
| ) | |
| version_2_with_negative: bool = field( | |
| default=True, | |
| metadata={"help": ""}, | |
| ) | |
| null_score_diff_threshold: float = field( | |
| default=0.0, | |
| metadata={"help": ""}, | |
| ) | |
| rear_threshold: float = field( | |
| default=0.0, | |
| metadata={"help": ""}, | |
| ) | |
| n_best_size: int = field( | |
| default=20, | |
| metadata={"help": ""}, | |
| ) | |
| use_choice_logits: bool = field( | |
| default=False, | |
| metadata={"help": ""}, | |
| ) | |
| start_n_top: int = field( | |
| default=-1, | |
| metadata={"help": ""}, | |
| ) | |
| end_n_top: int = field( | |
| default=-1, | |
| metadata={"help": ""}, | |
| ) | |
| beta1: int = field( | |
| default=1, | |
| metadata={"help": ""}, | |
| ) | |
| beta2: int = field( | |
| default=1, | |
| metadata={"help": ""}, | |
| ) | |
| best_cof: int = field( | |
| default=1, | |
| metadata={"help": ""}, | |
| ) | |
| class ModelArguments(RetroDataModelArguments): | |
| use_auth_token: bool = field( | |
| default=False, | |
| metadata={"help": ""}, | |
| ) | |
| class SketchModelArguments(ModelArguments): | |
| sketch_revision: str = field( | |
| default="main", | |
| metadata={"help": ""}, | |
| ) | |
| sketch_model_name: str = field( | |
| default="monologg/koelectra-small-v3-discriminator", | |
| metadata={"help": ""}, | |
| ) | |
| sketch_tokenizer_name: str = field( | |
| default=None, | |
| metadata={"help": ""}, | |
| ) | |
| sketch_architectures: str = field( | |
| default="ElectraForSequenceClassification", | |
| metadata={"help": ""}, | |
| ) | |
| class IntensiveModelArguments(ModelArguments): | |
| intensive_revision: str = field( | |
| default="main", | |
| metadata={"help": ""}, | |
| ) | |
| intensive_model_name: str = field( | |
| default="monologg/koelectra-small-v3-discriminator", | |
| metadata={"help": ""}, | |
| ) | |
| intensive_tokenizer_name: str = field( | |
| default=None, | |
| metadata={"help": ""}, | |
| ) | |
| intensive_architectures: str = field( | |
| default="ElectraForQuestionAnsweringAVPool", | |
| metadata={"help": ""}, | |
| ) | |
| class RetroArguments( | |
| DataArguments, | |
| SketchModelArguments, | |
| IntensiveModelArguments, | |
| ): | |
| def __post_init__(self): | |
| # Sketch | |
| model_cls = getattr(models, self.sketch_architectures, None) | |
| if model_cls is None: | |
| raise AttributeError | |
| self.sketch_model_cls = model_cls | |
| self.sketch_model_type = model_cls.model_type | |
| if self.sketch_tokenizer_name is None: | |
| self.sketch_tokenizer_name = self.sketch_model_name | |
| # Intensive | |
| model_cls = getattr(models, self.intensive_architectures, None) | |
| if model_cls is None: | |
| raise AttributeError | |
| self.intensive_model_cls = model_cls | |
| self.intensive_model_type = model_cls.model_type | |
| if self.intensive_tokenizer_name is None: | |
| self.intensive_tokenizer_name = self.intensive_model_name |