Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add the support for musa device support #1453

Merged
merged 6 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mmengine/device/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_dipu_available, is_mlu_available, is_mps_available,
is_npu_available, is_npu_support_full_precision)
from .utils import (get_device, get_max_cuda_memory, get_max_musa_memory,
is_cuda_available, is_dipu_available, is_mlu_available,
is_mps_available, is_musa_available, is_npu_available,
is_npu_support_full_precision)

__all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available', 'is_mps_available', 'is_npu_available',
'is_dipu_available', 'is_npu_support_full_precision'
'is_dipu_available', 'get_max_musa_memory', 'is_musa_available',
'is_npu_support_full_precision'
]
38 changes: 37 additions & 1 deletion mmengine/device/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
except Exception:
IS_DIPU_AVAILABLE = False

try:
import torch_musa # noqa: F401
IS_MUSA_AVAILABLE = True
except Exception:
IS_MUSA_AVAILABLE = False


def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
Expand Down Expand Up @@ -73,6 +79,34 @@ def is_dipu_available() -> bool:
return IS_DIPU_AVAILABLE


def get_max_musa_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
a given device. By default, this returns the peak allocated memory since
the beginning of this program.

Args:
device (torch.device, optional): selected device. Returns
statistic for the current device, given by
:func:`~torch.musa.current_device`, if ``device`` is None.
Defaults to None.

Returns:
int: The maximum GPU memory occupied by tensors in megabytes
for a given device.
"""
mem = torch.musa.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
# TODO:haowen.han@mthreads.com: This function is not supported by musa yet.
# torch.musa.reset_peak_memory_stats()
return int(mem_mb.item())


def is_musa_available() -> bool:
return IS_MUSA_AVAILABLE


def is_npu_support_full_precision() -> bool:
"""Returns True if npu devices support full precision training."""
version_of_support_full_precision = 220
Expand All @@ -91,12 +125,14 @@ def is_npu_support_full_precision() -> bool:
DEVICE = 'mps'
elif is_dipu_available():
DEVICE = 'dipu'
elif is_musa_available():
DEVICE = 'musa'


def get_device() -> str:
"""Returns the currently existing device type.

Returns:
str: cuda | npu | mlu | mps | cpu.
str: cuda | npu | mlu | mps | musa | cpu.
"""
return DEVICE
17 changes: 17 additions & 0 deletions mmengine/dist/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,16 @@ def _broadcast_object_list(object_list: List[Any],
current_device = torch.device('cpu')
is_hccl_backend = group_backend == 'hccl'
is_cncl_backend = group_backend == 'cncl'
is_mccl_backend = group_backend == 'mccl'
if is_hccl_backend:
current_device = torch.device('npu', torch.npu.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_cncl_backend:
current_device = torch.device('mlu', torch.mlu.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_mccl_backend:
current_device = torch.device('musa', torch.musa.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_nccl_backend:
# See note about using torch.cuda.current_device() here in
# docstring. We cannot simply use my_rank since rank == device is
Expand Down Expand Up @@ -624,13 +628,21 @@ def _all_gather_object(object_list: List[Any],
group_backend = get_backend(group)
current_device = torch.device('cpu')
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
is_mccl_backend = group_backend == 'mccl'
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
elif is_mccl_backend:
# See note about using torch.musa.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device('musa', torch.musa.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and
# index until the correct size when deserializing the tensors.
group_size = get_world_size(group=group)
Expand Down Expand Up @@ -776,10 +788,15 @@ def _gather_object(obj: Any,
group_backend = get_backend(group)
current_device = torch.device('cpu')
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
is_mccl_backend = group_backend == 'mccl'
if is_nccl_backend:
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
elif is_mccl_backend:
current_device = torch.device('musa', torch.musa.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and
# index until the correct size when deserializing the tensors.
group_size = get_world_size(group=group)
Expand Down
14 changes: 13 additions & 1 deletion mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from torch import Tensor
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
from mmengine.device import is_mlu_available, is_npu_available
from mmengine.device import (is_mlu_available, is_npu_available,
is_musa_available)

from collections.abc import Iterable, Mapping

Expand Down Expand Up @@ -116,6 +117,14 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
elif is_musa_available():
import torch_musa # noqa: F401
torch.musa.set_device(rank)
torch_dist.init_process_group(
backend='mccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
Expand Down Expand Up @@ -528,6 +537,9 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
return torch.device('mlu', torch.mlu.current_device())
elif backend == 'smddp':
return torch.device('cuda', torch.cuda.current_device())
elif backend == 'mccl':
import torch_musa
return torch.device('musa', torch_musa.current_device())
else:
# GLOO and MPI backends use cpu device by default
return torch.device('cpu')
Expand Down
16 changes: 13 additions & 3 deletions mmengine/hooks/empty_cache_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from mmengine.registry import HOOKS
from ..device import is_cuda_available, is_musa_available
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]
Expand Down Expand Up @@ -49,7 +50,10 @@ def _after_iter(self,
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_iter:
torch.cuda.empty_cache()
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.musa.empty_cache()

def _before_epoch(self, runner, mode: str = 'train') -> None:
"""Empty cache before an epoch.
Expand All @@ -59,7 +63,10 @@ def _before_epoch(self, runner, mode: str = 'train') -> None:
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_before_epoch:
torch.cuda.empty_cache()
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.musa.empty_cache()

