Skip to content

Commit

Permalink
Added Features class for user/item features
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Dec 6, 2024
1 parent 2af38ef commit 1f8c440
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 0 deletions.
100 changes: 100 additions & 0 deletions rtrec/utils/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import List, Set
from scipy.sparse import csr_matrix
import numpy as np
from .collections import SortedSet

class Features:

def __init__(self):
self.user_features: SortedSet[str] = SortedSet()
self.item_features: SortedSet[str] = SortedSet()
self.user_feature_map: dict[int, List[int]] = {}
self.item_feature_map: dict[int, List[int]] = {}

def put_user_feature(self, user_id: int, user_tags: List[str]) -> None:
"""
Add a list of user features to the user features set.
Replace the existing user features if the user ID already exists.
"""
user_feature_ids = []
for tag in user_tags:
tag_id = self.user_features.add(tag)
if tag_id not in user_feature_ids:
user_feature_ids.append(tag_id)
self.user_feature_map[user_id] = user_feature_ids

def put_item_feature(self, item_id: int, item_tags: List[str]) -> None:
"""
Add a list of item features to the item features set.
Replace the existing item features if the item ID already exists.
"""
item_feature_ids = []
for tag in item_tags:
tag_id = self.item_features.add(tag)
if tag_id not in item_feature_ids:
item_feature_ids.append(tag_id)
self.item_feature_map[item_id] = item_feature_ids

def get_user_feature_repr(self, user_tags: List[str]) -> csr_matrix:
"""
Returns:
csr_matrix: User feature representation of shape (1, n_features)
"""
user_feature_ids = []
for tag in user_tags:
tag_id = self.user_features.index(tag)
if tag_id >= 0:
user_feature_ids.append(tag_id)

cols = np.array(user_feature_ids)
rows = np.zeros(len(user_feature_ids))
data = np.ones(len(user_feature_ids))
return csr_matrix((data, (rows, cols)), shape=(1, len(self.user_features)))

def get_item_feature_repr(self, item_tags: List[str]) -> csr_matrix:
"""
Returns:
csr_matrix: Item feature representation of shape (1, n_features)
"""
item_feature_ids = []
for tag in item_tags:
tag_id = self.item_features.index(tag)
if tag_id >= 0:
item_feature_ids.append(tag_id)

cols = np.array(item_feature_ids)
rows = np.zeros(len(item_feature_ids))
data = np.ones(len(item_feature_ids))
return csr_matrix((data, (rows, cols)), shape=(1, len(self.item_features)))

def build_user_features_matrix(self) -> csr_matrix:
"""
Returns:
csr_matrix: User features matrix of shape (n_users, n_features)
"""

rows, cols, data = [], [], []

for user_id, feature_ids in self.user_feature_map.items():
for feature_id in feature_ids:
rows.append(user_id)
cols.append(feature_id)
data.append(1)

return csr_matrix((data, (rows, cols)), shape=(len(self.user_feature_map), len(self.user_features)))

def build_item_features_matrix(self) -> csr_matrix:
"""
Returns:
csr_matrix: Item features matrix of shape (n_items, n_features)
"""

rows, cols, data = [], [], []

for item_id, feature_ids in self.item_feature_map.items():
for feature_id in feature_ids:
rows.append(item_id)
cols.append(feature_id)
data.append(1)

return csr_matrix((data, (rows, cols)), shape=(len(self.item_feature_map), len(self.item_features)))
84 changes: 84 additions & 0 deletions tests/utils/test_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np
import pytest
from scipy.sparse import csr_matrix
from rtrec.utils.features import Features

# Test adding user features
def test_put_user_feature():
features = Features()
features.put_user_feature(1, ["tag1", "tag2"])
assert 1 in features.user_feature_map
assert set(features.user_feature_map[1]) == {0, 1} # Assuming tag IDs start at 0

def test_put_user_feature_replacement():
features = Features()
features.put_user_feature(1, ["tag1", "tag2"])
features.put_user_feature(1, ["tag3", "tag4"])
assert 1 in features.user_feature_map
assert set(features.user_feature_map[1]) == {2, 3} # Updated IDs for new tags

# Test adding item features
def test_put_item_feature():
features = Features()
features.put_item_feature(1, ["item_tag1", "item_tag2"])
assert 1 in features.item_feature_map
assert set(features.item_feature_map[1]) == {0, 1}

def test_put_item_feature_replacement():
features = Features()
features.put_item_feature(1, ["item_tag1", "item_tag2"])
features.put_item_feature(1, ["item_tag3", "item_tag4"])
assert 1 in features.item_feature_map
assert set(features.item_feature_map[1]) == {2, 3}

# Test getting user feature representation
def test_get_user_feature_repr():
features = Features()
features.put_user_feature(1, ["tag1", "tag2"])
user_repr = features.get_user_feature_repr(["tag1", "tag2"])
expected_matrix = csr_matrix(([1, 1], ([0, 0], [0, 1])), shape=(1, 2))
assert (user_repr != expected_matrix).nnz == 0 # Check equality

def test_get_user_feature_repr_non_existent_tag():
features = Features()
features.put_user_feature(1, ["tag1", "tag2"])
user_repr = features.get_user_feature_repr(["tag3"])
expected_matrix = csr_matrix(([], ([], [])), shape=(1, 2))
assert (user_repr != expected_matrix).nnz == 0

# Test getting item feature representation
def test_get_item_feature_repr():
features = Features()
features.put_item_feature(1, ["item_tag1", "item_tag2"])
item_repr = features.get_item_feature_repr(["item_tag1", "item_tag2"])
expected_matrix = csr_matrix(([1, 1], ([0, 0], [0, 1])), shape=(1, 2))
assert (item_repr != expected_matrix).nnz == 0

def test_get_item_feature_repr_non_existent_tag():
features = Features()
features.put_item_feature(1, ["item_tag1", "item_tag2"])
item_repr = features.get_item_feature_repr(["item_tag3"])
expected_matrix = csr_matrix(([], ([], [])), shape=(1, 2))
assert (item_repr != expected_matrix).nnz == 0

# Test building user features matrix
def test_build_user_features_matrix():
features = Features()
features.put_user_feature(0, ["tag1", "tag2"])
features.put_user_feature(1, ["tag2", "tag3"])
user_matrix = features.build_user_features_matrix()
expected_matrix = csr_matrix(np.matrix([[1, 1, 0], [0, 1, 1]]))
assert (user_matrix != expected_matrix).nnz == 0

# Test building item features matrix
def test_build_item_features_matrix():
features = Features()
features.put_item_feature(0, ["item_tag1", "item_tag2"])
features.put_item_feature(1, ["item_tag2", "item_tag3"])
item_matrix = features.build_item_features_matrix()
expected_matrix = csr_matrix(np.matrix([[1, 1, 0], [0, 1, 1]]))
assert (item_matrix != expected_matrix).nnz == 0

# Run tests using pytest if this file is executed directly
if __name__ == "__main__":
pytest.main()

0 comments on commit 1f8c440

Please sign in to comment.