Skip to content

Commit

Permalink
add data transformations documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidzee committed Apr 13, 2023
1 parent 5fe289b commit f209abb
Showing 1 changed file with 100 additions and 3 deletions.
103 changes: 100 additions & 3 deletions docs/data_module.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# `lightning_toolbox.DataModule`
`lightning_toolbox.DataModule` aims to capture the generic task for setting up DataModules in pytorch lightning (which are a way of decoupling data from the model implementation, [see the docs](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html)). It could be very useful if you are dealing with datasets with the following assumptions:
`lightning_toolbox.DataModule` aims to capture the generic task for setting up DataModules in pytorch lightning (which are a way of decoupling data from the model implementation, [see the docs](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html)). It could be very useful if you are dealing with datasets with the following assumption:

1. You don't want to transform the data on the fly, and your dataset object returns samples which are ready to be fed to your model. (`TODO:` In a future version, we will add support for on-the-fly functional transformations)
2. Creating a `torch.utils.data.DataLoader` from your dataset is as simple as calling `torch.utils.data.DataLoader(dataset, batch_size=...)`, which is the case for most datasets in deep learning.
> Creating a `torch.utils.data.DataLoader` from your dataset is as simple as calling `torch.utils.data.DataLoader(dataset, batch_size=...)`, which is the case for most datasets in deep learning.
In this case, you can use `lightning_toolbox.DataModule` to create a `LightningDataModule` from your dataset(s). You can either use it to transform your dataset(s) into a datamodule, or create your own datamodule by inheriting a class from it.

Expand All @@ -11,6 +10,9 @@ In this case, you can use `lightning_toolbox.DataModule` to create a `LightningD
- [Specifying Datasets](#specifying-datasets)
- [Validation Split](#validation-split)
- [Configuring the Dataloaders](#configuring-the-dataloaders)
- [Transformations](#transformations)
- [Utility Transformations](#utility-transformations)
- [Extending Core functionality through Inheritance](#extending-core-functionality-through-inheritance)


## Usage
Expand Down Expand Up @@ -95,6 +97,101 @@ dm = DataModule(
val_num_workers=8, # This will override num_workers for the validation split
)
```
### Transformations
Transformations are functions applied to your datapoints before they are fed to the model. You can specify transformations for each of the splits by passing a `train_transforms`, `val_transforms`, and `test_transforms` arguments. These arguments should be a list of transformations to be applied to the data. You can also pass a `transforms` argument which will be used for all the splits. If you pass both `transforms` and `train_transforms` for example, the `train_transforms` will be used for the training split and the `transforms` will be used for the validation and test splits.

The following are possible ways of specifying a transformation object:

* A callable function or a function descriptor which could be a path to a function or a code that evaluetes to a callable value (e.g. `torchvision.transforms.functional_tensor` or `torchvision.transforms.ToTensor()`), or directly a function, whether as a string or as a python code.

To learn more about possible ways of specifying a custom function check out dypy's [function specification](https://github.com/vahidzee/dypy) documentation.

For Instance, the following are all valid ways of specifying a transformation:
```python
from lightning_toolbox import DataModule
from torchvision.transforms import ToTensor

DataModule(
..., # dataset descriptions
# a callable value
transforms=lambda x: x,
# ----
# a string definition of a annonymous lambda function
transforms='lambda x: x',
# ----
# path to a callable value
transforms='torchvision.transforms.functional_tensor.to_tensor',
# ----
# piece of code that evaluates to a callable value
transforms='torchvision.transforms.ToTensor()',
# ----
transforms=dict(
# an actual piece of code that evaluates to a module with
# callable functions (which we get the function of interest from)
code="""
import torchvision.transforms as T
def factor(x):
return 1
def transform(x):
return T.ToTensor()(x) * factor(x)
""",
function_of_interest='transform'
)
)
```
* A transformation class (and its arguments) which could be a string with the path to a class or the class variable itself (e.g. `"torchvision.transforms.ToTensor"` or `torchvision.transforms.RandomCrop`), or directly a class, whether as a string or as a python code.

For instance, the following are all valid ways of specifying a transformation class:
```python
from lightning_toolbox import DataModule

DataModule(
..., # dataset descriptions
# a class variable
transforms=torchvision.transforms.ToTensor,
# ----
# a string definition of a class variable
# when a class is provided as a string, it is assumed that the class is in the dypy context
# and that it needs no arguments to be instantiated
transforms='torchvision.transforms.ToTensor',
# ----
# a class path and its arguments
transforms=dict(
class_path='torchvision.transforms.RandomCrop',
init_args=dict(size=32)
),
)
```
#### Utility transformations
It might be the case that you want to apply the same kind of transformations to a part of your data (like when dealing with (input, output) pairs). `lightning_toolbox` provides utility transformations that can be used to handle such cases. These transformations are:

* `lightning_toolbox.data.transforms.PairedDataTransform`:
This transformation class takes on two sets of transformations, one for the input and one for the target. Everything is similar to the way you would specify transformations for the datamodule, except that you have to specify the transformations for the input and output separately. The input transformations are applied to the first element of the tuple, and the target transformations are applied to the second element of the tuple. For instance, the following is a valid way of specifying a `PairedDataTransform` for a vision task for predicting the parity of the input digit image:
```python
from lightning_toolbox import DataModule
from lightning_toolbox.data.transforms import PairedDataTransform

DataModule(
dataset='torchvision.datasets.MNIST',
transforms=PairedDataTransform(
input_transforms='torchvision.transforms.ToTensor()',
target_transforms='lambda x: x % 2'
)
)

# or equivalently (using class_path and init_args)
DataModule(
dataset='torchvision.datasets.MNIST',
transforms=dict(
class_path='lightning_toolbox.data.transforms.PairedDataTransform',
init_args=dict(
input_transforms='torchvision.transforms.ToTensor()',
target_transforms='lambda x: x % 2'
)
)
)
```


## Extending Core functionality through Inheritance
You may wish to extend the base functionality of `lightning_toolbox.DataModule` perhaps to only use the dataloader configurations, or provide your own `perepare_data` method. The only thing to keep in mind is that, if you want to use the dataloader's functionality, your datamodule should have `train_data`, `val_data` and `test_data` specified by the time lightning want's to call the dataloaders functions.
Expand Down

0 comments on commit f209abb

Please sign in to comment.