Spaces:
Configuration error
Configuration error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Ways to make the model stronger.""" | |
| import random | |
| import torch | |
| def power_iteration(m, niters=1, bs=1): | |
| """This is the power method. batch size is used to try multiple starting point in parallel.""" | |
| assert m.dim() == 2 | |
| assert m.shape[0] == m.shape[1] | |
| dim = m.shape[0] | |
| b = torch.randn(dim, bs, device=m.device, dtype=m.dtype) | |
| for _ in range(niters): | |
| n = m.mm(b) | |
| norm = n.norm(dim=0, keepdim=True) | |
| b = n / (1e-10 + norm) | |
| return norm.mean() | |
| # We need a shared RNG to make sure all the distributed worker will skip the penalty together, | |
| # as otherwise we wouldn't get any speed up. | |
| penalty_rng = random.Random(1234) | |
| def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True, | |
| proba=1, conv_only=False, exact=False, bs=1): | |
| """ | |
| Penalty on the largest singular value for a layer. | |
| Args: | |
| - model: model to penalize | |
| - min_size: minimum size in MB of a layer to penalize. | |
| - dim: projection dimension for the svd_lowrank. Higher is better but slower. | |
| - niters: number of iterations in the algorithm used by svd_lowrank. | |
| - powm: use power method instead of lowrank SVD, my own experience | |
| is that it is both slower and less stable. | |
| - convtr: when True, differentiate between Conv and Transposed Conv. | |
| this is kept for compatibility with older experiments. | |
| - proba: probability to apply the penalty. | |
| - conv_only: only apply to conv and conv transposed, not LSTM | |
| (might not be reliable for other models than Demucs). | |
| - exact: use exact SVD (slow but useful at validation). | |
| - bs: batch_size for power method. | |
| """ | |
| total = 0 | |
| if penalty_rng.random() > proba: | |
| return 0. | |
| for m in model.modules(): | |
| for name, p in m.named_parameters(recurse=False): | |
| if p.numel() / 2**18 < min_size: | |
| continue | |
| if convtr: | |
| if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)): | |
| if p.dim() in [3, 4]: | |
| p = p.transpose(0, 1).contiguous() | |
| if p.dim() == 3: | |
| p = p.view(len(p), -1) | |
| elif p.dim() == 4: | |
| p = p.view(len(p), -1) | |
| elif p.dim() == 1: | |
| continue | |
| elif conv_only: | |
| continue | |
| assert p.dim() == 2, (name, p.shape) | |
| if exact: | |
| estimate = torch.svd(p, compute_uv=False)[1].pow(2).max() | |
| elif powm: | |
| a, b = p.shape | |
| if a < b: | |
| n = p.mm(p.t()) | |
| else: | |
| n = p.t().mm(p) | |
| estimate = power_iteration(n, niters, bs) | |
| else: | |
| estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2) | |
| total += estimate | |
| return total / proba | |