Skip to content

Commit

Permalink
new yapf lint
Browse files Browse the repository at this point in the history
  • Loading branch information
dianyo committed Feb 21, 2025
1 parent ca52af6 commit 3594b5e
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 43 deletions.
8 changes: 4 additions & 4 deletions mmengine/_strategy/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def __init__(
self.config.setdefault('gradient_accumulation_steps', 1)
self.config['steps_per_print'] = steps_per_print
self._inputs_to_half = inputs_to_half
assert (exclude_frozen_parameters is None or
digit_version(deepspeed.__version__) >= digit_version('0.13.2')
), ('DeepSpeed >= 0.13.2 is required to enable '
'exclude_frozen_parameters')
assert (exclude_frozen_parameters is None or digit_version(
deepspeed.__version__) >= digit_version('0.13.2')), (
'DeepSpeed >= 0.13.2 is required to enable '
'exclude_frozen_parameters')
self.exclude_frozen_parameters = exclude_frozen_parameters

register_deepspeed_optimizers()
Expand Down
21 changes: 12 additions & 9 deletions mmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@
def _lazy2string(cfg_dict, dict_type=None):
if isinstance(cfg_dict, dict):
dict_type = dict_type or type(cfg_dict)
return dict_type(
{k: _lazy2string(v, dict_type)
for k, v in dict.items(cfg_dict)})
return dict_type({
k: _lazy2string(v, dict_type)
for k, v in dict.items(cfg_dict)
})
elif isinstance(cfg_dict, (tuple, list)):
return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict)
elif isinstance(cfg_dict, (LazyAttr, LazyObject)):
Expand Down Expand Up @@ -273,13 +274,15 @@ def __reduce_ex__(self, proto):
# called by CPython interpreter during pickling. See more details in
# https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501
if digit_version(platform.python_version()) < digit_version('3.8'):
return (self.__class__, ({k: v
for k, v in super().items()}, ), None,
None, None)
return (self.__class__, ({
k: v
for k, v in super().items()
}, ), None, None, None)
else:
return (self.__class__, ({k: v
for k, v in super().items()}, ), None,
None, None, None)
return (self.__class__, ({
k: v
for k, v in super().items()
}, ), None, None, None, None)

def __eq__(self, other):
if isinstance(other, ConfigDict):
Expand Down
3 changes: 2 additions & 1 deletion mmengine/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def default_collate(data_batch: Sequence) -> Any:
return [default_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: default_collate([d[key] for d in data_batch])
key:
default_collate([d[key] for d in data_batch])
for key in data_item
})
else:
Expand Down
4 changes: 2 additions & 2 deletions mmengine/fileio/backends/local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool:
"""
return osp.isfile(filepath)

def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str:
def join_path(self, filepath: Union[str, Path], *filepaths:
Union[str, Path]) -> str:
r"""Concatenate all file paths.
Join one or more filepath components intelligently. The return value
Expand Down
4 changes: 2 additions & 2 deletions mmengine/fileio/file_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool:
"""
return self.client.isfile(filepath)

def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str:
def join_path(self, filepath: Union[str, Path], *filepaths:
Union[str, Path]) -> str:
r"""Concatenate all file paths.
Join one or more filepath components intelligently. The return value
Expand Down
8 changes: 4 additions & 4 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def __init__(self,
self.save_best = save_best

# rule logic
assert (isinstance(rule, str) or is_list_of(rule, str)
or (rule is None)), (
'"rule" should be a str or list of str or None, '
f'but got {type(rule)}')
assert (isinstance(rule, str) or is_list_of(rule, str) or
(rule
is None)), ('"rule" should be a str or list of str or None, '
f'but got {type(rule)}')
if isinstance(rule, list):
# check the length of rule list
assert len(rule) in [
Expand Down
7 changes: 4 additions & 3 deletions mmengine/model/test_time_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def test_step(self, data):
data_list: Union[List[dict], List[list]]
if isinstance(data, dict):
num_augs = len(data[next(iter(data))])
data_list = [{key: value[idx]
for key, value in data.items()}
for idx in range(num_augs)]
data_list = [{
key: value[idx]
for key, value in data.items()
} for idx in range(num_augs)]
elif isinstance(data, (tuple, list)):
num_augs = len(data[0])
data_list = [[_data[idx] for _data in data]
Expand Down
6 changes: 3 additions & 3 deletions mmengine/utils/dl_utils/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from ..version_utils import digit_version
from .parrots_wrapper import TORCH_VERSION

_torch_version_meshgrid_indexing = (
'parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))
_torch_version_meshgrid_indexing = ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION)
>= digit_version('1.10.0a0'))


def torch_meshgrid(*tensors):
Expand Down
5 changes: 3 additions & 2 deletions mmengine/visualization/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,8 +754,9 @@ def draw_bboxes(
assert bboxes.shape[-1] == 4, (
f'The shape of `bboxes` should be (N, 4), but got {bboxes.shape}')

assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <=
bboxes[:, 3]).all()
assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1]
<= bboxes[:,
3]).all()
if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))):
warnings.warn(
'Warning: The bbox is out of bounds,'
Expand Down
7 changes: 4 additions & 3 deletions tests/test_analysis/test_jit_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,10 @@ def dummy_ops_handle(inputs: List[Any],

dummy_flops = {}
for name, counts in model.flops.items():
dummy_flops[name] = Counter(
{op: flop
for op, flop in counts.items() if op != self.lin_op})
dummy_flops[name] = Counter({
op: flop
for op, flop in counts.items() if op != self.lin_op
})
dummy_flops[''][dummy_name] = 2 * dummy_out
dummy_flops['fc'][dummy_name] = dummy_out
dummy_flops['submod'][dummy_name] = dummy_out
Expand Down
8 changes: 4 additions & 4 deletions tests/test_dataset/test_base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,13 +733,13 @@ def test_length(self):
def test_getitem(self):
assert (
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
assert (self.cat_datasets[0]['imgs'] !=
self.dataset_b[0]['imgs']).all()
assert (self.cat_datasets[0]['imgs']
!= self.dataset_b[0]['imgs']).all()

assert (
self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all()
assert (self.cat_datasets[-1]['imgs'] !=
self.dataset_a[-1]['imgs']).all()
assert (self.cat_datasets[-1]['imgs']
!= self.dataset_a[-1]['imgs']).all()

def test_get_data_info(self):
assert self.cat_datasets.get_data_info(
Expand Down
12 changes: 6 additions & 6 deletions tests/test_optim/test_optimizer/test_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,8 @@ def test_init(self):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_step(self, dtype):
if dtype is not None and (digit_version(TORCH_VERSION) <
digit_version('1.10.0')):
if dtype is not None and (digit_version(TORCH_VERSION)
< digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
Expand All @@ -478,8 +478,8 @@ def test_step(self, dtype):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_backward(self, dtype):
if dtype is not None and (digit_version(TORCH_VERSION) <
digit_version('1.10.0')):
if dtype is not None and (digit_version(TORCH_VERSION)
< digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
Expand Down Expand Up @@ -539,8 +539,8 @@ def test_load_state_dict(self):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_optim_context(self, dtype, target_dtype):
if dtype is not None and (digit_version(TORCH_VERSION) <
digit_version('1.10.0')):
if dtype is not None and (digit_version(TORCH_VERSION)
< digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
Expand Down

0 comments on commit 3594b5e

Please sign in to comment.