Zenyx-Base-220M: High-Density Nano Foundation Model
Zenyx-Base-220M is a 220 million parameter causal language model built from scratch using JAX/Flax on Kaggle TPU v5e-8.
Unlike typical small models trained on limited data, Zenyx-Base was trained on ~153 Billion tokensβfar exceeding the Chinchilla optimal point for this parameter count. This "over-training" strategy was employed to maximize the information density and logic capabilities of the weights, creating a robust foundation for reasoning tasks.
π§ Model Description
- Architecture: Custom Llama-style Transformer (RoPE, SwiGLU, RMSNorm, Grouped Query Attention).
- Tokenizer: Qwen 2.5 Tokenizer (151,650 Vocab Size) for high compression efficiency.
- Context Window: 2048 Tokens.
- Training Hardware: TPU v5e-8.
- Final Validation Loss: ~2.38 (Exceptional convergence for 220M).
Technical Specifications
| Hyperparameter | Value |
|---|---|
| Layers | 12 |
| Hidden Dim | 768 |
| MLP Dim | 3072 |
| Attention Heads | 12 |
| KV Heads | 4 (GQA) |
| Vocab Size | 151,646 |
π Training Curriculum (The "Omni-Mix")
The model was trained using a rigorous 4-stage curriculum designed to layer capabilities sequentially:
- Phase 1: Fundamentals (FineWeb-Edu)
- Focus on high-quality educational English text to establish linguistic baselines.
- Phase 2: Logic & Structure (StarCoder - Python)
- Introduction of code data to enforce logical indentation, syntax, and structured thinking.
- Phase 3: Multilingualism (FineWeb-2)
- Exposure to 6 major languages (Hindi, Chinese, Russian, Japanese, French, Spanish) to expand the semantic embedding space.
- Phase 4: The Infinite Polish (Omni-Mix)
- A weighted interleaving of all previous datasets plus OpenWebMath to converge the model's logic and language capabilities.
π» Usage
This model is a raw JAX/Flax checkpoint saved in .safetensors format. It uses a custom architecture definition and requires flax and jax to run.
Loading with JAX/Flax
import jax
import jax.numpy as jnp
from flax.training import train_state
from flax import serialization
from safetensors.flax import load_file
from transformers import AutoTokenizer
import flax.linen as nn
# 1. Define Architecture (Must match training config)
class TransformerLM(nn.Module):
vocab_size: int
embed_dim: int = 768
num_layers: int = 12
num_heads: int = 12
num_kv_heads: int = 4
mlp_dim: int = 3072
max_length: int = 2048
dropout_rate: float = 0.0
# ... (Insert full model class definition here from the training script) ...
# 2. Load Resources
repo_id = "Arko007/Zenyx_Base_220M"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", trust_remote_code=True)
# 3. Initialize & Load Weights
model = TransformerLM(vocab_size=len(tokenizer))
dummy_input = jnp.ones((1, 1), dtype=jnp.int32)
params = model.init(jax.random.PRNGKey(0), dummy_input)['params']
# Load Safetensors
# Ensure model.safetensors is downloaded locally
loaded_params = load_file("model.safetensors")
print("Weights loaded successfully!")
β οΈ Limitations
- Size: At 220M parameters, the model's knowledge retrieval capacity is limited compared to 7B+ models.
- Base Model: This is a pre-trained base. It has not been fine-tuned for chat or instruction following (see Zenyx-DeepSeek-220M for the instruct version).
- Hallucinations: While logically consistent, it may generate factually incorrect statements.
π Citation
@misc{ZenyxBase220M,
title = {Zenyx-Base-220M: High-Density Foundation Model},
author = {Arko007},
year = {2025},
publisher = {HuggingFace},
url = {[https://huggingface.co/Arko007/Zenyx_Base_220M](https://huggingface.co/Arko007/Zenyx_Base_220M)}
}