Skip to content

Commit

Permalink
Added to_csr and to_csc
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Dec 10, 2024
1 parent 15c551e commit 5e1be2d
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions rtrec/utils/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional, Any
import time, math

from scipy.sparse import csr_matrix
from scipy.sparse import csr_matrix, csc_matrix

class UserItemInteractions:
def __init__(self, min_value: int = -5, max_value: int = 10, decay_in_days: Optional[int] = None, **kwargs: Any) -> None:
Expand All @@ -15,6 +15,7 @@ def __init__(self, min_value: int = -5, max_value: int = 10, decay_in_days: Opti
decay_rate (Optional[float]): Rate at which interactions decay over time.
If None, no decay is applied.
"""
# Store interactions as a dictionary of dictionaries in shape {user_id: {item_id: (value, timestamp)}}
self.interactions: defaultdict[int, dict[int, tuple[float, float]]] = defaultdict(dict)
self.empty = {}
self.all_item_ids = set()
Expand Down Expand Up @@ -174,13 +175,28 @@ def to_csr(self) -> csr_matrix:
rows, cols, data = [], [], []
max_row, max_col = 0, 0

for row, inner_dict in self.interactions.items():
for col, (value, tstamp) in inner_dict.items():
rows.append(row)
cols.append(col)
data.append(self._apply_decay(value, tstamp))
max_row = max(max_row, row)
max_col = max(max_col, col)
for user, inner_dict in self.interactions.items():
for item, (rating, tstamp) in inner_dict.items():
rows.append(user)
cols.append(item)
data.append(self._apply_decay(rating, tstamp))
max_row = max(max_row, user)
max_col = max(max_col, item)

# Create the csr_matrix
return csr_matrix((data, (rows, cols)), shape=(max_row, max_col))

def to_csc(self) -> csc_matrix:
rows, cols, data = [], [], []
max_row, max_col = 0, 0

for user, inner_dict in self.interactions.items():
for item, (rating, tstamp) in inner_dict.items():
rows.append(user)
cols.append(item)
data.append(self._apply_decay(rating, tstamp))
max_row = max(max_row, user)
max_col = max(max_col, item)

# Create the csc_matrix
return csc_matrix((data, (rows, cols)), shape=(max_row, max_col))

0 comments on commit 5e1be2d

Please sign in to comment.