Skip to content

Commit

Permalink
Fixed a bug handling candidate_item_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Jan 31, 2025
1 parent 0a5b312 commit e9d490b
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions rtrec/models/lightfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def bulk_fit(self, parallel: bool=False, progress_bar: bool=True) -> None:
def _recommend(self, user_id: int, candidate_item_ids: Optional[List[int]] = None, user_tags: Optional[List[str]] = None, top_k: int = 10, filter_interacted: bool = True) -> List[int]:
users_tags = [user_tags] if user_tags is not None else None
user_features = self._create_user_features(user_ids=[user_id], users_tags=users_tags, slice=True)
item_features = self._create_item_features(item_ids=candidate_item_ids, slice=True)
item_features = self._create_item_features(item_ids=candidate_item_ids, slice=False)

user_biases, user_embeddings = self.model.get_user_representations(user_features)
item_biases, item_embeddings = self.model.get_item_representations(item_features)
Expand All @@ -106,6 +106,8 @@ def _recommend(self, user_id: int, candidate_item_ids: Optional[List[int]] = Non
# the largest possible negative finite value in float32, which is approximately -3.4028235e+38.
min_score = -np.finfo(np.float32).max
# remove ids less than or equal to min_score
if candidate_item_ids and len(ids) > len(candidate_item_ids):
ids = ids[:len(candidate_item_ids)]
for i in range(len(ids)):
if scores[i] <= min_score:
ids = ids[:i]
Expand All @@ -116,7 +118,7 @@ def _recommend(self, user_id: int, candidate_item_ids: Optional[List[int]] = Non
@override
def _recommend_batch(self, user_ids: List[int], candidate_item_ids: Optional[List[int]] = None, users_tags: Optional[List[List[str]]] = None, top_k: int = 10, filter_interacted: bool = True) -> List[List[int]]:
user_features = self._create_user_features(user_ids=user_ids, users_tags=users_tags, slice=True)
item_features = self._create_item_features(item_ids=candidate_item_ids, slice=True)
item_features = self._create_item_features(item_ids=candidate_item_ids, slice=False)

user_biases, user_embeddings = self.model.get_user_representations(user_features)
item_biases, item_embeddings = self.model.get_item_representations(item_features)
Expand Down Expand Up @@ -144,6 +146,8 @@ def _recommend_batch(self, user_ids: List[int], candidate_item_ids: Optional[Lis
# the largest possible negative finite value in float32, which is approximately -3.4028235e+38.
min_score = -np.finfo(np.float32).max
for ids, scores in zip(ids_array, scores_array):
if candidate_item_ids and len(ids) > len(candidate_item_ids):
ids = ids[:len(candidate_item_ids)]
# remove ids less than or equal to min_score
for i in range(len(ids)):
if scores[i] <= min_score:
Expand Down

0 comments on commit e9d490b

Please sign in to comment.