Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # Copyright (c) Megvii, Inc. and its affiliates. | |
| import torch | |
| import torch.distributed as dist | |
| from yolox.utils import synchronize | |
| import random | |
| class DataPrefetcher: | |
| """ | |
| DataPrefetcher is inspired by code of following file: | |
| https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py | |
| It could speedup your pytorch dataloader. For more information, please check | |
| https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789. | |
| """ | |
| def __init__(self, loader): | |
| self.loader = iter(loader) | |
| self.stream = torch.cuda.Stream() | |
| self.input_cuda = self._input_cuda_for_image | |
| self.record_stream = DataPrefetcher._record_stream_for_image | |
| self.preload() | |
| def preload(self): | |
| try: | |
| self.next_input, self.next_target, _, _ = next(self.loader) | |
| except StopIteration: | |
| self.next_input = None | |
| self.next_target = None | |
| return | |
| with torch.cuda.stream(self.stream): | |
| self.input_cuda() | |
| self.next_target = self.next_target.cuda(non_blocking=True) | |
| def next(self): | |
| torch.cuda.current_stream().wait_stream(self.stream) | |
| input = self.next_input | |
| target = self.next_target | |
| if input is not None: | |
| self.record_stream(input) | |
| if target is not None: | |
| target.record_stream(torch.cuda.current_stream()) | |
| self.preload() | |
| return input, target | |
| def _input_cuda_for_image(self): | |
| self.next_input = self.next_input.cuda(non_blocking=True) | |
| def _record_stream_for_image(input): | |
| input.record_stream(torch.cuda.current_stream()) | |
| def random_resize(data_loader, exp, epoch, rank, is_distributed): | |
| tensor = torch.LongTensor(1).cuda() | |
| if is_distributed: | |
| synchronize() | |
| if rank == 0: | |
| if epoch > exp.max_epoch - 10: | |
| size = exp.input_size | |
| else: | |
| size = random.randint(*exp.random_size) | |
| size = int(32 * size) | |
| tensor.fill_(size) | |
| if is_distributed: | |
| synchronize() | |
| dist.broadcast(tensor, 0) | |
| input_size = data_loader.change_input_dim(multiple=tensor.item(), random_range=None) | |
| return input_size | |