From 958a54a39fe4b709b7948e5f2c14b64927e1a1e7 Mon Sep 17 00:00:00 2001 From: Vahid Zehtab <33608325+vahidzee@users.noreply.github.com> Date: Fri, 10 Mar 2023 18:00:25 -0500 Subject: [PATCH] refactor objective_function, fix construction now you can directly pass on callable functions as funciton terms --- .gitignore | 4 + README.md | 6 +- lightning_toolbox/criterion/__init__.py | 2 - lightning_toolbox/data/__init__.py | 2 +- lightning_toolbox/data/module.py | 2 +- .../objective_function/__init__.py | 15 ++++ .../objective.py} | 74 ++++++++++----- .../terms.py | 90 +++++++++++-------- setup.py | 2 +- 9 files changed, 130 insertions(+), 67 deletions(-) delete mode 100644 lightning_toolbox/criterion/__init__.py create mode 100644 lightning_toolbox/objective_function/__init__.py rename lightning_toolbox/{criterion/criterion.py => objective_function/objective.py} (59%) rename lightning_toolbox/{criterion => objective_function}/terms.py (76%) diff --git a/.gitignore b/.gitignore index b6e4761..8b3f7da 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,7 @@ dmypy.json # Pyre type checker .pyre/ + +# idea project +.idea/* +.idea \ No newline at end of file diff --git a/README.md b/README.md index d142320..b2e0bb8 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # LightningToolbox: A PyTorch Lightning Facilitator

Installation • - Docs • + DocsLicense

