Skip to content

Commit

Permalink
refactor objective_function, fix construction
Browse files Browse the repository at this point in the history
now you can directly pass on callable functions as funciton terms
  • Loading branch information
vahidzee committed Mar 10, 2023
1 parent fa13497 commit 958a54a
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 67 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ dmypy.json

# Pyre type checker
.pyre/

# idea project
.idea/*
.idea
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# LightningToolbox: A PyTorch Lightning Facilitator
<p align="center">
<a href="#installation">Installation</a> •
<a href="./docs/README.md">Docs</a> •
<a href="https://github.com/vahidzee/lightning-toolbox/docs/README.md">Docs</a> •
<a href="#license">License</a>
</p>

Expand All @@ -12,7 +12,7 @@

Welcome to lightning-toolbox, a python package that offers a set of automation tools built on top of PyTorch Lightning. As a deep learning researcher, I found that PyTorch Lightning offloads a significant portion of redundant work. However, I still found myself spending a considerable amount of time on writing boilerplate code for datamodules, training/validation steps, logging, and specially for not so complicated costum training loops. This is why I created lightning-toolbox - to make it easier to focus on writing experiment-specific code rather than dealing with tedious setup tasks.

By passing your PyTorch model onto a generic `lightning.LitModule` (`lightning_toolbox.TrainingModule`), lightning-toolbox automatically populates the objective function, optimizer step, and more. In addition, lightning-toolbox's generic lightning.DataModule (`lightning_toolbox.DataModule`) can turn any PyTorch dataset, into a experiment-ready lightning data module, completing the cycle for writing lightning deep learning experiments.
By passing your PyTorch model onto a generic `lightning.LightningModule` (`lightning_toolbox.TrainingModule`), lightning-toolbox automatically populates the objective function, optimizer step, and more. In addition, lightning-toolbox's generic `lightning.LightningDataModule` (`lightning_toolbox.DataModule`) can turn any PyTorch dataset, into a experiment-ready lightning data module, completing the cycle for writing lightning deep learning experiments.

Most of the functionality provided in this package is based on [dypy](https://github.com/vahidzee/dypy), which enables lazy evaluation of variables and runtime code injections. Although lightning-toolbox is currently in its early stages and mainly serves as a facilitator for my personal research projects, I believe it can be helpful for many others who deal with similar deep learning experiments. Therefore, I decided to open-source this project and continue to add on to it as I move further in my research.

Expand All @@ -27,4 +27,4 @@ Lightning toolbox is tested on `lightning==1.9.0`, although there's no version r

## License

This project is licensed under the terms of the Apache 2.0 license. See [LICENSE](LICENSE) for more details.
This project is licensed under the terms of the Apache 2.0 license. See [LICENSE](https://github.com/vahidzee/lightning-toolbox/LICENSE) for more details.
2 changes: 0 additions & 2 deletions lightning_toolbox/criterion/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion lightning_toolbox/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Vahid Zehtab 2021
# Copyright Vahid Zehtab (vahid@zehtab.me) 2021
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion lightning_toolbox/data/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Vahid Zehtab 2021
# Copyright Vahid Zehtab (vahid@zehtab.me) 2021
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
15 changes: 15 additions & 0 deletions lightning_toolbox/objective_function/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright Vahid Zehtab (vahid@zehtab.me) 2021
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .objective import Objective, ResultsDict, FactorsDict, TermDescriptor
from .terms import ObjectiveTerm
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
# Copyright Vahid Zehtab (vahid@zehtab.me) 2021
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as th
import torch
import dypy as dy
from .terms import CriterionTerm
from .terms import ObjectiveTerm

# register torch with dypy
dy.register_context(torch)
Expand All @@ -13,27 +26,32 @@
TermDescriptor = th.Union[str, th.Dict[str, th.Any], dy.FunctionDescriptor]


class Criterion:
"""Generic training objective abstraction for PyTorch Lightning.
class Objective:
"""Generic training objective function abstraction.
Criterion is a collection of terms that are combined together to form the overall loss/objective function.
ObjectiveFunction is a collection of terms that are combined together to form the overall loss/objective function.
The overall loss/objective value is computed as follows:
- terms are processed and their values are stored in the results dict
- results of the terms are combined together to form the
The abstraction for terms are CriterionTerm objects. Each term is a function that takes the following arguments:
- *args, **kwargs: the arguments passed to the criterion
- batch: the batch of data
- training_module: the training module
The term abstraction also allows for the computation of factors that are applied to the term values.
The abstraction for terms are ObjectiveTerm objects. Each term is a function that by default has a freeform signature
`*args, **kwargs`, which takes on all the arguments passed on to the objective function. The term abstraction allows
for the specification of the arguments that are passed to the term function. The term abstraction also allows for
the specification of the name of the term. The term abstraction also allows for the computation of factors that are
applied to the term values.
The combination of the terms and the factors are controlled by the term abstraction itself, but the overall reduction
of the terms is controlled by the criterion. The criterion can reduce the terms by summing them or multiplying them.
of the terms is controlled by the objective, which by default is the sum of the term values.
Attributes:
terms: List of terms to use in the training objective.
terms: List of terms to use in the objective function.
latch: A dictionary that can be used to store values for later use. The values are cleared after each call to
__call__.
results_latch: A dictionary that stores the results of the terms. The values are cleared after each call to
__call__.
factors_latch: A dictionary that stores the factors of the terms. The values are cleared after each call to
__call__. The factors are computed by the terms themselves, and just stored here for convenience.
"""

latch: th.Dict[th.Any, th.Any]
Expand All @@ -42,13 +60,13 @@ def __init__(
self,
*term_args: TermDescriptor,
**term_kwargs: TermDescriptor,
):
term_args = [CriterionTerm.from_description(term, criterion=self) for term in term_args]
) -> None:
term_args = [ObjectiveTerm.from_description(term, objective=self) for term in term_args]
term_kwargs = [
CriterionTerm.from_description(term, criterion=self, name=name) for name, term in term_kwargs.items()
ObjectiveTerm.from_description(term, objective=self, name=name) for name, term in term_kwargs.items()
]
self.terms = term_args + term_kwargs
self.__rename_terms(terms=self.terms)
self.terms: th.List[ObjectiveTerm] = term_args + term_kwargs
self.__rename_terms(terms=self.terms) # to make sure all terms have unique names
# initialize the latches
self.latch, self.results_latch, self.factors_latch = {}, {}, {}

Expand All @@ -60,22 +78,30 @@ def results(self) -> ResultsDict:
def factors(self) -> FactorsDict:
return self.factors_latch

def remember(self, **kwargs):
def remember(self, **kwargs) -> None:
"""Keep some values for later use. Get's cleared after each call to __call__"""
self.latch.update(kwargs)

def _forget(self):
def _forget(self) -> None:
"""Forget all remembered values and clear all latches."""
self.latch.clear()
self.results_latch.clear()
self.factors_latch.clear()

def __getitem__(self, key: th.Union[str, int]) -> ObjectiveTerm:
if isinstance(key, int):
return self.terms[key]
for term in self.terms: # TODO: use a dict for faster lookup
if term.name == key:
return term
raise KeyError(f"Term with name {key} not found.")

@property
def terms_names(self) -> th.List[str]:
return [term.name for term in self.terms]

@staticmethod
def __rename_terms(terms: th.List[CriterionTerm], prefix: str = "") -> None:
def __rename_terms(terms: th.List[ObjectiveTerm], prefix: str = "") -> None:
names_count = {term.name: 0 for term in terms}
for term in terms:
names_count[term.name] += 1
Expand All @@ -94,16 +120,16 @@ def process_terms_results(self, *args, **kwargs) -> None:
"""
Call all the terms and store their results in the results latch.
Terms are called in the order they are provided to the criterion. When a term is called, everything in provided
to the criterion is also passed to it. The results of the term are either a single value or a dict of values.
Terms are called in the order they are provided to the objective. When a term is called, everything in provided
to the objective is also passed to it. The results of the term are either a single value or a dict of values.
If the results are a dict, and the term is to contribute to the overall loss, it has to contain a key "loss".
All of the returned keys are stored in the results latch, but only the terms that provide a "loss" key (or
have a single scalar value returned) are used to compute the overall loss. Other returned values can be used
for logging or other purposes.
Args:
*args, **kwargs: the arguments passed to the criterion are directly passed to the terms.
*args, **kwargs: the arguments passed to the objective are directly passed to the terms.
"""
for term in self.terms:
term_results = term(*args, **kwargs)
Expand Down Expand Up @@ -138,7 +164,7 @@ def __call__(
self._forget() # clear the latches
self.process_terms_results(*args, **kwargs) # process the terms
self.factors_latch = {
term.name: term.factor_value(*args, **kwargs) for term in self.terms if term.name in self.results_latch
term.name: term.factor(*args, **kwargs) for term in self.terms if term.name in self.results_latch
} # compute the factor values for the terms that contribute to the loss
self.results_latch["loss"] = self.reduce() # reduce the term results with the factors
return self.results_latch if not return_factors else (self.results_latch, self.factors_latch)
Loading

0 comments on commit 958a54a

Please sign in to comment.