# Performance Optimizations (vs. main) Summary of optimizations on branch `perf/pipelined-distributed-muon-clean` relative to `main`. --- ## 1. Batched Momentum (`core.py`) **Before:** Per-param `update_g()` — one `torch.add` + optional `torch.add_` per parameter. **After:** `_batch_pre_ortho()` — `_foreach_mul_`, `_foreach_add_` on lists of local tensors (unwrapped from DTensor). Single fused kernel per batch instead of N individual kernels. **Impact:** Eliminates N per-param Python-loop overhead + N small kernel launches. Scales with parameter count. --- ## 2. Pipeline Buffer Packing (`pipeline.py`) ### Gather send buffer **Before:** Per-param `.to(COMM_DTYPE).contiguous()` followed by per-destination `append` to list, then `torch.cat` on the per-dst lists. **After:** Collect all grad slices in destination order in a single pass, then one `torch.cat` call. Avoids intermediate per-destination lists and redundant dtype conversions. ### Scatter send buffer **Before:** Per-param, per-destination-rank: index `u_full[indices].flatten()`, append to per-dst list, then flatten+cat. **After:** Cache `u_full` conversions (avoid redundant `.to()` per dst_rank). Collect all slices in dst order in one pass, single `torch.cat`. **Impact:** Fewer kernel launches, less Python overhead, reduced intermediate allocations. --- ## 3. Zero-Copy Scatter (`pipeline.py`) **Before:** `_launch_scatter` pre-allocates `torch.empty_like(p.to_local())` for every param. `_complete_scatter` copies from recv_buf into these pre-allocated tensors via `copy_()`. **After:** `_complete_scatter` assigns **views** into `recv_buf` directly (via `recv_buf.narrow(...).view_as(...)`). No pre-allocation, no copy. The recv_buf storage stays alive through the views until `_update_params` consumes them. **Impact:** Eliminates N `empty_like` allocations + N `copy_` kernel launches per scatter stage. --- ## 4. Batched Parameter Update (`pipeline.py`) **Before:** Per-param loop calling `update_p()` (which unwraps DTensor, applies weight decay, applies update individually). **After:** Batched using `_foreach_mul_` (weight decay) and `_foreach_add_` (Muon update), grouped by `adjusted_lr` to preserve float32 alpha precision. Single kernel per group instead of per param. **Impact:** Reduces N per-param kernel launches to 1-2 batched kernel launches. --- ## 5. Parallel Metadata Caching (`muon.py`) **Before:** `init_state_and_assign_params()` called every step — sorts params by FLOP cost, assigns ownership via round-robin, precomputes per-rank indices/numels for all-to-all. **After:** `_parallel_cache` keyed by `tuple(names)`. First call computes and caches `ordered_names`, `name_to_state`, `rank`, `chunk_size`. Subsequent calls reuse cached metadata, only rebuilding `param_to_state` with current `id(p)` keys (since param objects are stable but ids may change for QK clip updates). **Impact:** Eliminates repeated sorting, mesh construction, and index precomputation on every step. --- ## 6. Expert Param Expansion Caching (`muon.py`) **Before:** `_expand_expert_params()` called every step — for each expert param `(E, out, in)`, creates E `nn.Parameter` wrappers (triggers `aten::detach`), indexes data and grad (`aten::select`), and wraps in DTensor for TP. **After:** `_expert_expand_cache` keyed by `tuple(id(p) for p in params)`. Cold path runs `_expand_expert_params` once and caches: - `expanded_names` / `expanded_params` — the nn.Parameter wrappers with stable data views - `grad_info` — per-expert-group metadata (orig param index, num experts, expanded start index, DTensor flag, TP mesh/placements) Hot path reuses cached nn.Parameter objects (data views are stable since optimizer updates happen in-place on the same storage). Only updates `.grad` on each cached expert param by slicing the current step's gradient. **Eliminated on hot path:** - `nn.Parameter()` construction — removes `aten::detach` - `local_data[i]` data slicing — removes half of `aten::select` + `aten::as_strided` - `DTensor.from_local()` for data — only needed for grad now - `is_expert_param()` name matching per step **Still required per step:** - `local_grad[i]` — grad tensor changes each step (nesterov) - `DTensor.from_local(slice_grad, ...)` — for TP expert grads - `p.grad = None` — freeing original 3D grad storage **Impact:** ~8ms CPU overhead reduction per step at production scale (64 GPUs, 48 local experts). --- ## 7. Newton-Schulz Compile + CUDA Graph (`newton_schulz.py`) **Before:** `_zeropower_via_newtonschulz5()` called directly every time. **After:** `zeropower_via_newtonschulz5()` wrapper with per-shape `torch.compile` caching + CUDA graph (`triton.cudagraphs=True`). Each unique shape gets its own compiled function stored in `_ns_per_shape`. Toggled via `set_ns_compile(enabled)`. **Impact:** After warmup, NS iterations run as CUDA graphs — eliminates per-step compilation overhead and CPU-GPU synchronization. --- ## 8. Removed `small_param_numel_threshold` (`muon.py`) **Before:** Small sharded DTensors (below threshold, default 65536) fell back to `distributed_muon()` which used per-param `full_tensor()` + redistribute. **After:** All sharded DTensors go to `parallel()`. `distributed_muon()` is retained as a test-only reference implementation. Uneven shard splits (e.g., MoE gate weights with fewer rows than shard ranks) are handled inline via `full_tensor()` fallback within the batched distributed_muon path. **Impact:** Simpler routing, no silent fallback to slower path. --- ## Summary Table | Optimization | Location | Category | Kernel Launches Saved | |---|---|---|---| | Batched momentum | `core.py` | CPU + GPU | N per-param → 2-3 batched | | Buffer packing (gather) | `pipeline.py` | CPU + GPU | N cat+cast → 1 cat+cast | | Buffer packing (scatter) | `pipeline.py` | CPU + GPU | N cat → 1 cat | | Zero-copy scatter | `pipeline.py` | GPU memory | N alloc+copy → 0 | | Batched param update | `pipeline.py` | CPU + GPU | N update → 1-2 batched | | Parallel metadata cache | `muon.py` | CPU | Sort+index per step → once | | Expert expand cache | `muon.py` | CPU | N detach+select → grad-only | | NS compile + CUDA graph | `newton_schulz.py` | GPU | JIT warmup → graph replay | | Remove small_param_threshold | `muon.py` | Routing | Simpler, unified path |