diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index 9a92cdebfe..56ec343fb6 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -4,7 +4,7 @@ import torch from mmengine.registry import HOOKS -from ..device import is_cuda_available, is_musa_available +from ..device import is_musa_available from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -50,9 +50,8 @@ def _after_iter(self, mode (str): Current mode of runner. Defaults to 'train'. """ if self._do_after_iter: - if is_cuda_available(): - torch.cuda.empty_cache() - elif is_musa_available(): + torch.cuda.empty_cache() + if is_musa_available(): torch.musa.empty_cache() def _before_epoch(self, runner, mode: str = 'train') -> None: @@ -63,9 +62,8 @@ def _before_epoch(self, runner, mode: str = 'train') -> None: mode (str): Current mode of runner. Defaults to 'train'. """ if self._do_before_epoch: - if is_cuda_available(): - torch.cuda.empty_cache() - elif is_musa_available(): + torch.cuda.empty_cache() + if is_musa_available(): torch.musa.empty_cache() def _after_epoch(self, runner, mode: str = 'train') -> None: @@ -76,7 +74,6 @@ def _after_epoch(self, runner, mode: str = 'train') -> None: mode (str): Current mode of runner. Defaults to 'train'. """ if self._do_after_epoch: - if is_cuda_available(): - torch.cuda.empty_cache() - elif is_musa_available(): + torch.cuda.empty_cache() + if is_musa_available(): torch.musa.empty_cache() diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index ddcd782c0d..b44b5451a6 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -14,7 +14,6 @@ from mmengine.utils import ManagerMixin from mmengine.utils.manager import _accquire_lock, _release_lock -from ..device import is_cuda_available, is_musa_available class FilterDuplicateWarning(logging.Filter): @@ -399,24 +398,17 @@ def _get_device_id(): except ImportError: return 0 else: - local_rank = int(os.getenv('LOCAL_RANK', '0')) - if is_cuda_available(): - cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) - if cuda_visible_devices is None: - num_device = torch.cuda.device_count() - cuda_visible_devices = list(range(num_device)) - else: - cuda_visible_devices = cuda_visible_devices.split(',') - try: - return int(cuda_visible_devices[local_rank]) - except ValueError: - # handle case for Multi-Instance GPUs - # see #1148 for details - return cuda_visible_devices[local_rank] - elif is_musa_available(): + MUSA_AVAILABLE = False + try: + import torch_musa + MUSA_AVAILABLE = True + except ImportError: + pass + if MUSA_AVAILABLE: + local_rank = int(os.getenv('LOCAL_RANK', '0')) musa_visible_devices = os.getenv('MUSA_VISIBLE_DEVICES', None) if musa_visible_devices is None: - num_device = torch.musa.device_count() + num_device = torch_musa.device_count() musa_visible_devices = list(range(num_device)) else: musa_visible_devices = musa_visible_devices.split(',') @@ -427,8 +419,22 @@ def _get_device_id(): # see #1148 for details return musa_visible_devices[local_rank] else: + local_rank = int(os.getenv('LOCAL_RANK', '0')) # TODO: return device id of npu and mlu. - return local_rank + if not torch.cuda.is_available(): + return local_rank + cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) + if cuda_visible_devices is None: + num_device = torch.cuda.device_count() + cuda_visible_devices = list(range(num_device)) + else: + cuda_visible_devices = cuda_visible_devices.split(',') + try: + return int(cuda_visible_devices[local_rank]) + except ValueError: + # handle case for Multi-Instance GPUs + # see #1148 for details + return cuda_visible_devices[local_rank] def _get_host_info() -> str: diff --git a/tests/test_hooks/test_empty_cache_hook.py b/tests/test_hooks/test_empty_cache_hook.py index 4539ffa7e7..4a9ea99752 100644 --- a/tests/test_hooks/test_empty_cache_hook.py +++ b/tests/test_hooks/test_empty_cache_hook.py @@ -1,15 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -import unittest from unittest.mock import patch -from mmengine.device import is_musa_available from mmengine.testing import RunnerTestCase -# TODO:haowen.han@mthreads.com -@unittest.skipIf( - is_musa_available(), - 'torch_musa do not support torch.musa.reset_peak_memory_stats() yet') class TestEmptyCacheHook(RunnerTestCase): def test_with_runner(self): diff --git a/tests/test_runner/test_log_processor.py b/tests/test_runner/test_log_processor.py index c2b48e18a5..30f7d872b4 100644 --- a/tests/test_runner/test_log_processor.py +++ b/tests/test_runner/test_log_processor.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -import unittest from unittest.mock import MagicMock, patch import numpy as np @@ -251,10 +250,7 @@ def test_collect_non_scalars(self): assert tag['metric1'] is metric1 assert tag['metric2'] is metric2 - # TODO:haowen.han@mtheads.com - @unittest.skipIf( - is_musa_available(), - 'musa backend do not support torch.cuda.reset_peak_memory_stats') + # TODO:haowen.han@mtheads.com MUSA does not support it yet! @patch('torch.cuda.max_memory_allocated', MagicMock()) @patch('torch.cuda.reset_peak_memory_stats', MagicMock()) def test_get_max_memory(self):