TLDR:
This is an ensemble optimization adapted to standard models. This will yield high-capacity speed improvements through increased throughput for inference and training alike using carefully traced staged vmap structures.
https://github.com/AbstractEyes/pytorch-parallel-compiler
The early list of layers isn't fully represented yet, so this is a preliminary look into the potentials of this structure when fully fleshed out.
MLP (N=100, batch=32, CUDA):
Eager: 2-3x speedup
Compiled: 35-40x speedupResBlock (N=20, batch=8, CUDA):
Eager: ~5x speedup
Compiled: ~10x speedupThis is early testing and so far the yields indicate that WIDENING your model with adjacent shared batched vmaps for uniformly staged models will yield considerably higher output for inference at the cost of additional hardware utilization.
This is akin to lining up all your systems and uniformly passing the necessary implications through a shared frozen representation gate.
Training for this is not tested nor supported yet, use at your own risk.