import copy import logging import time from contextlib import nullcontext import pytest import torch import torch.distributed as dist from optimizer.muon import Muon, get_default_muon_param_groups from torch.distributed.tensor import DTensor, Replicate from torch.profiler import ProfilerActivity, profile from .utils import (ParallelDims, assert_params_equal, parallelize_motif, parallelize_qk_logits) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def apply_muon_step( model: torch.nn.Module, parallel_dims: ParallelDims | None, grads: list[torch.Tensor], warmup_step: int, chunk_size: int, qk_logits: dict[int, torch.Tensor] | None = None, use_distributed_muon: bool = False, measure_perf: bool = False, do_profile: bool = False, ) -> tuple[torch.nn.Module, tuple[float, float] | None]: """ apply single Muon step with optional QK clipping """ # 1. Apply gradients to model parameters assert len(grads) == len(list(model.parameters())) for grad, param in zip(grads, model.parameters()): grad = grad.to(param.device) if isinstance(param.data, DTensor): unsharded_grad = DTensor.from_local( grad, device_mesh=param.data.device_mesh, placements=[Replicate()] * param.data.device_mesh.ndim, ) sharded_grad = unsharded_grad.redistribute( device_mesh=param.data.device_mesh, placements=param.data.placements) param.grad = sharded_grad else: param.grad = grad # 2. Setup Muon optimizer params = get_default_muon_param_groups(model) clip_config = dict({ "q_indices": list(range(model.config.num_attention_heads)), "k_indices": list(range(model.config.num_attention_heads)), "head_dim": model.config.hidden_size // model.config.num_attention_heads, "threshold": 0.5 }) optim = Muon( params=params, clip_config=clip_config if qk_logits is not None else None, none_grad=False, warmup_step=warmup_step, chunk_size=chunk_size, use_distributed_muon=use_distributed_muon, ) optim.step(qk_logits=qk_logits) timing_result: tuple[float, float] | None = None if measure_perf: # extra warm up optim.step(qk_logits=qk_logits) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() num_iters = 20 current_mem = torch.cuda.memory_allocated() if do_profile: context = profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) else: context = nullcontext() with context as prof: for _i in range(num_iters): optim.step(qk_logits=qk_logits) end.record() end.synchronize() if prof is not None and dist.get_rank() == 0: date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) profile_name = "trace" profile_name += f"_{date}" profile_name += f"_{parallel_dims}" profile_name += f"_{chunk_size}" profile_name += f"_{warmup_step}" profile_name += f"_{qk_logits is not None}" profile_name += f"_{use_distributed_muon}" prof.export_chrome_trace(f"{profile_name}.json") peak_memory = torch.cuda.max_memory_allocated() - current_mem elapsed_time_ms = start.elapsed_time(end) / num_iters timing_result = (elapsed_time_ms, peak_memory) return model, timing_result @pytest.fixture(scope="session") def sequential_muon_result( skip_verify, # from conftest.py inputs # from conftest.py ) -> dict[bool, torch.nn.Module]: """Run Muon optimizer to sequential model for baseline results.""" if skip_verify: logger.info("Skipping verification tests as per user request") return None model, grads, qk_logits = inputs result = apply_muon_step( model=copy.deepcopy(model).cuda(), parallel_dims=None, grads=grads, warmup_step=-1, chunk_size=-1, qk_logits=None, )[0].cpu() result_qk_clip = apply_muon_step( model=copy.deepcopy(model).cuda(), parallel_dims=None, grads=grads, warmup_step=-1, chunk_size=-1, qk_logits=qk_logits, )[0].cpu() return { False: result, True: result_qk_clip, } OVERLAP_STEPS = [5] CHUNK_SIZES = [8] @pytest.mark.parametrize("parallel_dims", [ pytest.param(ParallelDims(8, 1, 1), id="base"), pytest.param(ParallelDims(1, 8, 1), id="fsdp"), pytest.param(ParallelDims(2, 4, 1), id="hsdp"), pytest.param(ParallelDims(1, 1, 8), id="tp"), pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"), pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"), ]) @pytest.mark.parametrize("apply_qk_clip", [False, True]) @pytest.mark.parametrize("use_distributed_muon", [False]) @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) @pytest.mark.parametrize("chunk_size", CHUNK_SIZES) def test_parallel_muon( request, sequential_muon_result: dict[bool, torch.nn.Module], parallel_dims: ParallelDims, apply_qk_clip: bool, use_distributed_muon: bool, warmup_step: int, chunk_size: int, inputs: tuple[torch.nn.Module, list[torch.Tensor], dict[int, torch.Tensor]], # from conftest.py measure_perf, # from conftest.py do_profile, # from conftest.py ) -> None: if use_distributed_muon and chunk_size != CHUNK_SIZES[0]: pytest.skip("Distributed Muon does not effected by chunk size") if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]: pytest.skip("Distributed Muon does not effected by warmup step") model, grads, qk_logits = inputs if not apply_qk_clip: qk_logits = None # Deepcopy the model to avoid in-place modification model = copy.deepcopy(model).cuda() parallelized_model = parallelize_motif(model, parallel_dims) if qk_logits is not None: # Deepcopy the qk logits to avoid in-place modification qk_logits = copy.deepcopy(qk_logits) qk_logits = parallelize_qk_logits(qk_logits, parallel_dims) parallelized_model, timing_result = apply_muon_step( model=parallelized_model, parallel_dims=parallel_dims, grads=grads, warmup_step=warmup_step, chunk_size=chunk_size, qk_logits=qk_logits, use_distributed_muon=use_distributed_muon, measure_perf=measure_perf, do_profile=do_profile, ) if measure_perf: assert timing_result is not None avg_time_ms, peak_memory = timing_result logger.info( f"\nParallel dims: {parallel_dims}, " f"\nUse distributed Muon: {use_distributed_muon}, " f"\nApply QK clip: {apply_qk_clip} => " f"\nChunk Size, Warmup Step, Avg Time (ms), Peak Memory (MB):" f"\n{chunk_size}, {warmup_step}, {avg_time_ms:.2f}, {peak_memory / (1024**2):.2f}," ) if sequential_muon_result is None: logger.info("Skipping correctness check as sequential result is None") elif measure_perf: logger.info("Skipping correctness check as timing is enabled") else: assert_params_equal(parallelized_model, sequential_muon_result[apply_qk_clip])