# /// script # dependencies = ["trl>=0.12.0", "transformers>=4.36.0", "accelerate>=0.24.0", "trackio"] # /// from datasets import load_dataset from transformers import AutoTokenizer, T5ForConditionalGeneration from trl import SFTTrainer, SFTConfig dataset = load_dataset("mindchain/container-status-de", split="train") split = dataset.train_test_split(test_size=0.15, seed=42) def fmt(ex): return {"text": f"Status: {ex['text']}", "label": ex["label"]} train_ds = split["train"].map(fmt, remove_columns=split["train"].column_names) eval_ds = split["test"].map(fmt, remove_columns=split["test"].column_names) tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-2-270m") model = T5ForConditionalGeneration.from_pretrained("google/t5gemma-2-270m") config = SFTConfig( output_dir="out", push_to_hub=True, hub_model_id="mindchain/t5gemma-270m-container-status", num_train_epochs=5, per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=3e-4, logging_steps=5, max_length=256, report_to="trackio", ) trainer = SFTTrainer(model=model, tokenizer=tokenizer, train_dataset=train_ds, eval_dataset=eval_ds, args=config) trainer.train() trainer.push_to_hub() print('DONE')