Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Augmentation function over a collection of images stored in a Numpy array. #190

Open
gerdm opened this issue Feb 10, 2022 · 2 comments
Labels
enhancement New feature or request

Comments

@gerdm
Copy link

gerdm commented Feb 10, 2022

🚀 Feature

I would like to apply a single augmentation function over a collection of images stored in a Numpy array. I tried passing an (K,L,L) array of K black and white images of size LxL to aug_np_wrapper, but it doesn't seem to work.

Motivation

I'm working on a project that would require transform all of the images in a single dataset, e.g., Fasion MNIST.

Pitch

I would like to be able to apply a transformation over a given axis for a collection of images.

Alternatives

The only alternative I can consider at the moment is using a for-loop, but this will be quite inefficient for I would like to obtain.

@gerdm gerdm changed the title Augmentation function over a collection of images stored in a Numpy array. [feature] Augmentation function over a collection of images stored in a Numpy array. Feb 10, 2022
@zpapakipos zpapakipos added the enhancement New feature or request label Feb 15, 2022
@zpapakipos
Copy link
Contributor

Hi @gerdm, thank you for this suggestion! Indeed our aug_np_wrapper (and all of our image augmentation functions in general) only expect single images as input for now. We could add support to aug_np_wrapper for taking in a batch of images. However, please note that our image augmentations are not implemented in numpy under the hood (aug_np_wrapper is just a wrapper which converts the numpy array to a PIL image and then calls the augmentation). Thus we will have to either use a for loop, which as you said is not very efficient, or we can get some speed-ups e.g. by using multiprocessing.

Let me know what you think, or feel free to try multiprocessing on your side to see if this unblocks you. I will add this to our backlog of tasks and will link the PR here when I get to it :)

@gerdm
Copy link
Author

gerdm commented Mar 5, 2022

Hi @zpapakipos!

I went with the multiprocessing option you outlined. I'll paste the code here in case someone wants to try this in the future. The code below makes use of image.blur, but it's pretty easy to generalise to other methods.

First, we create a class that defines a callable to pass to Python multiprocessing

import numpy as np
from multiprocessing import Pool
from augly import image

class BlurRad:
    def __init__(self, rad):
        self.rad = rad
        
    def __call__(self, img):
        return self.blur_multiple(img)

    def blur(self, X):
        """
        Blur an image using the augly library

        Paramters
        ---------
        X: np.array
            A single NxM-dimensional array
        radius: float
            The amout of blurriness
        """
        return image.aug_np_wrapper(X, image.blur, radius=self.rad)

    def blur_batch(self, X_batch):
        images_out = []
        for X in X_batch:
            img_blur = self.blur(X)
            images_out.append(img_blur)
        images_out = np.stack(images_out, axis=0)
        return images_out

We can then use of Python multiprocessing to blur a collection of images using a single radius.

def proc_dataset(img_dataset, radius, n_processes):
    """
    Blur all images of a dataset stored in a numpy array.
    
    Parameters
    ----------
    radius: float
        Intensity of bluriness
    img_dataset: array(N, L, K)
        N images of size LxK
    n_processes: int
        Number of processes to blur over
    """
    with Pool(processes=n_processes) as pool:
        dataset_proc = np.array_split(img_dataset, n_processes)
        dataset_proc = pool.map(BlurRad(radius), dataset_proc)
    
    dataset_proc = np.concatenate(dataset_proc, axis=0)
    n_obs = len(img_dataset)
    dataset_proc = dataset_proc.reshape(n_obs, -1)
    
    return dataset_proc

If we want to blur different images over different radii, we define the following function

def blur_multiple(radii, img_dataset):
    """
    Blur every element of `img_dataset` given an element
    of `radii`.
    """
    imgs_out = []
    for radius, img in zip(radii, img_dataset):
        img_proc = BlurRad(radius).blur(img)
        imgs_out.append(img_proc)
    imgs_out = np.stack(imgs_out, axis=0)
    
    return imgs_out

def proc_dataset_multiple(img_dataset, radii, n_processes):
    """
    Blur all images of a dataset stored in a numpy array with variable
    radius.
    
    Parameters
    ----------
    radius: array(N,) or float
        Intensity of bluriness. One per image. If
        float, the same value is used for all images.
    img_dataset: array(N, L, K)
        N images of size LxK
    n_processes: int
        Number of processes to blur over
    """

    if type(radii) in [float, np.float_]:
        radii = radii * np.ones(len(img_dataset))
    
    with Pool(processes=n_processes) as pool:
        dataset_proc = np.array_split(img_dataset, n_processes)
        radii_split = np.array_split(radii, n_processes)
        
        elements = zip(radii_split, dataset_proc)
        dataset_proc = pool.starmap(blur_multiple, elements)
        dataset_proc = np.concatenate(dataset_proc, axis=0)

    return dataset_proc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants