UltraThinking-LLM-Training / train_advanced.py
Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""
Advanced Training Script for ULTRATHINK Model
Supports configuration files, all advanced features, and production deployment
"""
import os
import sys
import yaml
import argparse
from pathlib import Path
import logging
from typing import Dict, Any
# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
logger = logging.getLogger(__name__)
def load_config(config_path: str) -> Dict[str, Any]:
"""Load configuration from YAML file"""
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config
def merge_configs(base_config: Dict, overrides: Dict) -> Dict:
"""Merge override config into base config"""
result = base_config.copy()
for key, value in overrides.items():
if isinstance(value, dict) and key in result and isinstance(result[key], dict):
result[key] = merge_configs(result[key], value)
else:
result[key] = value
return result
def config_to_args(config: Dict) -> argparse.Namespace:
"""Convert config dict to argparse Namespace for compatibility"""
args = argparse.Namespace()
# Helper to ensure numeric types
def to_int(val, default):
return int(val) if val is not None else default
def to_float(val, default):
return float(val) if val is not None else default
def to_bool(val, default):
if val is None:
return default
if isinstance(val, bool):
return val
return str(val).lower() in ('true', '1', 'yes')
# Model config
model = config.get('model', {})
args.vocab_size = to_int(model.get('vocab_size'), 100352)
args.hidden_size = to_int(model.get('hidden_size'), 4096)
args.num_layers = to_int(model.get('num_layers'), 32)
args.num_heads = to_int(model.get('num_heads'), 32)
args.num_kv_heads = to_int(model.get('num_kv_heads'), 8)
args.intermediate_size = to_int(model.get('intermediate_size'), 14336)
args.max_seq_length = to_int(model.get('max_seq_length'), 8192)
args.activation = model.get('activation', 'swiglu')
args.dropout = to_float(model.get('dropout'), 0.0)
args.attention_dropout = to_float(model.get('attention_dropout'), 0.0)
args.use_flash_attention = to_bool(model.get('use_flash_attention'), False)
args.gradient_checkpointing = to_bool(model.get('gradient_checkpointing'), False)
# Advanced features
advanced = config.get('advanced', {})
args.enable_moe = to_bool(advanced.get('enable_moe'), False)
args.enable_dre = to_bool(advanced.get('enable_dre'), False)
args.enable_constitutional = to_bool(advanced.get('enable_constitutional'), False)
args.enable_rlhf = to_bool(advanced.get('enable_rlhf'), False)
args.enable_multimodal = to_bool(advanced.get('enable_multimodal'), False)
args.dre_warmup_steps = to_int(advanced.get('dre_warmup_steps'), 0)
# MoE config
moe = config.get('moe', {})
args.num_knowledge_experts = to_int(moe.get('num_knowledge_experts'), 64)
args.num_skill_experts = to_int(moe.get('num_skill_experts'), 32)
args.num_meta_experts = to_int(moe.get('num_meta_experts'), 16)
args.num_safety_experts = to_int(moe.get('num_safety_experts'), 8)
args.moe_top_k = to_int(moe.get('moe_top_k'), 2)
args.expert_capacity = to_float(moe.get('expert_capacity'), 1.25)
# Multimodal config
multimodal = config.get('multimodal', {})
args.image_size = to_int(multimodal.get('image_size'), 224)
args.patch_size = to_int(multimodal.get('patch_size'), 14)
args.audio_sample_rate = to_int(multimodal.get('audio_sample_rate'), 16000)
# Training config
training = config.get('training', {})
args.batch_size = to_int(training.get('batch_size'), 32)
args.gradient_accumulation_steps = to_int(training.get('gradient_accumulation_steps'), 4)
args.learning_rate = to_float(training.get('learning_rate'), 3e-5)
args.weight_decay = to_float(training.get('weight_decay'), 0.01)
args.adam_beta1 = to_float(training.get('adam_beta1'), 0.9)
args.adam_beta2 = to_float(training.get('adam_beta2'), 0.999)
args.warmup_steps = to_int(training.get('warmup_steps'), 10000)
args.max_steps = to_int(training.get('max_steps'), 1000000)
args.num_epochs = to_int(training.get('num_epochs'), 3)
args.gradient_clipping = to_float(training.get('gradient_clipping'), 1.0)
args.use_amp = to_bool(training.get('use_amp'), False)
# Distributed config
distributed = config.get('distributed', {})
args.distributed = to_bool(distributed.get('enabled'), False)
args.use_4d_parallelism = to_bool(distributed.get('use_4d_parallelism'), False)
args.data_parallel_size = to_int(distributed.get('data_parallel_size'), 1)
args.tensor_parallel_size = to_int(distributed.get('tensor_parallel_size'), 1)
args.pipeline_parallel_size = to_int(distributed.get('pipeline_parallel_size'), 1)
args.expert_parallel_size = to_int(distributed.get('expert_parallel_size'), 1)
args.zero_stage = to_int(distributed.get('zero_stage'), 0)
args.deepspeed = distributed.get('deepspeed_config', None)
args.launcher = distributed.get('launcher', 'none')
# Data config
data = config.get('data', {})
args.dataset = data.get('dataset', 'wikitext')
args.mix_datasets = data.get('mix_datasets', None)
args.dataset_subset = data.get('dataset_subset', None)
args.data_path = data.get('data_path', None)
args.text_column = data.get('text_column', 'text')
args.tokenizer_name = data.get('tokenizer_name', 'gpt2')
args.max_samples = to_int(data.get('max_samples'), None) if data.get('max_samples') is not None else None
args.train_samples = to_int(data.get('train_samples'), 10000)
args.val_samples = to_int(data.get('val_samples'), 1000)
args.num_workers = to_int(data.get('num_workers'), 4)
args.streaming = to_bool(data.get('streaming'), False)
args.use_synthetic_data = to_bool(data.get('use_synthetic_data'), False)
args.synthetic_samples = to_int(data.get('synthetic_samples'), 5000)
# RLHF config
rlhf = config.get('rlhf', {})
args.rlhf_frequency = to_int(rlhf.get('rlhf_frequency'), 5)
args.rlhf_iterations = to_int(rlhf.get('rlhf_iterations'), 100)
args.rlhf_steps_per_iteration = to_int(rlhf.get('rlhf_steps_per_iteration'), 1000)
args.ppo_epochs = to_int(rlhf.get('ppo_epochs'), 4)
args.ppo_batch_size = to_int(rlhf.get('ppo_batch_size'), 32)
# Evaluation config
evaluation = config.get('evaluation', {})
args.eval_frequency = to_int(evaluation.get('eval_frequency'), 5)
# Logging config
logging_cfg = config.get('logging', {})
args.use_mlflow = to_bool(logging_cfg.get('use_mlflow'), False)
args.mlflow_tracking_uri = logging_cfg.get('mlflow_tracking_uri', 'file:./mlruns')
args.mlflow_experiment = logging_cfg.get('mlflow_experiment', 'UltraThinking-LLM-Training')
args.run_name = logging_cfg.get('run_name', 'ultrathink_training')
args.use_wandb = False # Deprecated
# Output config
output = config.get('output', {})
args.output_dir = output.get('output_dir', './outputs/ultrathink')
# Resume/init
args.init_from_model_dir = None
args.resume_checkpoint = None
args.continuous = False
return args
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Advanced Training Script for ULTRATHINK')
parser.add_argument(
'--config',
type=str,
required=True,
help='Path to YAML configuration file (e.g., configs/train_small.yaml)'
)
parser.add_argument(
'--override',
type=str,
nargs='*',
help='Override config values (e.g., training.batch_size=16 model.hidden_size=512)'
)
parser.add_argument(
'--resume',
type=str,
default=None,
help='Path to checkpoint to resume from'
)
parser.add_argument(
'--init-from',
type=str,
default=None,
help='Path to pretrained model directory to initialize from'
)
parser.add_argument(
'--continuous',
action='store_true',
help='Train continuously until interrupted'
)
parser.add_argument(
'--run-name',
type=str,
default=None,
help='Override run name from config'
)
return parser.parse_args()
def apply_overrides(config: Dict, overrides: list) -> Dict:
"""Apply command-line overrides to config"""
if not overrides:
return config
for override in overrides:
if '=' not in override:
logger.warning(f"Invalid override format: {override}. Use key=value")
continue
key_path, value = override.split('=', 1)
keys = key_path.split('.')
# Try to convert value to appropriate type
try:
if value.lower() == 'true':
value = True
elif value.lower() == 'false':
value = False
elif value.lower() == 'null' or value.lower() == 'none':
value = None
elif '.' in value:
value = float(value)
else:
try:
value = int(value)
except ValueError:
pass # Keep as string
except Exception:
pass # Keep as string
# Apply override
current = config
for key in keys[:-1]:
if key not in current:
current[key] = {}
current = current[key]
current[keys[-1]] = value
return config
def main():
"""Main training function"""
# Parse arguments
args = parse_args()
# Load base configuration
logger.info(f"Loading configuration from {args.config}")
config = load_config(args.config)
# Apply overrides
if args.override:
logger.info(f"Applying overrides: {args.override}")
config = apply_overrides(config, args.override)
# Convert to argparse Namespace
train_args = config_to_args(config)
# Apply resume/init flags
if args.resume:
train_args.resume_checkpoint = args.resume
if args.init_from:
train_args.init_from_model_dir = args.init_from
if args.continuous:
train_args.continuous = True
if args.run_name:
train_args.run_name = args.run_name
# Create output directory
Path(train_args.output_dir).mkdir(parents=True, exist_ok=True)
# Setup logging
log_file = os.path.join(train_args.output_dir, 'training.log')
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
# Save effective configuration
config_save_path = os.path.join(train_args.output_dir, 'effective_config.yaml')
with open(config_save_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
logger.info(f"Saved effective configuration to {config_save_path}")
# Log configuration
logger.info("=" * 80)
logger.info("ULTRATHINK Advanced Training")
logger.info("=" * 80)
logger.info(f"Configuration: {args.config}")
logger.info(f"Output directory: {train_args.output_dir}")
logger.info(f"Run name: {train_args.run_name}")
logger.info("")
logger.info("Model Configuration:")
logger.info(f" Hidden size: {train_args.hidden_size}")
logger.info(f" Layers: {train_args.num_layers}")
logger.info(f" Heads: {train_args.num_heads}")
logger.info(f" Sequence length: {train_args.max_seq_length}")
logger.info("")
logger.info("Advanced Features:")
logger.info(f" MoE: {train_args.enable_moe}")
logger.info(f" DRE: {train_args.enable_dre}")
logger.info(f" Constitutional AI: {train_args.enable_constitutional}")
logger.info(f" RLHF: {train_args.enable_rlhf}")
logger.info(f" Multimodal: {train_args.enable_multimodal}")
logger.info("")
logger.info("Training Configuration:")
logger.info(f" Batch size: {train_args.batch_size}")
logger.info(f" Gradient accumulation: {train_args.gradient_accumulation_steps}")
logger.info(f" Learning rate: {train_args.learning_rate}")
logger.info(f" Epochs: {train_args.num_epochs}")
logger.info("=" * 80)
# Import and run training
from train_ultrathink import UltraThinkTrainer
import mlflow
# Create trainer
trainer = UltraThinkTrainer(train_args)
# Start MLflow run if enabled
active_mlflow = False
if train_args.use_mlflow and trainer.is_main_process():
try:
mlflow.set_tracking_uri(train_args.mlflow_tracking_uri)
mlflow.set_experiment(train_args.mlflow_experiment)
mlflow.start_run(run_name=train_args.run_name)
# Log configuration as params
safe_params = {
k: (str(v) if not isinstance(v, (str, int, float, bool, type(None))) else v)
for k, v in vars(train_args).items()
}
mlflow.log_params(safe_params)
# Log config file as artifact
mlflow.log_artifact(args.config, artifact_path='config')
if os.path.exists(config_save_path):
mlflow.log_artifact(config_save_path, artifact_path='config')
active_mlflow = True
logger.info(f"MLflow tracking enabled: {train_args.mlflow_tracking_uri}")
except Exception as e:
logger.warning(f"Failed to start MLflow run: {e}")
active_mlflow = False
# Run training
try:
results = trainer.train()
# Log final results
logger.info("=" * 80)
logger.info("Training completed successfully!")
logger.info(f"Final results: {results}")
logger.info("=" * 80)
except KeyboardInterrupt:
logger.info("Training interrupted by user")
results = {'status': 'interrupted'}
except Exception as e:
logger.error(f"Training failed with error: {e}", exc_info=True)
results = {'status': 'failed', 'error': str(e)}
raise
finally:
# Cleanup MLflow
if train_args.use_mlflow and trainer.is_main_process() and active_mlflow:
try:
results_path = os.path.join(train_args.output_dir, 'evaluation_results.json')
if os.path.exists(results_path):
mlflow.log_artifact(results_path, artifact_path='evaluation')
finally:
try:
mlflow.end_run()
except Exception:
pass
if __name__ == "__main__":
main()