Skip to content

Commit 6462392

Browse files
committedDec 11, 2024·
Refactored to use set difference
1 parent fdc9ad3 commit 6462392

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed
 

‎rtrec/utils/interactions.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def __init__(self, min_value: int = -5, max_value: int = 10, decay_in_days: Opti
1717
"""
1818
# Store interactions as a dictionary of dictionaries in shape {user_id: {item_id: (value, timestamp)}}
1919
self.interactions: defaultdict[int, dict[int, tuple[float, float]]] = defaultdict(dict)
20-
self.empty = {}
2120
self.all_item_ids = set()
2221
assert max_value > min_value, f"max_value should be greater than min_value {max_value} > {min_value}"
2322
self.min_value = min_value
@@ -119,12 +118,19 @@ def get_user_items(self, user_id: int, n_recent: Optional[int] = None) -> List[i
119118
Returns:
120119
List[int]: List of item IDs that the user has interacted with.
121120
"""
122-
# use top-k recent items for the user
121+
user_interactions = self.interactions.get(user_id)
122+
if user_interactions is None:
123+
return []
124+
125+
# Use top-k recent items for the user
123126
if n_recent is not None and len(self.interactions) > n_recent:
124-
# sort by timestamp in descending order
125-
return [item_id for item_id, _ in sorted(self.interactions.get(user_id, self.empty).items(), key=lambda x: x[1][1], reverse=True)[:n_recent]]
127+
# Sort by timestamp in descending order
128+
sorted_items = sorted(
129+
user_interactions.items(), key=lambda x: x[1][1], reverse=True
130+
)
131+
return [item_id for item_id, _ in sorted_items[:n_recent]]
126132
else:
127-
return list(self.interactions.get(user_id, self.empty).keys())
133+
return list(user_interactions.keys())
128134

129135
def get_all_item_ids(self) -> List[int]:
130136
"""
@@ -154,8 +160,10 @@ def get_all_non_interacted_items(self, user_id: int) -> List[int]:
154160
Returns:
155161
List[int]: List of item IDs the user has not interacted with.
156162
"""
157-
interacted_items = set(self.get_user_items(user_id))
158-
return [item_id for item_id in self.all_item_ids if item_id not in interacted_items]
163+
interacted_items = self.get_user_items(user_id)
164+
if len(interacted_items) == 0:
165+
return list(self.all_item_ids)
166+
return list(self.all_item_ids.difference(interacted_items))
159167

160168
def get_all_non_negative_items(self, user_id: int) -> List[int]:
161169
"""

0 commit comments

Comments
 (0)
Please sign in to comment.