Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
hanhaowen-mt and zhouzaida authored Jan 10, 2024
1 parent 658a2f2 commit cee8d80
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 13 deletions.
4 changes: 2 additions & 2 deletions mmengine/dist/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,11 +790,11 @@ def _gather_object(obj: Any,
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
is_mccl_backend = group_backend == 'mccl'
if is_nccl_backend:
current_device = torch.device('', torch.cuda.current_device())
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
elif is_mccl_backend:
current_device = torch.device('', torch.musa.current_device())
current_device = torch.device('musa', torch.musa.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and
Expand Down
3 changes: 1 addition & 2 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from torch import Tensor
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
from mmengine.device import is_mlu_available, is_npu_available
from mmengine.device import is_musa_available
from mmengine.device import (is_mlu_available, is_npu_available, is_musa_available)

from collections.abc import Iterable, Mapping

Expand Down
2 changes: 1 addition & 1 deletion mmengine/structures/base_data_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def cuda(self) -> 'BaseDataElement':

# Tensor-like methods
def musa(self) -> 'BaseDataElement':
"""Convert all tensors to GPU in data."""
"""Convert all tensors to musa in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
Expand Down
4 changes: 2 additions & 2 deletions mmengine/utils/dl_utils/collect_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

import mmengine
from mmengine.device import is_musa_available
from mmengine.device import is_cuda_available, is_musa_available
from .parrots_wrapper import TORCH_VERSION, get_build_config, is_rocm_pytorch


Expand Down Expand Up @@ -57,7 +57,7 @@ def collect_env():
env_info['sys.platform'] = sys.platform
env_info['Python'] = sys.version.replace('\n', '')

cuda_available = torch.cuda.is_available()
cuda_available = is_cuda_available()
musa_available = is_musa_available()
env_info['CUDA available'] = cuda_available
env_info['MUSA available'] = musa_available
Expand Down
6 changes: 3 additions & 3 deletions mmengine/utils/dl_utils/time_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from mmengine.device import is_musa_available
from mmengine.device import is_cuda_available, is_musa_available
from mmengine.dist.utils import master_only
from mmengine.logging import MMLogger, print_log

Expand Down Expand Up @@ -86,7 +86,7 @@ def wrapper(*args, **kwargs):
self.__count += 1

if self.with_sync:
if torch.cuda.is_available():
if is_cuda_available():
torch.cuda.synchronize()
elif is_musa_available():
torch.musa.synchronize()
Expand All @@ -95,7 +95,7 @@ def wrapper(*args, **kwargs):
result = fn(*args, **kwargs)

if self.with_sync:
if torch.cuda.is_available():
if is_cuda_available():
torch.cuda.synchronize()
elif is_musa_available():
torch.musa.synchronize()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_runner/test_log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from parameterized import parameterized

from mmengine.device import is_musa_available
from mmengine.device import is_cuda_available, is_musa_available
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
from mmengine.runner import LogProcessor
from mmengine.testing import RunnerTestCase
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy):
f"time: {train_logs['time']:.4f} "
f"data_time: {train_logs['data_time']:.4f} ")

if torch.cuda.is_available() or is_musa_available():
if is_cuda_available() or is_musa_available():
log_str += 'memory: 100 '
if mode == 'train':
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy):
f"time: {train_logs['time']:.4f} "
f"data_time: {train_logs['data_time']:.4f} ")

if torch.cuda.is_available() or is_musa_available():
if is_cuda_available() or is_musa_available():
log_str += 'memory: 100 '

if mode == 'train':
Expand Down

0 comments on commit cee8d80

Please sign in to comment.