Skip to content

Commit

Permalink
v0.0.24: replace dict with th.Any for convinience
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidzee committed May 4, 2023
1 parent 3e6e971 commit 84aa3b9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions lightning_toolbox/training/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,28 @@
import torch
from .datastructures import ArgsListDict

DICT_TYPE = th.Union[dict, th.Any]


class TrainingModule(lightning.LightningModule):
def __init__(
self,
# model
model: th.Optional[torch.nn.Module] = None, # model instance
model_cls: th.Optional[str] = None,
model_args: th.Optional[dict] = None,
model_args: th.Optional[DICT_TYPE] = None,
# objective
objective: th.Optional[Objective] = None, # objective instance
objective_cls: th.Union[type, str] = "lightning_toolbox.Objective",
objective_args: th.Optional[dict] = None,
objective_args: th.Optional[DICT_TYPE] = None,
# optimization configs
# optimizer name or class
optimizer: th.Union[str, type, th.List[th.Union[str, type]], None] = None,
optimizer_frequency: th.Union[int, th.List[th.Optional[int]], None] = None,
optimizer_is_active: th.Optional[th.Union[dy.FunctionDescriptor, th.List[dy.FunctionDescriptor]]] = None,
# optimizer parameters (self.<*>)
optimizer_parameters: th.Optional[th.Union[th.List[str], str]] = None,
optimizer_args: th.Union[dict, th.List[th.Optional[dict]], None] = None,
optimizer_args: th.Union[DICT_TYPE, th.List[th.Optional[DICT_TYPE]], None] = None,
# learning rate
lr: th.Union[th.List[float], float] = 1e-4,
# schedulers
Expand All @@ -49,7 +51,7 @@ def __init__(
scheduler_name: th.Optional[th.Union[str, th.List[str]]] = None,
# optimizer index
scheduler_optimizer: th.Optional[th.Union[int, th.List[int]]] = None,
scheduler_args: th.Optional[th.Union[dict, th.List[dict]]] = None,
scheduler_args: th.Optional[th.Union[DICT_TYPE, th.List[DICT_TYPE]]] = None,
scheduler_interval: th.Union[str, th.List[str]] = "epoch",
scheduler_frequency: th.Union[int, th.List[int]] = 1,
scheduler_monitor: th.Optional[th.Union[str, th.List[str]]] = None,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
setup(
name="lightning_toolbox",
packages=find_packages(include=["lightning_toolbox", "lightning_toolbox.*"]),
version="0.0.23",
version="0.0.24",
license="MIT",
description="A collection of utilities for PyTorch Lightning.",
long_description=long_description,
Expand Down

0 comments on commit 84aa3b9

Please sign in to comment.