Spaces:
Sleeping
Sleeping
fix syntax errors
Browse files
app.py
CHANGED
|
@@ -14,44 +14,44 @@ select = st.selectbox('Which model would you like to evaluate?',
|
|
| 14 |
('Bart', 'mBart'))
|
| 15 |
|
| 16 |
def get_datasets():
|
| 17 |
-
if select == 'Bart'
|
| 18 |
all_datasets = ["Communication Networks: unseen questions", "Communication Networks: unseen answers"]
|
| 19 |
-
if select == 'mBart'
|
| 20 |
all_datasets = ["Micro Job: unseen questions", "Micro Job: unseen answers", "Legal Domain: unseen questions", "Legal Domain: unseen answers"]
|
| 21 |
return all_datasets
|
| 22 |
|
| 23 |
all_datasets = get_datasets()
|
| 24 |
|
| 25 |
def get_split(dataset_name):
|
| 26 |
-
if dataset_name == "Communication Networks: unseen questions"
|
| 27 |
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions")
|
| 28 |
-
if dataset_name == "Communication Networks: unseen answers"
|
| 29 |
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers")
|
| 30 |
-
if dataset_name == "Micro Job: unseen questions"
|
| 31 |
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions")
|
| 32 |
-
if dataset_name == "Micro Job: unseen answers"
|
| 33 |
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers")
|
| 34 |
-
if dataset_name == "Legal Domain: unseen questions"
|
| 35 |
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions")
|
| 36 |
-
if dataset_name == "Legal Domain: unseen answers"
|
| 37 |
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers")
|
| 38 |
return split
|
| 39 |
|
| 40 |
def get_model(datasetname):
|
| 41 |
-
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers"
|
| 42 |
model = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
| 43 |
-
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers"
|
| 44 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
| 45 |
-
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers"
|
| 46 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
| 47 |
return model
|
| 48 |
|
| 49 |
def get_tokenizer(datasetname):
|
| 50 |
-
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers"
|
| 51 |
tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
| 52 |
-
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers"
|
| 53 |
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
| 54 |
-
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers"
|
| 55 |
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
| 56 |
return tokenizer
|
| 57 |
|
|
@@ -212,7 +212,7 @@ def load_data():
|
|
| 212 |
predicted_labels = extract_labels(predictions)
|
| 213 |
|
| 214 |
reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
|
| 215 |
-
|
| 216 |
|
| 217 |
rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
|
| 218 |
bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
|
|
|
|
| 14 |
('Bart', 'mBart'))
|
| 15 |
|
| 16 |
def get_datasets():
|
| 17 |
+
if select == 'Bart':
|
| 18 |
all_datasets = ["Communication Networks: unseen questions", "Communication Networks: unseen answers"]
|
| 19 |
+
if select == 'mBart':
|
| 20 |
all_datasets = ["Micro Job: unseen questions", "Micro Job: unseen answers", "Legal Domain: unseen questions", "Legal Domain: unseen answers"]
|
| 21 |
return all_datasets
|
| 22 |
|
| 23 |
all_datasets = get_datasets()
|
| 24 |
|
| 25 |
def get_split(dataset_name):
|
| 26 |
+
if dataset_name == "Communication Networks: unseen questions":
|
| 27 |
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions")
|
| 28 |
+
if dataset_name == "Communication Networks: unseen answers":
|
| 29 |
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers")
|
| 30 |
+
if dataset_name == "Micro Job: unseen questions":
|
| 31 |
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions")
|
| 32 |
+
if dataset_name == "Micro Job: unseen answers":
|
| 33 |
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers")
|
| 34 |
+
if dataset_name == "Legal Domain: unseen questions":
|
| 35 |
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions")
|
| 36 |
+
if dataset_name == "Legal Domain: unseen answers":
|
| 37 |
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers")
|
| 38 |
return split
|
| 39 |
|
| 40 |
def get_model(datasetname):
|
| 41 |
+
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
|
| 42 |
model = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
| 43 |
+
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers":
|
| 44 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
| 45 |
+
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers":
|
| 46 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
| 47 |
return model
|
| 48 |
|
| 49 |
def get_tokenizer(datasetname):
|
| 50 |
+
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
|
| 51 |
tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
| 52 |
+
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers":
|
| 53 |
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
| 54 |
+
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers":
|
| 55 |
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
| 56 |
return tokenizer
|
| 57 |
|
|
|
|
| 212 |
predicted_labels = extract_labels(predictions)
|
| 213 |
|
| 214 |
reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
|
| 215 |
+
reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
|
| 216 |
|
| 217 |
rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
|
| 218 |
bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
|