@@ -17,7 +17,6 @@ def __init__(self, min_value: int = -5, max_value: int = 10, decay_in_days: Opti
17
17
"""
18
18
# Store interactions as a dictionary of dictionaries in shape {user_id: {item_id: (value, timestamp)}}
19
19
self .interactions : defaultdict [int , dict [int , tuple [float , float ]]] = defaultdict (dict )
20
- self .empty = {}
21
20
self .all_item_ids = set ()
22
21
assert max_value > min_value , f"max_value should be greater than min_value { max_value } > { min_value } "
23
22
self .min_value = min_value
@@ -119,12 +118,19 @@ def get_user_items(self, user_id: int, n_recent: Optional[int] = None) -> List[i
119
118
Returns:
120
119
List[int]: List of item IDs that the user has interacted with.
121
120
"""
122
- # use top-k recent items for the user
121
+ user_interactions = self .interactions .get (user_id )
122
+ if user_interactions is None :
123
+ return []
124
+
125
+ # Use top-k recent items for the user
123
126
if n_recent is not None and len (self .interactions ) > n_recent :
124
- # sort by timestamp in descending order
125
- return [item_id for item_id , _ in sorted (self .interactions .get (user_id , self .empty ).items (), key = lambda x : x [1 ][1 ], reverse = True )[:n_recent ]]
127
+ # Sort by timestamp in descending order
128
+ sorted_items = sorted (
129
+ user_interactions .items (), key = lambda x : x [1 ][1 ], reverse = True
130
+ )
131
+ return [item_id for item_id , _ in sorted_items [:n_recent ]]
126
132
else :
127
- return list (self . interactions . get ( user_id , self . empty ) .keys ())
133
+ return list (user_interactions .keys ())
128
134
129
135
def get_all_item_ids (self ) -> List [int ]:
130
136
"""
@@ -154,8 +160,10 @@ def get_all_non_interacted_items(self, user_id: int) -> List[int]:
154
160
Returns:
155
161
List[int]: List of item IDs the user has not interacted with.
156
162
"""
157
- interacted_items = set (self .get_user_items (user_id ))
158
- return [item_id for item_id in self .all_item_ids if item_id not in interacted_items ]
163
+ interacted_items = self .get_user_items (user_id )
164
+ if len (interacted_items ) == 0 :
165
+ return list (self .all_item_ids )
166
+ return list (self .all_item_ids .difference (interacted_items ))
159
167
160
168
def get_all_non_negative_items (self , user_id : int ) -> List [int ]:
161
169
"""
0 commit comments