Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Tuple | |
| import torch | |
| from mmengine.dist import all_gather, broadcast, get_rank | |
| def batch_shuffle_ddp(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Batch shuffle, for making use of BatchNorm. | |
| Args: | |
| x (torch.Tensor): Data in each GPU. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: Output of shuffle operation. | |
| - x_gather[idx_this]: Shuffled data. | |
| - idx_unshuffle: Index for restoring. | |
| """ | |
| # gather from all gpus | |
| batch_size_this = x.shape[0] | |
| x_gather = torch.cat(all_gather(x), dim=0) | |
| batch_size_all = x_gather.shape[0] | |
| num_gpus = batch_size_all // batch_size_this | |
| # random shuffle index | |
| idx_shuffle = torch.randperm(batch_size_all) | |
| # broadcast to all gpus | |
| broadcast(idx_shuffle, src=0) | |
| # index for restoring | |
| idx_unshuffle = torch.argsort(idx_shuffle) | |
| # shuffled index for this gpu | |
| gpu_idx = get_rank() | |
| idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] | |
| return x_gather[idx_this], idx_unshuffle | |
| def batch_unshuffle_ddp(x: torch.Tensor, | |
| idx_unshuffle: torch.Tensor) -> torch.Tensor: | |
| """Undo batch shuffle. | |
| Args: | |
| x (torch.Tensor): Data in each GPU. | |
| idx_unshuffle (torch.Tensor): Index for restoring. | |
| Returns: | |
| torch.Tensor: Output of unshuffle operation. | |
| """ | |
| # gather from all gpus | |
| batch_size_this = x.shape[0] | |
| x_gather = torch.cat(all_gather(x), dim=0) | |
| batch_size_all = x_gather.shape[0] | |
| num_gpus = batch_size_all // batch_size_this | |
| # restored index for this gpu | |
| gpu_idx = get_rank() | |
| idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] | |
| return x_gather[idx_this] | |