Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # pyre-unsafe | |
| from typing import Optional, Union | |
| import torch | |
| Device = Union[str, torch.device] | |
| def make_device(device: Device) -> torch.device: | |
| """ | |
| Makes an actual torch.device object from the device specified as | |
| either a string or torch.device object. If the device is `cuda` without | |
| a specific index, the index of the current device is assigned. | |
| Args: | |
| device: Device (as str or torch.device) | |
| Returns: | |
| A matching torch.device object | |
| """ | |
| device = torch.device(device) if isinstance(device, str) else device | |
| if device.type == "cuda" and device.index is None: | |
| # If cuda but with no index, then the current cuda device is indicated. | |
| # In that case, we fix to that device | |
| device = torch.device(f"cuda:{torch.cuda.current_device()}") | |
| return device | |
| def get_device(x, device: Optional[Device] = None) -> torch.device: | |
| """ | |
| Gets the device of the specified variable x if it is a tensor, or | |
| falls back to a default CPU device otherwise. Allows overriding by | |
| providing an explicit device. | |
| Args: | |
| x: a torch.Tensor to get the device from or another type | |
| device: Device (as str or torch.device) to fall back to | |
| Returns: | |
| A matching torch.device object | |
| """ | |
| # User overrides device | |
| if device is not None: | |
| return make_device(device) | |
| # Set device based on input tensor | |
| if torch.is_tensor(x): | |
| return x.device | |
| # Default device is cpu | |
| return torch.device("cpu") | |