Skip to content

Commit

Permalink
revise logging/logger.py for ci
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhaowen-mt committed Jan 9, 2024
1 parent 243d093 commit f7543e2
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 39 deletions.
17 changes: 7 additions & 10 deletions mmengine/hooks/empty_cache_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from mmengine.registry import HOOKS
from ..device import is_cuda_available, is_musa_available
from ..device import is_musa_available
from .hook import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]
Expand Down Expand Up @@ -50,9 +50,8 @@ def _after_iter(self,
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_iter:
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.cuda.empty_cache()
if is_musa_available():
torch.musa.empty_cache()

def _before_epoch(self, runner, mode: str = 'train') -> None:
Expand All @@ -63,9 +62,8 @@ def _before_epoch(self, runner, mode: str = 'train') -> None:
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_before_epoch:
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.cuda.empty_cache()
if is_musa_available():
torch.musa.empty_cache()

def _after_epoch(self, runner, mode: str = 'train') -> None:
Expand All @@ -76,7 +74,6 @@ def _after_epoch(self, runner, mode: str = 'train') -> None:
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_epoch:
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.cuda.empty_cache()
if is_musa_available():
torch.musa.empty_cache()
42 changes: 24 additions & 18 deletions mmengine/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from mmengine.utils import ManagerMixin
from mmengine.utils.manager import _accquire_lock, _release_lock
from ..device import is_cuda_available, is_musa_available


class FilterDuplicateWarning(logging.Filter):
Expand Down Expand Up @@ -399,24 +398,17 @@ def _get_device_id():
except ImportError:
return 0
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
if is_cuda_available():
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
if cuda_visible_devices is None:
num_device = torch.cuda.device_count()
cuda_visible_devices = list(range(num_device))
else:
cuda_visible_devices = cuda_visible_devices.split(',')
try:
return int(cuda_visible_devices[local_rank])
except ValueError:
# handle case for Multi-Instance GPUs
# see #1148 for details
return cuda_visible_devices[local_rank]
elif is_musa_available():
MUSA_AVAILABLE = False
try:
import torch_musa
MUSA_AVAILABLE = True
except ImportError:
pass
if MUSA_AVAILABLE:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
musa_visible_devices = os.getenv('MUSA_VISIBLE_DEVICES', None)
if musa_visible_devices is None:
num_device = torch.musa.device_count()
num_device = torch_musa.device_count()
musa_visible_devices = list(range(num_device))
else:
musa_visible_devices = musa_visible_devices.split(',')
Expand All @@ -427,8 +419,22 @@ def _get_device_id():
# see #1148 for details
return musa_visible_devices[local_rank]
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
# TODO: return device id of npu and mlu.
return local_rank
if not torch.cuda.is_available():
return local_rank
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
if cuda_visible_devices is None:
num_device = torch.cuda.device_count()
cuda_visible_devices = list(range(num_device))
else:
cuda_visible_devices = cuda_visible_devices.split(',')
try:
return int(cuda_visible_devices[local_rank])
except ValueError:
# handle case for Multi-Instance GPUs
# see #1148 for details
return cuda_visible_devices[local_rank]


def _get_host_info() -> str:
Expand Down
6 changes: 0 additions & 6 deletions tests/test_hooks/test_empty_cache_hook.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
from unittest.mock import patch

from mmengine.device import is_musa_available
from mmengine.testing import RunnerTestCase


# TODO:haowen.han@mthreads.com
@unittest.skipIf(
is_musa_available(),
'torch_musa do not support torch.musa.reset_peak_memory_stats() yet')
class TestEmptyCacheHook(RunnerTestCase):

def test_with_runner(self):
Expand Down
6 changes: 1 addition & 5 deletions tests/test_runner/test_log_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest
from unittest.mock import MagicMock, patch

import numpy as np
Expand Down Expand Up @@ -251,10 +250,7 @@ def test_collect_non_scalars(self):
assert tag['metric1'] is metric1
assert tag['metric2'] is metric2

# TODO:haowen.han@mtheads.com
@unittest.skipIf(
is_musa_available(),
'musa backend do not support torch.cuda.reset_peak_memory_stats')
# TODO:haowen.han@mtheads.com MUSA does not support it yet!
@patch('torch.cuda.max_memory_allocated', MagicMock())
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
def test_get_max_memory(self):
Expand Down

0 comments on commit f7543e2

Please sign in to comment.