himu1780 commited on
Commit
8ff6579
·
verified ·
1 Parent(s): cd3e6b8

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +18 -1
train.py CHANGED
@@ -83,7 +83,24 @@ def format_example(ex):
83
 
84
 
85
  def prepare_dataset(tokenizer, dataset_name):
86
- dataset = load_dataset(dataset_name, split="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  dataset = dataset.map(format_example, remove_columns=dataset.column_names)
88
 
89
  def tokenize(ex):
 
83
 
84
 
85
  def prepare_dataset(tokenizer, dataset_name):
86
+ """
87
+ Supports:
88
+ - gsm8k
89
+ - gsm8k:main
90
+ - any_dataset
91
+ """
92
+
93
+ # Auto-fix gsm8k without config
94
+ if dataset_name == "gsm8k":
95
+ dataset_name = "gsm8k:main"
96
+
97
+ # Handle dataset:config format
98
+ if ":" in dataset_name:
99
+ name, config = dataset_name.split(":", 1)
100
+ dataset = load_dataset(name, config, split="train")
101
+ else:
102
+ dataset = load_dataset(dataset_name, split="train")
103
+
104
  dataset = dataset.map(format_example, remove_columns=dataset.column_names)
105
 
106
  def tokenize(ex):