Skip to content

Commit

Permalink
fix update of issue #25
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed Oct 14, 2022
1 parent 5211cc2 commit 8966870
Show file tree
Hide file tree
Showing 57 changed files with 3,021 additions and 472 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ The main branch works with **PyTorch 1.8** (required by some self-supervised met

## What's New

[2022-10-12] Update new features and documents of `OpenMixup` v0.2.6 (issue [#24](https://github.com/Westlake-AI/openmixup/issues/24) and issue [#25](https://github.com/Westlake-AI/openmixup/issues/25)).
[2022-10-15] Update new features and documents of `OpenMixup` v0.2.6 (issue [#24](https://github.com/Westlake-AI/openmixup/issues/24) and issue [#25](https://github.com/Westlake-AI/openmixup/issues/25)).

[2022-09-14] `OpenMixup` v0.2.6 is released (issue [#20](https://github.com/Westlake-AI/openmixup/issues/20)).

Expand Down Expand Up @@ -113,6 +113,7 @@ Please refer to [Model Zoos](docs/en/model_zoos) for various backbones, mixup me
- [x] [MViTV2](https://arxiv.org/abs/2112.01526) (CVPR'2022) [[config](https://github.com/Westlake-AI/openmixup/tree/main/configs/classification/imagenet/mvit/)]
- [x] [RepMLP](https://arxiv.org/abs/2105.01883) (CVPR'2022) [[config](https://github.com/Westlake-AI/openmixup/tree/main/configs/classification/imagenet/repmlp/)]
- [x] [VAN](https://arxiv.org/abs/2202.09741) (ArXiv'2022) [[config](https://github.com/Westlake-AI/openmixup/tree/main/configs/classification/imagenet/van/)]
- [x] [DeiT-3](https://arxiv.org/abs/2204.07118) (ECCV'2022) [[config](https://github.com/Westlake-AI/openmixup/tree/main/configs/classification/imagenet/deit3/)]
- [x] [LITv2](https://arxiv.org/abs/2205.13213) (NIPS'2022) [[config](https://github.com/Westlake-AI/openmixup/tree/main/configs/classification/imagenet/lit_v2/)]
- [x] [HorNet](https://arxiv.org/abs/2207.14284) (NIPS'2022) [[config](https://github.com/Westlake-AI/openmixup/tree/main/configs/classification/imagenet/hornet/)]
- [x] [EdgeNeXt](https://arxiv.org/abs/2206.10589) (ECCVW'2022) [[config](https://github.com/Westlake-AI/openmixup/tree/main/configs/classification/imagenet/edgenext/)]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# dataset settings
data_source_cfg = dict(type='ImageNet')
# ImageNet dataset
data_train_list = 'data/meta/ImageNet/train_labeled_full.txt'
data_train_root = 'data/ImageNet/train'
data_test_list = 'data/meta/ImageNet/val_labeled.txt'
data_test_root = 'data/ImageNet/val/'

dataset_type = 'ClassificationDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(type='RandomResizedCrop', size=224, interpolation=3), # bicubic
dict(type='RandomHorizontalFlip'),
]
test_pipeline_1 = [
dict(type='Resize', size=256, interpolation=3), # 0.85
dict(type='RandomHorizontalFlip', p=0.5),
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]
test_pipeline_2 = [
dict(type='Resize', size=256, interpolation=3), # 0.85
dict(type='RandomVerticalFlip', p=0.5),
dict(type='PlaceCrop', size=224, start=[0, 5, 10, 15,]),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]

# prefetch
prefetch = True
if not prefetch:
train_pipeline.extend([dict(type='ToTensor'), dict(type='Normalize', **img_norm_cfg)])

data = dict(
imgs_per_gpu=64,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_source=dict(
list_file=data_train_list, root=data_train_root,
**data_source_cfg),
pipeline=train_pipeline,
prefetch=prefetch,
),
val=dict(
type="MultiViewDataset", # use multi-view for test time augmentations
data_source=dict(
list_file=data_test_list, root=data_test_root, **data_source_cfg),
num_views=[2, 4],
pipelines=[test_pipeline_1, test_pipeline_2],
prefetch=False,
))

# validation hook
evaluation = dict(
initial=False,
interval=1,
imgs_per_gpu=128,
workers_per_gpu=4,
eval_param=dict(topk=(1, 5)))

# checkpoint
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models
rand_increasing_policies = [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Invert'),
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
dict(type='Posterize', magnitude_key='bits', magnitude_range=(4, 0)),
dict(type='Solarize', magnitude_key='thr', magnitude_range=(256, 0)),
dict(type='SolarizeAdd', magnitude_key='magnitude', magnitude_range=(0, 110)),
dict(type='ColorTransform', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(type='Contrast', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(type='Brightness', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(type='Shear',
magnitude_key='magnitude', magnitude_range=(0, 0.3), direction='horizontal'),
dict(type='Shear',
magnitude_key='magnitude', magnitude_range=(0, 0.3), direction='vertical'),
dict(type='Translate',
magnitude_key='magnitude', magnitude_range=(0, 0.45), direction='horizontal'),
dict(type='Translate',
magnitude_key='magnitude', magnitude_range=(0, 0.45), direction='vertical'),
]

# dataset settings
data_source_cfg = dict(type='ImageNet')
# ImageNet dataset
data_train_list = 'data/meta/ImageNet/train_labeled_full.txt'
data_train_root = 'data/ImageNet/train'
data_test_list = 'data/meta/ImageNet/val_labeled.txt'
data_test_root = 'data/ImageNet/val/'

dataset_type = 'ClassificationDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(type='RandomResizedCrop', size=224, interpolation=3), # bicubic
dict(type='RandomHorizontalFlip'),
dict(type='RandAugment',
policies=rand_increasing_policies,
num_policies=2, total_level=10,
magnitude_level=9, magnitude_std=0.5,
hparams=dict(
pad_val=[104, 116, 124], interpolation='bicubic')),
]
test_pipeline = [
dict(type='Resize', size=224, interpolation=3), # crop-ratio = 1.0
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]
# prefetch
prefetch = True
if not prefetch:
train_pipeline.extend([dict(type='ToTensor'), dict(type='Normalize', **img_norm_cfg)])

data = dict(
imgs_per_gpu=128,
workers_per_gpu=10,
train=dict(
type=dataset_type,
data_source=dict(
list_file=data_train_list, root=data_train_root,
**data_source_cfg),
pipeline=train_pipeline,
prefetch=prefetch,
),
val=dict(
type=dataset_type,
data_source=dict(
list_file=data_test_list, root=data_test_root, **data_source_cfg),
pipeline=test_pipeline,
prefetch=False,
))

# validation hook
evaluation = dict(
initial=False,
interval=1,
imgs_per_gpu=128,
workers_per_gpu=4,
eval_param=dict(topk=(1, 5)))

# checkpoint
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# dataset settings
data_source_cfg = dict(type='ImageNet')
# ImageNet dataset
data_train_list = 'data/meta/ImageNet/train_labeled_full.txt'
data_train_root = 'data/ImageNet/train'
data_test_list = 'data/meta/ImageNet/val_labeled.txt'
data_test_root = 'data/ImageNet/val/'

dataset_type = 'ClassificationDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(type='RandomResizedCrop', size=160, interpolation=3), # bicubic
dict(type='RandomHorizontalFlip'),
dict(type='RandomAppliedTrans', # 3-Augment in DeiT III
transforms=[
dict(type='RandomGrayscale', p=1.),
dict(type='Solarization', p=1.),
dict(type='GaussianBlur', sigma_min=0.1, sigma_max=2.0, p=1.),
],
p=1.0),
dict(type='ColorJitter',
brightness=0.3, contrast=0.3, saturation=0.3),
]
test_pipeline = [
dict(type='Resize', size=160, interpolation=3), # crop-ratio = 1.0
dict(type='CenterCrop', size=160),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]

# prefetch
prefetch = True
if not prefetch:
train_pipeline.extend([dict(type='ToTensor'), dict(type='Normalize', **img_norm_cfg)])

data = dict(
imgs_per_gpu=128,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_source=dict(
list_file=data_train_list, root=data_train_root,
**data_source_cfg),
pipeline=train_pipeline,
prefetch=prefetch,
),
val=dict(
type=dataset_type,
data_source=dict(
list_file=data_test_list, root=data_test_root, **data_source_cfg),
pipeline=test_pipeline,
prefetch=False,
))

# validation hook
evaluation = dict(
initial=False,
interval=1,
imgs_per_gpu=128,
workers_per_gpu=4,
eval_param=dict(topk=(1, 5)))

# checkpoint
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# dataset settings
data_source_cfg = dict(type='ImageNet')
# ImageNet dataset
data_train_list = 'data/meta/ImageNet/train_labeled_full.txt'
data_train_root = 'data/ImageNet/train'
data_test_list = 'data/meta/ImageNet/val_labeled.txt'
data_test_root = 'data/ImageNet/val/'

dataset_type = 'ClassificationDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(type='RandomResizedCrop', size=192, interpolation=3), # bicubic
dict(type='RandomHorizontalFlip'),
dict(type='RandomAppliedTrans', # 3-Augment in DeiT III
transforms=[
dict(type='RandomGrayscale', p=1.),
dict(type='Solarization', p=1.),
dict(type='GaussianBlur', sigma_min=0.1, sigma_max=2.0, p=1.),
],
p=1.0),
dict(type='ColorJitter',
brightness=0.3, contrast=0.3, saturation=0.3),
]
test_pipeline = [
dict(type='Resize', size=192, interpolation=3), # crop-ratio = 1.0
dict(type='CenterCrop', size=192),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]

# prefetch
prefetch = True
if not prefetch:
train_pipeline.extend([dict(type='ToTensor'), dict(type='Normalize', **img_norm_cfg)])

data = dict(
imgs_per_gpu=128,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_source=dict(
list_file=data_train_list, root=data_train_root,
**data_source_cfg),
pipeline=train_pipeline,
prefetch=prefetch,
),
val=dict(
type=dataset_type,
data_source=dict(
list_file=data_test_list, root=data_test_root, **data_source_cfg),
pipeline=test_pipeline,
prefetch=False,
))

# validation hook
evaluation = dict(
initial=False,
interval=1,
imgs_per_gpu=128,
workers_per_gpu=4,
eval_param=dict(topk=(1, 5)))

# checkpoint
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# dataset settings
data_source_cfg = dict(type='ImageNet')
# ImageNet dataset
data_train_list = 'data/meta/ImageNet/train_labeled_full.txt'
data_train_root = 'data/ImageNet/train'
data_test_list = 'data/meta/ImageNet/val_labeled.txt'
data_test_root = 'data/ImageNet/val/'

dataset_type = 'ClassificationDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(type='RandomResizedCrop', size=224, interpolation=3), # bicubic
dict(type='RandomHorizontalFlip'),
dict(type='RandomAppliedTrans', # 3-Augment in DeiT III
transforms=[
dict(type='RandomGrayscale', p=1.),
dict(type='Solarization', p=1.),
dict(type='GaussianBlur', sigma_min=0.1, sigma_max=2.0, p=1.),
],
p=1.0),
dict(type='ColorJitter',
brightness=0.3, contrast=0.3, saturation=0.3),
]
test_pipeline = [
dict(type='Resize', size=224, interpolation=3), # crop-ratio = 1.0
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]

# prefetch
prefetch = True
if not prefetch:
train_pipeline.extend([dict(type='ToTensor'), dict(type='Normalize', **img_norm_cfg)])

data = dict(
imgs_per_gpu=128,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_source=dict(
list_file=data_train_list, root=data_train_root,
**data_source_cfg),
pipeline=train_pipeline,
prefetch=prefetch,
),
val=dict(
type=dataset_type,
data_source=dict(
list_file=data_test_list, root=data_test_root, **data_source_cfg),
pipeline=test_pipeline,
prefetch=False,
))

# validation hook
evaluation = dict(
initial=False,
interval=1,
imgs_per_gpu=128,
workers_per_gpu=4,
eval_param=dict(topk=(1, 5)))

# checkpoint
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
24 changes: 24 additions & 0 deletions configs/classification/_base_/models/deit3/deit3_base_p16_sz192.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# model settings
model = dict(
type='MixUpClassification',
pretrained=None,
alpha=[0.8, 1.0,],
mix_mode=["mixup", "cutmix",],
mix_args=dict(),
backbone=dict(
type='DeiT3',
arch='base',
img_size=192,
patch_size=16,
drop_path_rate=0.2),
head=dict(
type='VisionTransformerClsHead',
loss=dict(type='CrossEntropyLoss', # mixup BCE loss (one-hot encoding)
use_soft=False, use_sigmoid=True, loss_weight=1.0),
multi_label=True,
in_channels=768, num_classes=1000),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
],
)
Loading

0 comments on commit 8966870

Please sign in to comment.