Skip to content

Commit

Permalink
fixed ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
blondered committed Feb 6, 2025
1 parent f23903c commit 4e87e3f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
4 changes: 2 additions & 2 deletions rectools/models/nn/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _fit(

dataset_schema = self.data_preparator.train_dataset.get_schema()
item_external_ids = self.data_preparator.train_dataset.item_id_map.external_ids
model_config = self.get_config()
model_config = self.get_config(simple_types=True)
self._init_lightning_model(
torch_model=torch_model,
dataset_schema=dataset_schema,
Expand Down Expand Up @@ -461,7 +461,7 @@ def __getstate__(self) -> object:
checkpoint = Path(f.name).read_bytes()
state: tp.Dict[str, tp.Any] = {"fitted_checkpoint": checkpoint}
return state
state = {"model_config": self.get_config()}
state = {"model_config": self.get_config(simple_types=True)}
return state

def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None:
Expand Down
15 changes: 3 additions & 12 deletions tests/models/nn/test_transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import pytest
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

from rectools import Columns
Expand All @@ -30,7 +29,7 @@
from rectools.models.nn.transformer_base import TransformerModelBase
from tests.models.utils import assert_save_load_do_not_change_model

from .utils import custom_trainer, leave_one_out_mask
from .utils import custom_trainer, custom_trainer_ckpt, leave_one_out_mask


class TestTransformerModelBase:
Expand Down Expand Up @@ -122,24 +121,16 @@ def test_save_load_for_fitted_model(
def test_load_from_checkpoint(
self,
model_cls: tp.Type[TransformerModelBase],
tmp_path: str,
dataset: Dataset,
) -> None:

model = model_cls.from_config(
{
"deterministic": True,
"item_net_block_types": (IdEmbeddingsItemNet,), # TODO: add CatFeaturesItemNet
"get_trainer_func": custom_trainer_ckpt,
}
)
model._trainer = Trainer( # pylint: disable=protected-access
default_root_dir=tmp_path,
max_epochs=2,
min_epochs=2,
deterministic=True,
accelerator="cpu",
devices=1,
callbacks=ModelCheckpoint(filename="last_epoch"),
)
model.fit(dataset)

assert model.fit_trainer is not None
Expand Down
13 changes: 13 additions & 0 deletions tests/models/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pandas as pd
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from rectools import Columns

Expand All @@ -36,3 +37,15 @@ def custom_trainer() -> Trainer:
enable_checkpointing=False,
devices=1,
)


def custom_trainer_ckpt() -> Trainer:
return Trainer(
# default_root_dir=tmp_path,
max_epochs=3,
min_epochs=3,
deterministic=True,
accelerator="cpu",
devices=1,
callbacks=ModelCheckpoint(filename="last_epoch"),
)

0 comments on commit 4e87e3f

Please sign in to comment.