Spaces:
Running
Running
Update train.py
Browse files
train.py
CHANGED
|
@@ -83,7 +83,24 @@ def format_example(ex):
|
|
| 83 |
|
| 84 |
|
| 85 |
def prepare_dataset(tokenizer, dataset_name):
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
dataset = dataset.map(format_example, remove_columns=dataset.column_names)
|
| 88 |
|
| 89 |
def tokenize(ex):
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
def prepare_dataset(tokenizer, dataset_name):
|
| 86 |
+
"""
|
| 87 |
+
Supports:
|
| 88 |
+
- gsm8k
|
| 89 |
+
- gsm8k:main
|
| 90 |
+
- any_dataset
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
# Auto-fix gsm8k without config
|
| 94 |
+
if dataset_name == "gsm8k":
|
| 95 |
+
dataset_name = "gsm8k:main"
|
| 96 |
+
|
| 97 |
+
# Handle dataset:config format
|
| 98 |
+
if ":" in dataset_name:
|
| 99 |
+
name, config = dataset_name.split(":", 1)
|
| 100 |
+
dataset = load_dataset(name, config, split="train")
|
| 101 |
+
else:
|
| 102 |
+
dataset = load_dataset(dataset_name, split="train")
|
| 103 |
+
|
| 104 |
dataset = dataset.map(format_example, remove_columns=dataset.column_names)
|
| 105 |
|
| 106 |
def tokenize(ex):
|