GLEN-model / test_model_loading.py
QuanTH02's picture
Commit 15-06-v1
6534252
#!/usr/bin/env python3
import sys
import os
sys.path.append('src')
print("Testing model loading...")
try:
import torch
print(f"βœ… PyTorch version: {torch.__version__}")
# Test checkpoint loading
ckpt_path = "logs/test_glen_vault/GLEN_P2_test/checkpoint-7/model.safetensors"
print(f"Checking checkpoint: {ckpt_path}")
if os.path.exists(ckpt_path):
print("βœ… Checkpoint file exists")
# Test loading
print("Testing checkpoint loading...")
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
print(f"βœ… Checkpoint loaded successfully! Keys: {len(state_dict)}")
# Check for 'state_dict' key
if "state_dict" in state_dict:
print("βœ… Found 'state_dict' key")
state_dict = state_dict["state_dict"]
print(f"Final state dict keys: {len(state_dict)}")
else:
print("❌ Checkpoint file not found")
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()