def _after_epoch(self, runner, mode: str = 'train') -> None:
"""Empty cache after an epoch.
Expand All @@ -69,4 +76,7 @@ def _after_epoch(self, runner, mode: str = 'train') -> None:
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_epoch:
torch.cuda.empty_cache()
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.musa.empty_cache()
46 changes: 31 additions & 15 deletions mmengine/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,22 +398,38 @@ def _get_device_id():
except ImportError:
return 0
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
# TODO: return device id of npu and mlu.
if not torch.cuda.is_available():
return local_rank
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
if cuda_visible_devices is None:
num_device = torch.cuda.device_count()
cuda_visible_devices = list(range(num_device))
else:
cuda_visible_devices = cuda_visible_devices.split(',')
MUSA_AVAILABLE = False
try:
return int(cuda_visible_devices[local_rank])
except ValueError:
# handle case for Multi-Instance GPUs
# see #1148 for details
return cuda_visible_devices[local_rank]
import torch_musa
MUSA_AVAILABLE = True
except ImportError:
pass
if MUSA_AVAILABLE:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
musa_visible_devices = os.getenv('MUSA_VISIBLE_DEVICES', None)
if musa_visible_devices is None:
num_device = torch_musa.device_count()
musa_visible_devices = list(range(num_device))
else:
musa_visible_devices = musa_visible_devices.split(',')
return int(musa_visible_devices[local_rank])
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
# TODO: return device id of npu and mlu.
if not torch.cuda.is_available():
return local_rank
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
if cuda_visible_devices is None:
num_device = torch.cuda.device_count()
cuda_visible_devices = list(range(num_device))
else:
cuda_visible_devices = cuda_visible_devices.split(',')
try:
return int(cuda_visible_devices[local_rank])
except ValueError:
# handle case for Multi-Instance GPUs
# see #1148 for details
return cuda_visible_devices[local_rank]


def _get_host_info() -> str:
Expand Down
15 changes: 15 additions & 0 deletions mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,21 @@ def cuda(
self._set_device(torch.device(device))
return super().cuda(device)

def musa(
self,
device: Optional[Union[int, str, torch.device]] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.musa`
additionally.

Returns:
nn.Module: The model itself.
"""
if device is None or isinstance(device, int):
device = torch.device('musa', index=device)
self._set_device(torch.device(device))
return super().musa(device)

def mlu(
self,
device: Union[int, str, torch.device, None] = None,
Expand Down
9 changes: 9 additions & 0 deletions mmengine/model/base_model/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def cuda(self, *args, **kwargs) -> nn.Module:
self._device = torch.device(torch.cuda.current_device())
return super().cuda()

def musa(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`

Returns:
nn.Module: The model itself.
"""
self._device = torch.device(torch.musa.current_device())
return super().musa()

def npu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`

Expand Down
7 changes: 4 additions & 3 deletions mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn

from mmengine.device import (is_cuda_available, is_mlu_available,
is_npu_available)
is_musa_available, is_npu_available)
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
Expand Down Expand Up @@ -74,8 +74,9 @@ def __init__(self,
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert is_cuda_available() or is_npu_available() or is_mlu_available(
), ('``AmpOptimizerWrapper`` is only available training '
'on gpu, npu or mlu')
) or is_musa_available(), (
'``AmpOptimizerWrapper`` is only available training '
'on gpu, npu, mlu or musa')
super().__init__(**kwargs)
self._scale_update_param = None

Expand Down
8 changes: 7 additions & 1 deletion mmengine/runner/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,13 @@ def autocast(device_type: Optional[str] = None,

elif device_type == 'npu':
pass

elif device_type == 'musa':
if dtype is None:
dtype = torch.get_autocast_gpu_dtype()
with torch.musa.amp.autocast(
enabled=enabled, dtype=dtype, cache_enabled=cache_enabled):
yield
return
else:
# Device like MPS does not support fp16 training or testing.
# If an inappropriate device is set and fp16 is enabled, an error
Expand Down
12 changes: 9 additions & 3 deletions mmengine/runner/log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import numpy as np
import torch

from mmengine.device import get_max_cuda_memory, is_cuda_available
from mmengine.device import (get_max_cuda_memory, get_max_musa_memory,
is_cuda_available, is_musa_available)
from mmengine.registry import LOG_PROCESSORS


Expand Down Expand Up @@ -226,11 +227,13 @@ def get_log_after_iter(self, runner, batch_idx: int,
log_tag.pop('time')
log_tag.pop('data_time')

# If cuda is available, the max memory occupied should be calculated.
if is_cuda_available():
# If cuda/musa is available,
# the max memory occupied should be calculated.
if is_cuda_available() or is_musa_available():
max_memory = self._get_max_memory(runner)
log_str += f'memory: {max_memory} '
tag['memory'] = max_memory

# Loop left keys to fill `log_str`.
if mode in ('train', 'val'):
log_items = []
Expand Down Expand Up @@ -498,6 +501,9 @@ def _get_max_memory(self, runner) -> int:
"""

device = getattr(runner.model, 'output_device', None)

if is_musa_available():
return get_max_musa_memory(device)
return get_max_cuda_memory(device)

def _get_iter(self, runner, batch_idx: int) -> int:
Expand Down
6 changes: 5 additions & 1 deletion mmengine/runner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch.utils.data import DataLoader

from mmengine.device import is_cuda_available, is_musa_available
from mmengine.dist import get_rank, sync_random_seed
from mmengine.logging import print_log
from mmengine.utils import digit_version, is_list_of
Expand Down Expand Up @@ -69,7 +70,10 @@ def set_random_seed(seed: Optional[int] = None,
np.random.seed(seed)
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if is_cuda_available():
torch.cuda.manual_seed_all(seed)
elif is_musa_available():
torch.musa.manual_seed_all(seed)
# os.environ['PYTHONHASHSEED'] = str(seed)
if deterministic:
if torch.backends.cudnn.benchmark:
Expand Down
Loading