import logging import pytest import torch import torch.distributed as dist from packaging import version from transformers import AutoModelForCausalLM logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) SEED = 0xdeadbeef def pytest_addoption(parser): parser.addoption( "--measure-perf", action="store_true", default=False, help= "Measure execution time and peak memory usage during optimizer step.", ) parser.addoption( "--do-profile", action="store_true", default=False, help="Enable profiling during tests.", ) parser.addoption( "--skip-verify", action="store_true", default=False, help= "Skip verification of optimizer step correctness with sequential implementation.\n" "This can be useful when GPU memory is limited.", ) def pytest_configure(config): if config.getoption( "--do-profile") and not config.getoption("--measure-perf"): raise pytest.UsageError( "--do-profile requires --measure-perf. Please enable both flags.") @pytest.fixture(scope="session") def measure_perf(request): return request.config.getoption("--measure-perf") @pytest.fixture(scope="session") def do_profile(request): return request.config.getoption("--do-profile") @pytest.fixture(scope="session") def skip_verify(request): return request.config.getoption("--skip-verify") @pytest.fixture(scope="session", autouse=True) def init_dist(request): if version.parse(torch.__version__) < version.parse("2.8"): pytest.skip("torch>=2.8.0 is required for parallel muon") return try: dist.init_process_group(backend="nccl") torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) except Exception as e: print(f"Failed to initialize torch.distributed: {e}") pytest.skip("Failed to initialize torch.distributed") if dist.get_world_size() != 8: pytest.skip("Need 8 processes in dist group. " "You can run with `torchrun --nproc-per-node=8 " "--local-ranks-filter 0 -m pytest " "test_rms_norm_sequence_parallel.py`." "To run with less than 8 gpus, modify " "the test cases accordingly.") yield dist.destroy_process_group() @pytest.fixture(scope="session") def inputs(): """Load Motif-2.6B model and generate random gradients for testing. Returns: tuple[torch.nn.Module, list[torch.Tensor], dict[int, torch.Tensor]]: - torch.nn.Module: The Motif-2.6B model. - list[torch.Tensor]: A list of random gradients for each model parameter. - dict[int, torch.Tensor]: A dictionary mapping layer indices to random QK logits. """ model_name = "Motif-Technologies/Motif-2.6B-4layer-random" torch.manual_seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, ) logger.info( f"Loaded model {model_name}. ({len(list(model.parameters()))} parameters)" ) grads: list[torch.Tensor] = [] for param in model.parameters(): grad = torch.randn_like(param, device=param.device, dtype=param.dtype) grads.append(grad) qk_logits: dict[int, torch.Tensor] = { i: torch.randn(model.config.num_attention_heads, device=model.device, dtype=torch.bfloat16) for i in range(model.config.num_hidden_layers) } return [model, grads, qk_logits] def _create_moe_model(num_experts=8, top_k=2, n_layers=4): """Create a torchtitan Llama4 MoE model with random gradients.""" from torchtitan.models.llama4.model.args import TransformerModelArgs from torchtitan.models.llama4.model.model import Transformer from torchtitan.models.moe import MoEArgs torch.manual_seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) moe_args = MoEArgs( num_experts=num_experts, num_shared_experts=1, top_k=top_k, score_func="sigmoid", ) model_args = TransformerModelArgs( dim=2048, n_layers=n_layers, n_heads=16, n_kv_heads=8, vocab_size=32000, norm_eps=1e-5, rope_theta=10000, max_seq_len=4096, moe_args=moe_args, interleave_moe_layer_step=1, ) model = Transformer(model_args) model.init_weights() logger.info(f"Created torchtitan Llama4 MoE model " f"(num_experts={num_experts}, n_layers={n_layers}, " f"{len(list(model.parameters()))} parameters)") grads = [ torch.randn_like(param, device=param.device, dtype=param.dtype) for param in model.parameters() ] return [model, grads] @pytest.fixture(scope="session") def moe_inputs(): """MoE model with 8 experts (standard config).""" return _create_moe_model(num_experts=8, top_k=2) @pytest.fixture(scope="session") def moe_inputs_few_experts(): """MoE model with 2 experts (triggers EFSDP Shard(1) mode).""" return _create_moe_model(num_experts=2, top_k=1)