Skip to content

Commit

Permalink
change back the classification preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
Mehrdad committed Dec 22, 2023
1 parent 4939863 commit 27f1ef2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
32 changes: 19 additions & 13 deletions ultralytics/data/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ def classify_transforms(size=224, rect=False, mean=(0.0, 0.0, 0.0), std=(1.0, 1.
"""Transforms to apply if albumentations not installed."""
if not isinstance(size, int):
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
transforms = [ClassifyLetterBox(size), ToTensor()]
transforms = [ClassifyLetterBox(size, auto=True) if rect else CenterCrop(size), ToTensor()]
if any(mean) or any(std):
transforms.append(T.Normalize(mean, std, inplace=True))
return T.Compose(transforms)
Expand All @@ -998,6 +998,9 @@ def hsv2colorjitter(h, s, v):


def classify_albumentations(
augment=True,
size=224,
scale=(0.08, 1.0),
hflip=0.5,
vflip=0.0,
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
Expand All @@ -1014,17 +1017,20 @@ def classify_albumentations(
from albumentations.pytorch import ToTensorV2

check_version(A.__version__, '1.0.3', hard=True) # version requirement
T = []
if auto_aug:
# TODO: implement AugMix, AutoAug & RandAug in albumentations
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
else:
if hflip > 0:
T += [A.HorizontalFlip(p=hflip)]
if vflip > 0:
T += [A.VerticalFlip(p=vflip)]
if any((hsv_h, hsv_s, hsv_v)):
T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
if augment: # Resize and crop
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
if auto_aug:
# TODO: implement AugMix, AutoAug & RandAug in albumentations
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
else:
if hflip > 0:
T += [A.HorizontalFlip(p=hflip)]
if vflip > 0:
T += [A.VerticalFlip(p=vflip)]
if any((hsv_h, hsv_s, hsv_v)):
T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
return A.Compose(T)
Expand Down Expand Up @@ -1133,4 +1139,4 @@ def __call__(self, im):
im = torch.from_numpy(im) # to torch
im = im.half() if self.half else im.float() # uint8 to fp16/32
im /= 255.0 # 0-255 to 0.0-1.0
return im
return im
10 changes: 5 additions & 5 deletions ultralytics/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
import numpy as np
import torch
import torchvision
import torchvision.transforms as T

from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable

from .augment import Compose, Format, Instances, LetterBox, ClassifyLetterBox, classify_albumentations, classify_transforms, v8_transforms
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label

Expand Down Expand Up @@ -227,8 +226,10 @@ def __init__(self, root, args, augment=False, cache=False, prefix=''):
self.samples = self.verify_images() # filter out bad images
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
self.torch_transforms = classify_transforms(args.imgsz, rect=args.rect, mean=args.mean, std=args.std)
self.pre_album_transforms = T.Compose([ClassifyLetterBox(size=args.imgsz)])
self.album_transforms = classify_albumentations(
augment=augment,
size=args.imgsz,
scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
hflip=args.fliplr,
vflip=args.flipud,
hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
Expand All @@ -250,7 +251,6 @@ def __getitem__(self, i):
else: # read image
im = cv2.imread(f) # BGR
if self.album_transforms:
im = self.pre_album_transforms(im) # Letterbox
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
else:
sample = self.torch_transforms(im)
Expand Down Expand Up @@ -337,4 +337,4 @@ class SemanticDataset(BaseDataset):

def __init__(self):
"""Initialize a SemanticDataset object."""
super().__init__()
super().__init__()

0 comments on commit 27f1ef2

Please sign in to comment.