Skip to content

Commit

Permalink
v0.0.25: add TupleDataTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidzee committed May 8, 2023
1 parent 84aa3b9 commit ad2a9fd
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 40 deletions.
21 changes: 10 additions & 11 deletions docs/data_module.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,29 +165,28 @@ The following are possible ways of specifying a transformation object:
#### Utility transformations
It might be the case that you want to apply the same kind of transformations to a part of your data (like when dealing with (input, output) pairs). `lightning_toolbox` provides utility transformations that can be used to handle such cases. These transformations are:

* `lightning_toolbox.data.transforms.PairedDataTransform`:
This transformation class takes on two sets of transformations, one for the input and one for the target. Everything is similar to the way you would specify transformations for the datamodule, except that you have to specify the transformations for the input and output separately. The input transformations are applied to the first element of the tuple, and the target transformations are applied to the second element of the tuple. For instance, the following is a valid way of specifying a `PairedDataTransform` for a vision task for predicting the parity of the input digit image:
* `lightning_toolbox.data.transforms.TupleDataTransform`:
This transformation class takes on sets of transformations, for each element in the tuple data. Everything is similar to the way you would specify transformations for the datamodule, except that you have to specify the transformations for elements seperately. For instance, the following is a valid way of specifying a `TupleDataTransform` for a vision task for predicting the parity of the input digit image:
```python
from lightning_toolbox import DataModule
from lightning_toolbox.data.transforms import PairedDataTransform
from lightning_toolbox.data.transforms import TupleDataTransform

DataModule(
dataset='torchvision.datasets.MNIST',
transforms=PairedDataTransform(
input_transforms='torchvision.transforms.ToTensor()',
target_transforms='lambda x: x % 2'
transforms=TupleDataTransform(
transforms=['torchvision.transforms.ToTensor()', 'lambda x: x % 2']
)
)

# or equivalently (using class_path and init_args)
DataModule(
dataset='torchvision.datasets.MNIST',
transforms=dict(
class_path='lightning_toolbox.data.transforms.PairedDataTransform',
init_args=dict(
input_transforms='torchvision.transforms.ToTensor()',
target_transforms='lambda x: x % 2'
)
class_path='lightning_toolbox.data.transforms.TupleDataTransform',
init_args={
0:'torchvision.transforms.ToTensor()',
1:'lambda x: x % 2'
}
)
)
```
Expand Down
2 changes: 1 addition & 1 deletion lightning_toolbox/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .module import DataModule
from .transforms import transform_dataset, TransformsDescriptor, PairedDataTransform
from .transforms import transform_dataset, TransformsDescriptor, TupleDataTransform
46 changes: 19 additions & 27 deletions lightning_toolbox/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,48 +153,41 @@ def transform_dataset(
return WrapDataset(dataset)


class PairedDataTransform:
class TupleDataTransform:
"""
Utility class to wrap a pair dataset (any dataset that returns a tuple of (input, target)) with
a list of transforms for input and target separately.
Utility class to wrap a tuple dataset (any dataset that returns a tuple e.g. (input, target)) with
a list of transforms for each element of the tuple.
This class specially useful for dealing with datasets from torchvision, and coupling them with
torchvision.transforms.
Example:
>>> from torchvision.datasets import MNIST
>>> from torchvision import transforms
>>> from lightning_toolbox.data import VisionTransform
>>> from lightning_toolbox.data import TupleDataTransform
>>> dataset = MNIST(root=".", download=True, transform=None, target_transform=None)
>>> transform = PairedTransform(input_transforms=transforms.ToTensor(), target_transforms=transforms.ToTensor())
>>> transform_dataset(dataset, transform)
>>> transform = TupleDataTransform(transforms.ToTensor(), transforms.ToTensor())
>>> transform(dataset, transform)
Attributes:
input_transforms: Either None or a torchvision.transforms.Compose object that contains a list of transforms
to apply to the input.
target_transforms: Either None or a torchvision.transforms.Compose object that contains a list of transforms
to apply to the target.
transforms: A list of either Nones or transformations to be applied to associated elements of the tuple.
"""

def __init__(
self,
input_transforms: th.Optional[TransformsDescriptor] = None,
target_transforms: th.Optional[TransformsDescriptor] = None,
transforms: th.Union[
th.List[th.Optional[TransformDescriptor]], th.Dict[int, th.Optional[TransformDescriptor]]
],
):
"""
Initialize the transform.
Args:
input_transforms: a (list of) transform(s) to apply to the input
target_transforms: a (list of) transform(s) to apply to the target
transforms: A list of either Nones or transformations to be applied to associated elements of the tuple.
"""
input_transforms = (
initialize_transforms(input_transforms, force_list=True) if input_transforms is not None else None
)
target_transforms = (
initialize_transforms(target_transforms, force_list=True) if target_transforms is not None else None
)
self.input_transforms, self.target_transforms = input_transforms, target_transforms
if isinstance(transforms, dict):
transforms = [transforms.get(i, None) for i in range(len(transforms))]
self.transforms = [initialize_transforms(t, force_list=True) if t is not None else None for t in transforms]

def __call__(self, datam) -> th.Tuple[th.Any, th.Any]:
"""
Expand All @@ -206,9 +199,8 @@ def __call__(self, datam) -> th.Tuple[th.Any, th.Any]:
Returns:
a tuple of (input, target)
"""
x, y = datam
for transform in self.input_transforms or []:
x = transform(x)
for transform in self.target_transforms or []:
y = transform(y)
return x, y
results = list(datam)
for i, transforms in enumerate(self.transforms):
for transform in transforms or []:
results[i] = transform(results[i])
return tuple(results)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
setup(
name="lightning_toolbox",
packages=find_packages(include=["lightning_toolbox", "lightning_toolbox.*"]),
version="0.0.24",
version="0.0.25",
license="MIT",
description="A collection of utilities for PyTorch Lightning.",
long_description=long_description,
Expand Down

0 comments on commit ad2a9fd

Please sign in to comment.