Skip to content

Commit

Permalink
[Feature] Add the support for musa device support (#1453)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhaowen-mt authored Jan 11, 2024
1 parent 109cd44 commit 3d8a611
Show file tree
Hide file tree
Showing 22 changed files with 253 additions and 43 deletions.
10 changes: 6 additions & 4 deletions mmengine/device/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_dipu_available, is_mlu_available, is_mps_available,
is_npu_available, is_npu_support_full_precision)
from .utils import (get_device, get_max_cuda_memory, get_max_musa_memory,
is_cuda_available, is_dipu_available, is_mlu_available,
is_mps_available, is_musa_available, is_npu_available,
is_npu_support_full_precision)

__all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available', 'is_mps_available', 'is_npu_available',
'is_dipu_available', 'is_npu_support_full_precision'
'is_dipu_available', 'get_max_musa_memory', 'is_musa_available',
'is_npu_support_full_precision'
]
38 changes: 37 additions & 1 deletion mmengine/device/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
except Exception:
IS_DIPU_AVAILABLE = False

try:
import torch_musa # noqa: F401
IS_MUSA_AVAILABLE = True
except Exception:
IS_MUSA_AVAILABLE = False


def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
Expand Down Expand Up @@ -73,6 +79,34 @@ def is_dipu_available() -> bool:
return IS_DIPU_AVAILABLE


