mindchain's picture
Add training script
4f8d4a7 verified
# /// script
# dependencies = ["trl>=0.12.0", "transformers>=4.36.0", "accelerate>=0.24.0", "trackio"]
# ///
from datasets import load_dataset
from transformers import AutoTokenizer, T5ForConditionalGeneration
from trl import SFTTrainer, SFTConfig
dataset = load_dataset("mindchain/container-status-de", split="train")
split = dataset.train_test_split(test_size=0.15, seed=42)
def fmt(ex):
return {"text": f"Status: {ex['text']}", "label": ex["label"]}
train_ds = split["train"].map(fmt, remove_columns=split["train"].column_names)
eval_ds = split["test"].map(fmt, remove_columns=split["test"].column_names)
tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-2-270m")
model = T5ForConditionalGeneration.from_pretrained("google/t5gemma-2-270m")
config = SFTConfig(
output_dir="out",
push_to_hub=True,
hub_model_id="mindchain/t5gemma-270m-container-status",
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=3e-4,
logging_steps=5,
max_length=256,
report_to="trackio",
)
trainer = SFTTrainer(model=model, tokenizer=tokenizer, train_dataset=train_ds, eval_dataset=eval_ds, args=config)
trainer.train()
trainer.push_to_hub()
print('DONE')