From ad2a9fdea27aeb99664e2179fc8cc5decab1919f Mon Sep 17 00:00:00 2001 From: Vahid Zehtab <33608325+vahidzee@users.noreply.github.com> Date: Mon, 8 May 2023 17:36:32 -0400 Subject: [PATCH] v0.0.25: add TupleDataTransform --- docs/data_module.md | 21 ++++++------- lightning_toolbox/data/__init__.py | 2 +- lightning_toolbox/data/transforms.py | 46 ++++++++++++---------------- setup.py | 2 +- 4 files changed, 31 insertions(+), 40 deletions(-) diff --git a/docs/data_module.md b/docs/data_module.md index ef488be..ae93671 100644 --- a/docs/data_module.md +++ b/docs/data_module.md @@ -165,17 +165,16 @@ The following are possible ways of specifying a transformation object: #### Utility transformations It might be the case that you want to apply the same kind of transformations to a part of your data (like when dealing with (input, output) pairs). `lightning_toolbox` provides utility transformations that can be used to handle such cases. These transformations are: -* `lightning_toolbox.data.transforms.PairedDataTransform`: - This transformation class takes on two sets of transformations, one for the input and one for the target. Everything is similar to the way you would specify transformations for the datamodule, except that you have to specify the transformations for the input and output separately. The input transformations are applied to the first element of the tuple, and the target transformations are applied to the second element of the tuple. For instance, the following is a valid way of specifying a `PairedDataTransform` for a vision task for predicting the parity of the input digit image: +* `lightning_toolbox.data.transforms.TupleDataTransform`: + This transformation class takes on sets of transformations, for each element in the tuple data. Everything is similar to the way you would specify transformations for the datamodule, except that you have to specify the transformations for elements seperately. For instance, the following is a valid way of specifying a `TupleDataTransform` for a vision task for predicting the parity of the input digit image: ```python from lightning_toolbox import DataModule - from lightning_toolbox.data.transforms import PairedDataTransform + from lightning_toolbox.data.transforms import TupleDataTransform DataModule( dataset='torchvision.datasets.MNIST', - transforms=PairedDataTransform( - input_transforms='torchvision.transforms.ToTensor()', - target_transforms='lambda x: x % 2' + transforms=TupleDataTransform( + transforms=['torchvision.transforms.ToTensor()', 'lambda x: x % 2'] ) ) @@ -183,11 +182,11 @@ It might be the case that you want to apply the same kind of transformations to DataModule( dataset='torchvision.datasets.MNIST', transforms=dict( - class_path='lightning_toolbox.data.transforms.PairedDataTransform', - init_args=dict( - input_transforms='torchvision.transforms.ToTensor()', - target_transforms='lambda x: x % 2' - ) + class_path='lightning_toolbox.data.transforms.TupleDataTransform', + init_args={ + 0:'torchvision.transforms.ToTensor()', + 1:'lambda x: x % 2' + } ) ) ``` diff --git a/lightning_toolbox/data/__init__.py b/lightning_toolbox/data/__init__.py index 8542dbb..874de5d 100644 --- a/lightning_toolbox/data/__init__.py +++ b/lightning_toolbox/data/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .module import DataModule -from .transforms import transform_dataset, TransformsDescriptor, PairedDataTransform +from .transforms import transform_dataset, TransformsDescriptor, TupleDataTransform diff --git a/lightning_toolbox/data/transforms.py b/lightning_toolbox/data/transforms.py index 253a49f..e45f592 100644 --- a/lightning_toolbox/data/transforms.py +++ b/lightning_toolbox/data/transforms.py @@ -153,10 +153,10 @@ def transform_dataset( return WrapDataset(dataset) -class PairedDataTransform: +class TupleDataTransform: """ - Utility class to wrap a pair dataset (any dataset that returns a tuple of (input, target)) with - a list of transforms for input and target separately. + Utility class to wrap a tuple dataset (any dataset that returns a tuple e.g. (input, target)) with + a list of transforms for each element of the tuple. This class specially useful for dealing with datasets from torchvision, and coupling them with torchvision.transforms. @@ -164,37 +164,30 @@ class PairedDataTransform: Example: >>> from torchvision.datasets import MNIST >>> from torchvision import transforms - >>> from lightning_toolbox.data import VisionTransform + >>> from lightning_toolbox.data import TupleDataTransform >>> dataset = MNIST(root=".", download=True, transform=None, target_transform=None) - >>> transform = PairedTransform(input_transforms=transforms.ToTensor(), target_transforms=transforms.ToTensor()) - >>> transform_dataset(dataset, transform) + >>> transform = TupleDataTransform(transforms.ToTensor(), transforms.ToTensor()) + >>> transform(dataset, transform) Attributes: - input_transforms: Either None or a torchvision.transforms.Compose object that contains a list of transforms - to apply to the input. - target_transforms: Either None or a torchvision.transforms.Compose object that contains a list of transforms - to apply to the target. + transforms: A list of either Nones or transformations to be applied to associated elements of the tuple. """ def __init__( self, - input_transforms: th.Optional[TransformsDescriptor] = None, - target_transforms: th.Optional[TransformsDescriptor] = None, + transforms: th.Union[ + th.List[th.Optional[TransformDescriptor]], th.Dict[int, th.Optional[TransformDescriptor]] + ], ): """ Initialize the transform. Args: - input_transforms: a (list of) transform(s) to apply to the input - target_transforms: a (list of) transform(s) to apply to the target + transforms: A list of either Nones or transformations to be applied to associated elements of the tuple. """ - input_transforms = ( - initialize_transforms(input_transforms, force_list=True) if input_transforms is not None else None - ) - target_transforms = ( - initialize_transforms(target_transforms, force_list=True) if target_transforms is not None else None - ) - self.input_transforms, self.target_transforms = input_transforms, target_transforms + if isinstance(transforms, dict): + transforms = [transforms.get(i, None) for i in range(len(transforms))] + self.transforms = [initialize_transforms(t, force_list=True) if t is not None else None for t in transforms] def __call__(self, datam) -> th.Tuple[th.Any, th.Any]: """ @@ -206,9 +199,8 @@ def __call__(self, datam) -> th.Tuple[th.Any, th.Any]: Returns: a tuple of (input, target) """ - x, y = datam - for transform in self.input_transforms or []: - x = transform(x) - for transform in self.target_transforms or []: - y = transform(y) - return x, y + results = list(datam) + for i, transforms in enumerate(self.transforms): + for transform in transforms or []: + results[i] = transform(results[i]) + return tuple(results) diff --git a/setup.py b/setup.py index dd0db15..4f32c6c 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ setup( name="lightning_toolbox", packages=find_packages(include=["lightning_toolbox", "lightning_toolbox.*"]), - version="0.0.24", + version="0.0.25", license="MIT", description="A collection of utilities for PyTorch Lightning.", long_description=long_description,