def get_max_musa_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
a given device. By default, this returns the peak allocated memory since
the beginning of this program.
Args:
device (torch.device, optional): selected device. Returns
statistic for the current device, given by
:func:`~torch.musa.current_device`, if ``device`` is None.
Defaults to None.
Returns:
int: The maximum GPU memory occupied by tensors in megabytes
for a given device.
"""
mem = torch.musa.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
# TODO:haowen.han@mthreads.com: This function is not supported by musa yet.
# torch.musa.reset_peak_memory_stats()
return int(mem_mb.item())


def is_musa_available() -> bool:
return IS_MUSA_AVAILABLE


def is_npu_support_full_precision() -> bool:
"""Returns True if npu devices support full precision training."""
version_of_support_full_precision = 220
Expand All @@ -91,12 +125,14 @@ def is_npu_support_full_precision() -> bool:
DEVICE = 'mps'
elif is_dipu_available():
DEVICE = 'dipu'
elif is_musa_available():
DEVICE = 'musa'


def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | npu | mlu | mps | cpu.
str: cuda | npu | mlu | mps | musa | cpu.
"""
return DEVICE
17 changes: 17 additions & 0 deletions mmengine/dist/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,16 @@ def _broadcast_object_list(object_list: List[Any],
current_device = torch.device('cpu')
is_hccl_backend = group_backend == 'hccl'
is_cncl_backend = group_backend == 'cncl'
is_mccl_backend = group_backend == 'mccl'
if is_hccl_backend:
current_device = torch.device('npu', torch.npu.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_cncl_backend:
current_device = torch.device('mlu', torch.mlu.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_mccl_backend:
current_device = torch.device('musa', torch.musa.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_nccl_backend:
# See note about using torch.cuda.current_device() here in
# docstring. We cannot simply use my_rank since rank == device is
Expand Down Expand Up @@ -624,13 +628,21 @@ def _all_gather_object(object_list: List[Any],
group_backend = get_backend(group)
current_device = torch.device('cpu')
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
is_mccl_backend = group_backend == 'mccl'
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
elif is_mccl_backend:
# See note about using torch.musa.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device('musa', torch.musa.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and
# index until the correct size when deserializing the tensors.
group_size = get_world_size(group=group)
Expand Down Expand Up @@ -776,10 +788,15 @@ def _gather_object(obj: Any,
group_backend = get_backend(group)
current_device = torch.device('cpu')
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
is_mccl_backend = group_backend == 'mccl'
if is_nccl_backend:
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
elif is_mccl_backend:
current_device = torch.device('musa', torch.musa.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and
# index until the correct size when deserializing the tensors.
group_size = get_world_size(group=group)
Expand Down
14 changes: 13 additions & 1 deletion mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from torch import Tensor
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
from mmengine.device import is_mlu_available, is_npu_available
from mmengine.device import (is_mlu_available, is_npu_available,
is_musa_available)

from collections.abc import Iterable, Mapping

Expand Down Expand Up @@ -117,6 +118,14 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
elif is_musa_available():
import torch_musa # noqa: F401
torch.musa.set_device(rank)
torch_dist.init_process_group(
backend='mccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
torch.cuda.set_device(local_rank)

Expand Down Expand Up @@ -527,6 +536,9 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
return torch.device('mlu', torch.mlu.current_device())
elif backend == 'smddp':
return torch.device('cuda', torch.cuda.current_device())
elif backend == 'mccl':
import torch_musa
return torch.device('musa', torch_musa.current_device())
else:
# GLOO and MPI backends use cpu device by default
return torch.device('cpu')
Expand Down
16 changes: 13 additions & 3 deletions mmengine/hooks/empty_cache_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from mmengine.registry import HOOKS
from ..device import is_cuda_available, is_musa_available
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]
Expand Down Expand Up @@ -49,7 +50,10 @@ def _after_iter(self,
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_iter:
torch.cuda.empty_cache()
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.musa.empty_cache()

def _before_epoch(self, runner, mode: str = 'train') -> None:
"""Empty cache before an epoch.
Expand All @@ -59,7 +63,10 @@ def _before_epoch(self, runner, mode: str = 'train') -> None:
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_before_epoch:
torch.cuda.empty_cache()
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.musa.empty_cache()

def _after_epoch(self, runner, mode: str = 'train') -> None:
"""Empty cache after an epoch.
Expand All @@ -69,4 +76,7 @@ def _after_epoch(self, runner, mode: str = 'train') -> None:
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_epoch:
torch.cuda.empty_cache()
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.musa.empty_cache()
46 changes: 31 additions & 15 deletions mmengine/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,22 +398,38 @@ def _get_device_id():
except ImportError:
return 0
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
# TODO: return device id of npu and mlu.
if not torch.cuda.is_available():
return local_rank
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
if cuda_visible_devices is None:
num_device = torch.cuda.device_count()
cuda_visible_devices = list(range(num_device))
else:
cuda_visible_devices = cuda_visible_devices.split(',')
MUSA_AVAILABLE = False
try:
return int(cuda_visible_devices[local_rank])
except ValueError:
# handle case for Multi-Instance GPUs
# see #1148 for details
return cuda_visible_devices[local_rank]
import torch_musa
MUSA_AVAILABLE = True
except ImportError:
pass
if MUSA_AVAILABLE:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
musa_visible_devices = os.getenv('MUSA_VISIBLE_DEVICES', None)
if musa_visible_devices is None:
num_device = torch_musa.device_count()
musa_visible_devices = list(range(num_device))
else:
musa_visible_devices = musa_visible_devices.split(',')
return int(musa_visible_devices[local_rank])
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
# TODO: return device id of npu and mlu.
if not torch.cuda.is_available():
return local_rank
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
if cuda_visible_devices is None:
num_device = torch.cuda.device_count()
cuda_visible_devices = list(range(num_device))
else:
cuda_visible_devices = cuda_visible_devices.split(',')
try:
return int(cuda_visible_devices[local_rank])
except ValueError:
# handle case for Multi-Instance GPUs
# see #1148 for details
return cuda_visible_devices[local_rank]


def _get_host_info() -> str:
Expand Down
15 changes: 15 additions & 0 deletions mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,21 @@ def cuda(
self._set_device(torch.device(device))
return super().cuda(device)

def musa(
self,
device: Optional[Union[int, str, torch.device]] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.musa`
additionally.
Returns:
nn.Module: The model itself.
"""
if device is None or isinstance(device, int):
device = torch.device('musa', index=device)
self._set_device(torch.device(device))
return super().musa(device)

def mlu(
self,
device: Union[int, str, torch.device, None] = None,
Expand Down
9 changes: 9 additions & 0 deletions mmengine/model/base_model/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def cuda(self, *args, **kwargs) -> nn.Module:
self._device = torch.device(torch.cuda.current_device())
return super().cuda()

def musa(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
Returns:
nn.Module: The model itself.
"""
self._device = torch.device(torch.musa.current_device())
return super().musa()

def npu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
Expand Down
7 changes: 4 additions & 3 deletions mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn

from mmengine.device import (is_cuda_available, is_mlu_available,
is_npu_available)
is_musa_available, is_npu_available)
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
Expand Down Expand Up @@ -74,8 +74,9 @@ def __init__(self,
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert is_cuda_available() or is_npu_available() or is_mlu_available(
), ('``AmpOptimizerWrapper`` is only available training '
'on gpu, npu or mlu')
) or is_musa_available(), (
'``AmpOptimizerWrapper`` is only available training '
'on gpu, npu, mlu or musa')
super().__init__(**kwargs)
self._scale_update_param = None

Expand Down
8 changes: 7 additions & 1 deletion mmengine/runner/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,13 @@ def autocast(device_type: Optional[str] = None,

elif device_type == 'npu':
pass

elif device_type == 'musa':
if dtype is None:
dtype = torch.get_autocast_gpu_dtype()
with torch.musa.amp.autocast(
enabled=enabled, dtype=dtype, cache_enabled=cache_enabled):
yield
return
else:
# Device like MPS does not support fp16 training or testing.
# If an inappropriate device is set and fp16 is enabled, an error
Expand Down
12 changes: 9 additions & 3 deletions mmengine/runner/log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import numpy as np
import torch

from mmengine.device import get_max_cuda_memory, is_cuda_available
from mmengine.device import (get_max_cuda_memory, get_max_musa_memory,
is_cuda_available, is_musa_available)
from mmengine.registry import LOG_PROCESSORS


Expand Down Expand Up @@ -226,11 +227,13 @@ def get_log_after_iter(self, runner, batch_idx: int,
log_tag.pop('time')
log_tag.pop('data_time')

# If cuda is available, the max memory occupied should be calculated.
if is_cuda_available():
# If cuda/musa is available,
# the max memory occupied should be calculated.
if is_cuda_available() or is_musa_available():
max_memory = self._get_max_memory(runner)
log_str += f'memory: {max_memory} '
tag['memory'] = max_memory

# Loop left keys to fill `log_str`.
if mode in ('train', 'val'):
log_items = []
Expand Down Expand Up @@ -498,6 +501,9 @@ def _get_max_memory(self, runner) -> int:
"""

device = getattr(runner.model, 'output_device', None)

if is_musa_available():
return get_max_musa_memory(device)
return get_max_cuda_memory(device)

def _get_iter(self, runner, batch_idx: int) -> int:
Expand Down
6 changes: 5 additions & 1 deletion mmengine/runner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch.utils.data import DataLoader

from mmengine.device import is_cuda_available, is_musa_available
from mmengine.dist import get_rank, sync_random_seed
from mmengine.logging import print_log
from mmengine.utils import digit_version, is_list_of
Expand Down Expand Up @@ -69,7 +70,10 @@ def set_random_seed(seed: Optional[int] = None,
np.random.seed(seed)
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if is_cuda_available():
torch.cuda.manual_seed_all(seed)
elif is_musa_available():
torch.musa.manual_seed_all(seed)
# os.environ['PYTHONHASHSEED'] = str(seed)
if deterministic:
if torch.backends.cudnn.benchmark:
Expand Down
Loading

0 comments on commit 3d8a611

Please sign in to comment.