Skip to content

Commit

Permalink
pr fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
blondered committed Feb 3, 2025
1 parent 54555b2 commit 87d0e7a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 17 deletions.
17 changes: 11 additions & 6 deletions rectools/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,25 @@
from .identifiers import ExternalId, IdMap
from .interactions import Interactions

DenseOrSparseFeatureName = tp.Union[str, SparseFeatureName]
AnyFeatureName = tp.Union[str, SparseFeatureName]


def _serialize_feature_name(spec: DenseOrSparseFeatureName) -> Hashable:
def _serialize_feature_name(spec: tp.Union[AnyFeatureName, tp.Any]) -> Hashable:
if isinstance(spec, tuple):
return tuple(_serialize_feature_name(item) for item in spec)
if isinstance(spec, (int, float, str)):
if isinstance(spec, (int, float, str, bool)):
return spec
if np.issubdtype(spec, np.number): # type:ignore[unreachable]
if not isinstance(spec, np.ndarray) and np.issubdtype(spec, np.number) or np.issubdtype(spec, np.bool_):
return spec.item()
return "unsupported feature name"
raise ValueError(
f"""
Serialization for feature name {spec} is not supported.
Please convert your feature names and values to strings, numbers, booleans or their tuples or lists.
"""
)


FeatureName = tpe.Annotated[DenseOrSparseFeatureName, PlainSerializer(_serialize_feature_name, when_used="json")]
FeatureName = tpe.Annotated[AnyFeatureName, PlainSerializer(_serialize_feature_name, when_used="json")]
DatasetSchemaDict = tp.Dict[str, tp.Any]


Expand Down
6 changes: 2 additions & 4 deletions rectools/models/nn/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
BERT4Rec model: transformer-based sequential model with bidirectional attention mechanism and
"MLM" (masked item in user sequence) training objective.
Our implementation covers multiple loss functions and a variable number of negatives for them.
Notes
References
----------
Transformers tutorial: https://rectools.readthedocs.io/en/stable/examples/tutorials/transformers_tutorial.html
Advanced training guide:
Expand Down Expand Up @@ -217,8 +217,6 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
Verbosity level.
Enables progress bar, model summary and logging in default lightning trainer when set to a
positive integer.
Enables automatic lightning checkpointing when set to 100 or higher. This will save the most
the most recent model to a single checkpoint after each epoch.
Will be omitted if `get_trainer_func` is specified.
dataloader_num_workers : int, default 0
Number of loader worker processes.
Expand Down
4 changes: 1 addition & 3 deletions rectools/models/nn/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
"Shifted Sequence" training objective.
Our implementation covers multiple loss functions and a variable number of negatives for them.
Notes
References
----------
Transformers tutorial: https://rectools.readthedocs.io/en/stable/examples/tutorials/transformers_tutorial.html
Advanced training guide:
Expand Down Expand Up @@ -246,8 +246,6 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
Verbosity level.
Enables progress bar, model summary and logging in default lightning trainer when set to a
positive integer.
Enables automatic lightning checkpointing when set to 100 or higher. This will save the most
the most recent model to a single checkpoint after each epoch.
Will be omitted if `get_trainer_func` is specified.
dataloader_num_workers : int, default 0
Number of loader worker processes.
Expand Down
8 changes: 4 additions & 4 deletions rectools/models/nn/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def _init_trainer(self) -> None:
enable_progress_bar=self.verbose > 0,
enable_model_summary=self.verbose > 0,
logger=self.verbose > 0,
enable_checkpointing=self.verbose > 99,
enable_checkpointing=False,
devices=1,
)
else:
Expand Down Expand Up @@ -931,12 +931,12 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:

def __getstate__(self) -> object:
if self.is_fitted:
if self.fit_trainer is None:
raise RuntimeError("Model that was loaded from checkpoint cannot be saved without being fitted again")
with NamedTemporaryFile() as f:
if self.fit_trainer is None:
raise TypeError("Model that was loaded from checkpoint cannot be saved without being fitted again")
self.fit_trainer.save_checkpoint(f.name)
checkpoint = Path(f.name).read_bytes()
state: tp.Dict[str, tp.Any] = {"fitted_checkpoint": checkpoint}
state: tp.Dict[str, tp.Any] = {"fitted_checkpoint": checkpoint}
return state
state = {"model_config": self.get_config()}
return state
Expand Down

0 comments on commit 87d0e7a

Please sign in to comment.