Skip to content

Commit

Permalink
Merge branch 'main' into feature/interactions_to_raw
Browse files Browse the repository at this point in the history
  • Loading branch information
feldlime committed Dec 13, 2023
2 parents 1a520a3 + 7b5cf2d commit c90bc87
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
7 changes: 6 additions & 1 deletion rectools/dataset/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Mapping between external and internal ids."""

import typing as tp
import warnings

import attr
import numpy as np
Expand Down Expand Up @@ -80,7 +81,11 @@ def from_dict(cls, mapping: tp.Dict[ExternalId, InternalId]) -> "IdMap":
order = np.argsort(internal_ids)
internal_ids_sorted = internal_ids[order]

internals_incorrect = internal_ids_sorted != np.arange(internal_ids_sorted.size)
with warnings.catch_warnings():
# When comparing numeric vs. non-numeric array returns scalar, will change in the future
warnings.simplefilter("ignore", FutureWarning)
internals_incorrect = internal_ids_sorted != np.arange(internal_ids_sorted.size)

if internals_incorrect is True or internals_incorrect.any():
raise ValueError("Internal ids must be integers from 0 to n_objects-1")

Expand Down
9 changes: 7 additions & 2 deletions rectools/models/dssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@
from __future__ import annotations

import typing as tp
import warnings
from copy import deepcopy

import numpy as np
import torch
import torch.nn.functional as F
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import Logger

with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import Logger

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ doctest_optionflags = DONT_ACCEPT_TRUE_FOR_1 NORMALIZE_WHITESPACE
filterwarnings =
ignore:LightFM was compiled without OpenMP support
ignore:distutils Version classes are deprecated
ignore:Converting sparse features to dense array may cause MemoryError
ignore:OpenBLAS is configured to use

[coverage:run]
# the name of the data file to use for storing or reporting coverage.
Expand Down
22 changes: 13 additions & 9 deletions tests/model_selection/test_cross_validate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=attribute-defined-outside-init

import typing as tp
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -251,12 +252,15 @@ def test_happy_path_with_features(self) -> None:
def test_fail_with_cold_users(self) -> None:
splitter = LastNSplitter(n=1, n_splits=2, filter_cold_users=False)

with pytest.raises(KeyError):
cross_validate(
dataset=self.dataset,
splitter=splitter,
metrics=self.metrics,
models=self.models,
k=2,
filter_viewed=False,
)
with warnings.catch_warnings(record=True) as w:
with pytest.raises(KeyError):
cross_validate(
dataset=self.dataset,
splitter=splitter,
metrics=self.metrics,
models=self.models,
k=2,
filter_viewed=False,
)
assert len(w) == 1
assert "Currently models do not support recommendations for cold users" in str(w[-1].message)

0 comments on commit c90bc87

Please sign in to comment.