@@ -12,7 +12,7 @@ Welcome to lightning-toolbox, a python package that offers a set of automation tools built on top of PyTorch Lightning. As a deep learning researcher, I found that PyTorch Lightning offloads a significant portion of redundant work. However, I still found myself spending a considerable amount of time on writing boilerplate code for datamodules, training/validation steps, logging, and specially for not so complicated costum training loops. This is why I created lightning-toolbox - to make it easier to focus on writing experiment-specific code rather than dealing with tedious setup tasks. -By passing your PyTorch model onto a generic `lightning.LitModule` (`lightning_toolbox.TrainingModule`), lightning-toolbox automatically populates the objective function, optimizer step, and more. In addition, lightning-toolbox's generic lightning.DataModule (`lightning_toolbox.DataModule`) can turn any PyTorch dataset, into a experiment-ready lightning data module, completing the cycle for writing lightning deep learning experiments. +By passing your PyTorch model onto a generic `lightning.LightningModule` (`lightning_toolbox.TrainingModule`), lightning-toolbox automatically populates the objective function, optimizer step, and more. In addition, lightning-toolbox's generic `lightning.LightningDataModule` (`lightning_toolbox.DataModule`) can turn any PyTorch dataset, into a experiment-ready lightning data module, completing the cycle for writing lightning deep learning experiments. Most of the functionality provided in this package is based on [dypy](https://github.com/vahidzee/dypy), which enables lazy evaluation of variables and runtime code injections. Although lightning-toolbox is currently in its early stages and mainly serves as a facilitator for my personal research projects, I believe it can be helpful for many others who deal with similar deep learning experiments. Therefore, I decided to open-source this project and continue to add on to it as I move further in my research. @@ -27,4 +27,4 @@ Lightning toolbox is tested on `lightning==1.9.0`, although there's no version r ## License -This project is licensed under the terms of the Apache 2.0 license. See [LICENSE](LICENSE) for more details. \ No newline at end of file +This project is licensed under the terms of the Apache 2.0 license. See [LICENSE](https://github.com/vahidzee/lightning-toolbox/LICENSE) for more details. \ No newline at end of file diff --git a/lightning_toolbox/criterion/__init__.py b/lightning_toolbox/criterion/__init__.py deleted file mode 100644 index 4db77db..0000000 --- a/lightning_toolbox/criterion/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .criterion import Criterion, ResultsDict, FactorsDict, TermDescriptor -from .terms import CriterionTerm diff --git a/lightning_toolbox/data/__init__.py b/lightning_toolbox/data/__init__.py index 14e838d..9d004e8 100644 --- a/lightning_toolbox/data/__init__.py +++ b/lightning_toolbox/data/__init__.py @@ -1,4 +1,4 @@ -# Copyright Vahid Zehtab 2021 +# Copyright Vahid Zehtab (vahid@zehtab.me) 2021 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/lightning_toolbox/data/module.py b/lightning_toolbox/data/module.py index 9b98349..0faf2dc 100644 --- a/lightning_toolbox/data/module.py +++ b/lightning_toolbox/data/module.py @@ -1,4 +1,4 @@ -# Copyright Vahid Zehtab 2021 +# Copyright Vahid Zehtab (vahid@zehtab.me) 2021 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/lightning_toolbox/objective_function/__init__.py b/lightning_toolbox/objective_function/__init__.py new file mode 100644 index 0000000..e537c70 --- /dev/null +++ b/lightning_toolbox/objective_function/__init__.py @@ -0,0 +1,15 @@ +# Copyright Vahid Zehtab (vahid@zehtab.me) 2021 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .objective import Objective, ResultsDict, FactorsDict, TermDescriptor +from .terms import ObjectiveTerm diff --git a/lightning_toolbox/criterion/criterion.py b/lightning_toolbox/objective_function/objective.py similarity index 59% rename from lightning_toolbox/criterion/criterion.py rename to lightning_toolbox/objective_function/objective.py index 1750f70..4193194 100644 --- a/lightning_toolbox/criterion/criterion.py +++ b/lightning_toolbox/objective_function/objective.py @@ -1,7 +1,20 @@ +# Copyright Vahid Zehtab (vahid@zehtab.me) 2021 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import typing as th import torch import dypy as dy -from .terms import CriterionTerm +from .terms import ObjectiveTerm # register torch with dypy dy.register_context(torch) @@ -13,27 +26,32 @@ TermDescriptor = th.Union[str, th.Dict[str, th.Any], dy.FunctionDescriptor] -class Criterion: - """Generic training objective abstraction for PyTorch Lightning. +class Objective: + """Generic training objective function abstraction. - Criterion is a collection of terms that are combined together to form the overall loss/objective function. + ObjectiveFunction is a collection of terms that are combined together to form the overall loss/objective function. The overall loss/objective value is computed as follows: - terms are processed and their values are stored in the results dict - results of the terms are combined together to form the - The abstraction for terms are CriterionTerm objects. Each term is a function that takes the following arguments: - - *args, **kwargs: the arguments passed to the criterion - - batch: the batch of data - - training_module: the training module - - The term abstraction also allows for the computation of factors that are applied to the term values. + The abstraction for terms are ObjectiveTerm objects. Each term is a function that by default has a freeform signature + `*args, **kwargs`, which takes on all the arguments passed on to the objective function. The term abstraction allows + for the specification of the arguments that are passed to the term function. The term abstraction also allows for + the specification of the name of the term. The term abstraction also allows for the computation of factors that are + applied to the term values. The combination of the terms and the factors are controlled by the term abstraction itself, but the overall reduction - of the terms is controlled by the criterion. The criterion can reduce the terms by summing them or multiplying them. + of the terms is controlled by the objective, which by default is the sum of the term values. Attributes: - terms: List of terms to use in the training objective. + terms: List of terms to use in the objective function. + latch: A dictionary that can be used to store values for later use. The values are cleared after each call to + __call__. + results_latch: A dictionary that stores the results of the terms. The values are cleared after each call to + __call__. + factors_latch: A dictionary that stores the factors of the terms. The values are cleared after each call to + __call__. The factors are computed by the terms themselves, and just stored here for convenience. """ latch: th.Dict[th.Any, th.Any] @@ -42,13 +60,13 @@ def __init__( self, *term_args: TermDescriptor, **term_kwargs: TermDescriptor, - ): - term_args = [CriterionTerm.from_description(term, criterion=self) for term in term_args] + ) -> None: + term_args = [ObjectiveTerm.from_description(term, objective=self) for term in term_args] term_kwargs = [ - CriterionTerm.from_description(term, criterion=self, name=name) for name, term in term_kwargs.items() + ObjectiveTerm.from_description(term, objective=self, name=name) for name, term in term_kwargs.items() ] - self.terms = term_args + term_kwargs - self.__rename_terms(terms=self.terms) + self.terms: th.List[ObjectiveTerm] = term_args + term_kwargs + self.__rename_terms(terms=self.terms) # to make sure all terms have unique names # initialize the latches self.latch, self.results_latch, self.factors_latch = {}, {}, {} @@ -60,22 +78,30 @@ def results(self) -> ResultsDict: def factors(self) -> FactorsDict: return self.factors_latch - def remember(self, **kwargs): + def remember(self, **kwargs) -> None: """Keep some values for later use. Get's cleared after each call to __call__""" self.latch.update(kwargs) - def _forget(self): + def _forget(self) -> None: """Forget all remembered values and clear all latches.""" self.latch.clear() self.results_latch.clear() self.factors_latch.clear() + def __getitem__(self, key: th.Union[str, int]) -> ObjectiveTerm: + if isinstance(key, int): + return self.terms[key] + for term in self.terms: # TODO: use a dict for faster lookup + if term.name == key: + return term + raise KeyError(f"Term with name {key} not found.") + @property def terms_names(self) -> th.List[str]: return [term.name for term in self.terms] @staticmethod - def __rename_terms(terms: th.List[CriterionTerm], prefix: str = "") -> None: + def __rename_terms(terms: th.List[ObjectiveTerm], prefix: str = "") -> None: names_count = {term.name: 0 for term in terms} for term in terms: names_count[term.name] += 1 @@ -94,8 +120,8 @@ def process_terms_results(self, *args, **kwargs) -> None: """ Call all the terms and store their results in the results latch. - Terms are called in the order they are provided to the criterion. When a term is called, everything in provided - to the criterion is also passed to it. The results of the term are either a single value or a dict of values. + Terms are called in the order they are provided to the objective. When a term is called, everything in provided + to the objective is also passed to it. The results of the term are either a single value or a dict of values. If the results are a dict, and the term is to contribute to the overall loss, it has to contain a key "loss". All of the returned keys are stored in the results latch, but only the terms that provide a "loss" key (or @@ -103,7 +129,7 @@ def process_terms_results(self, *args, **kwargs) -> None: for logging or other purposes. Args: - *args, **kwargs: the arguments passed to the criterion are directly passed to the terms. + *args, **kwargs: the arguments passed to the objective are directly passed to the terms. """ for term in self.terms: term_results = term(*args, **kwargs) @@ -138,7 +164,7 @@ def __call__( self._forget() # clear the latches self.process_terms_results(*args, **kwargs) # process the terms self.factors_latch = { - term.name: term.factor_value(*args, **kwargs) for term in self.terms if term.name in self.results_latch + term.name: term.factor(*args, **kwargs) for term in self.terms if term.name in self.results_latch } # compute the factor values for the terms that contribute to the loss self.results_latch["loss"] = self.reduce() # reduce the term results with the factors return self.results_latch if not return_factors else (self.results_latch, self.factors_latch) diff --git a/lightning_toolbox/criterion/terms.py b/lightning_toolbox/objective_function/terms.py similarity index 76% rename from lightning_toolbox/criterion/terms.py rename to lightning_toolbox/objective_function/terms.py index 33ad9dc..74ff630 100644 --- a/lightning_toolbox/criterion/terms.py +++ b/lightning_toolbox/objective_function/terms.py @@ -1,10 +1,23 @@ +# Copyright Vahid Zehtab (vahid@zehtab.me) 2021 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import torch import functools import dypy as dy import typing as th -class CriterionTerm: +class ObjectiveTerm: def __init__( self, name: str = None, @@ -15,28 +28,29 @@ def __init__( ) -> None: """ Term is a single value that is used to compute the objective function (e.g. MSE, KL, etc.). - Terms can be combined into a Criterion, which is a collection of terms that are summed together. + Terms can be combined into an Objective, which is a collection of terms that are (by default) summed together. You can use the Term class in two ways: 1. You can pass a function that computes the term value, and optionally a factor that scales the term value. - 2. You can inherit from the CriterionTerm class and implement the __call__ method and optionally other methods. + 2. You can inherit from the ObjectiveTerm class and implement the __call__ method and optionally other methods. Args: name (str): name of the term. This is used to identify the term in the results dictionary. If not provided, the name of the term-class is used. + factor (float or FunctionDescriptor): Factor that scales the term value. this can be a single float value, or a function that computes the factor value. (all the arguments - passed to the criterion are passed to the factor function as well), and through `self.criterion` one - can access the latches of the criterion (e.g. `self.criterion.results`, `self.criterion.latch`). + passed to the objective are passed to the factor function as well), and through `self.objective` one + can access the latches of the objective (e.g. `self.objective.results`, `self.objective.latch`). Example: 1. factor annealed with the reciprocal of the number of epochs: >>> factor = "lambda self, training_module: 1/training_module.trainer.current_epoch" 2. factor with twice the value of the term "mse": - >>> factor = "lambda self, training_module: 2*self.criterion.results['mse']" + >>> factor = "lambda self, training_module: 2*self.objective.results['mse']" - Use this option only if you don't want to inherit from the CriterionTerm class. In general, it is + Use this option only if you don't want to inherit from the ObjectiveTerm class. In general, it is recommended to use this option since it provides most of the functionality one requires for applying a factor to a term. @@ -62,19 +76,24 @@ def __init__( self._factor_description = factor if factor is not None else 1.0 self._term_function_description = term_function or kwargs self._scale_factor = scale_factor - self.criterion = None + self.objective: "training_toolbox.Objective" = None - # link to criterion + # link to objective @property def remember(self): - return self.criterion.remember + return self.objective.remember - def _register_criterion(self, criterion): - """Register a link to the criterion that this term belongs to. + def _register_objective(self, objective): + """Register a link to the objective that this term belongs to. - This is used to access criterions' attributes such as latch/. + This is used to access objectives' attributes such as latch/. """ - self.criterion = criterion + self.objective = objective + + @property + def factor(self): + """Returns the factor-value function""" + return self._compute_factor @functools.cached_property def _compiled_factor(self): @@ -82,7 +101,7 @@ def _compiled_factor(self): return dy.dynamic_args_wrapper(compiled) if callable(compiled) else compiled # TODO: rewrite as @dy.method(signature="dynamic") - def factor_value(self, *args, **kwargs) -> torch.Tensor: + def _compute_factor(self, *args, **kwargs) -> torch.Tensor: """ Computes the final factor value to be applied to the term value. By default this is a wrapper around the `factor` (function/float) that is passed to the term constructor. @@ -96,8 +115,8 @@ def factor_value(self, *args, **kwargs) -> torch.Tensor: Args: results_dict (ResultsDict): Dictionary containing the results of other terms in the objective function. - if this dictionary is proccessed by the `Criterion` class it will contain `term/` and - `regularization/` entries for each term and regularization in the criterion. + if this dictionary is proccessed by the `Objective` class it will contain `term/` and + `regularization/` entries for each term and regularization in the objective. training_module (lightning.LightningModule): The training module that is being trained. @@ -125,8 +144,8 @@ def scale_factor(self, factor_value) -> th.Union[torch.Tensor, int, float]: if self._scale_factor: return ( factor_value - * self.criterion.results_latch[self._scale_factor].data.clone() - / self.criterion.results_latch[self.name].data.clone() + * self.objective.results_latch[self._scale_factor].data.clone() + / self.objective.results_latch[self.name].data.clone() ) else: return factor_value @@ -159,12 +178,12 @@ def _compiled_term_function(self): # TODO: rewrite with a @dy.method(signature="dynamic") base term_function def __call__(self, *args, **kwargs) -> torch.Tensor: """ - Computes the term value. This is the main method of the `Term` class. It is called by the `Criterion` class + Computes the term value. This is the main method of the `Term` class. It is called by the `Objective` class when computing the objective function value. If this method is not overridden, it will call the `term_function` that was provided to the constructor. The `term_function` should have the signature `term_function(*args, **kwargs)` where `*args` and `**kwargs` - are any arguments provided by the user when calling the criterion. + are any arguments provided by the user when calling the objective. Args: *args: arguments to be passed to the `term_function`. @@ -188,37 +207,38 @@ def __call__(self, *args, **kwargs) -> torch.Tensor: @staticmethod def from_description( - description: th.Union["CriterionTerm", "TermDescriptor"], + description: th.Union["ObjectiveTerm", "TermDescriptor"], # overwrite attributes of the instance name: th.Optional[str] = None, - criterion: th.Optional["Criterion"] = None, - ) -> "CriterionTerm": + objective: th.Optional["Objective"] = None, + ) -> "ObjectiveTerm": """ - Creates a `CriterionTerm` instance from a `TermDescriptor` object. + Creates a `ObjectiveTerm` instance from a `TermDescriptor` object. Args: description (TermDescriptor): The term descriptor. name (str): The name of the term. If not provided, the name from the description will be used. - criterion (Criterion): The criterion that this term belongs to. + objective (Objective): The objective that this term belongs to. Returns: - CriterionTerm: The criterion term. + ObjectiveTerm: The objective term. """ - if isinstance(description, CriterionTerm): + if isinstance(description, ObjectiveTerm): term = description elif isinstance(description, str): - try: - term = dy.eval(description, dynamic_args=True)() - except: - term = CriterionTerm(term_function=description, name=name) + term = ObjectiveTerm.from_description(dy.eval(description, dynamic_args=True, strict=False)) + elif isinstance(description, type) and issubclass(description, ObjectiveTerm): + term = description() + elif callable(description): + term = ObjectiveTerm(term_function=description) # else the description is a dict # checking if the description provides a class_path to instantiate a previously defined term elif "class_path" in description: term = dy.eval(description["class_path"])(**description.get("init_args", dict())) # else the description is a dict with required fields to instantiate a new term else: - term = CriterionTerm(**description) + term = ObjectiveTerm(**description) if name is not None: term.name = name - if criterion is not None: - term._register_criterion(criterion) + if objective is not None: + term._register_objective(objective) return term diff --git a/setup.py b/setup.py index ee5f0cc..fe45078 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ long_description_content_type="text/markdown", author="Vahid Zehtab", author_email="vahid@zehtab.me", - url="https://github.com/vahidzee/lightning_toolbox", + url="https://github.com/vahidzee/lightning-toolbox", keywords=["artificial intelligence", "pytorch lightning", "objective functions", "regularization"], install_requires=["torch>=1.9", "lightning", "dypy"], classifiers=[