-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added Features class for user/item features
- Loading branch information
Showing
2 changed files
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from typing import List, Set | ||
from scipy.sparse import csr_matrix | ||
import numpy as np | ||
from .collections import SortedSet | ||
|
||
class Features: | ||
|
||
def __init__(self): | ||
self.user_features: SortedSet[str] = SortedSet() | ||
self.item_features: SortedSet[str] = SortedSet() | ||
self.user_feature_map: dict[int, List[int]] = {} | ||
self.item_feature_map: dict[int, List[int]] = {} | ||
|
||
def put_user_feature(self, user_id: int, user_tags: List[str]) -> None: | ||
""" | ||
Add a list of user features to the user features set. | ||
Replace the existing user features if the user ID already exists. | ||
""" | ||
user_feature_ids = [] | ||
for tag in user_tags: | ||
tag_id = self.user_features.add(tag) | ||
if tag_id not in user_feature_ids: | ||
user_feature_ids.append(tag_id) | ||
self.user_feature_map[user_id] = user_feature_ids | ||
|
||
def put_item_feature(self, item_id: int, item_tags: List[str]) -> None: | ||
""" | ||
Add a list of item features to the item features set. | ||
Replace the existing item features if the item ID already exists. | ||
""" | ||
item_feature_ids = [] | ||
for tag in item_tags: | ||
tag_id = self.item_features.add(tag) | ||
if tag_id not in item_feature_ids: | ||
item_feature_ids.append(tag_id) | ||
self.item_feature_map[item_id] = item_feature_ids | ||
|
||
def get_user_feature_repr(self, user_tags: List[str]) -> csr_matrix: | ||
""" | ||
Returns: | ||
csr_matrix: User feature representation of shape (1, n_features) | ||
""" | ||
user_feature_ids = [] | ||
for tag in user_tags: | ||
tag_id = self.user_features.index(tag) | ||
if tag_id >= 0: | ||
user_feature_ids.append(tag_id) | ||
|
||
cols = np.array(user_feature_ids) | ||
rows = np.zeros(len(user_feature_ids)) | ||
data = np.ones(len(user_feature_ids)) | ||
return csr_matrix((data, (rows, cols)), shape=(1, len(self.user_features))) | ||
|
||
def get_item_feature_repr(self, item_tags: List[str]) -> csr_matrix: | ||
""" | ||
Returns: | ||
csr_matrix: Item feature representation of shape (1, n_features) | ||
""" | ||
item_feature_ids = [] | ||
for tag in item_tags: | ||
tag_id = self.item_features.index(tag) | ||
if tag_id >= 0: | ||
item_feature_ids.append(tag_id) | ||
|
||
cols = np.array(item_feature_ids) | ||
rows = np.zeros(len(item_feature_ids)) | ||
data = np.ones(len(item_feature_ids)) | ||
return csr_matrix((data, (rows, cols)), shape=(1, len(self.item_features))) | ||
|
||
def build_user_features_matrix(self) -> csr_matrix: | ||
""" | ||
Returns: | ||
csr_matrix: User features matrix of shape (n_users, n_features) | ||
""" | ||
|
||
rows, cols, data = [], [], [] | ||
|
||
for user_id, feature_ids in self.user_feature_map.items(): | ||
for feature_id in feature_ids: | ||
rows.append(user_id) | ||
cols.append(feature_id) | ||
data.append(1) | ||
|
||
return csr_matrix((data, (rows, cols)), shape=(len(self.user_feature_map), len(self.user_features))) | ||
|
||
def build_item_features_matrix(self) -> csr_matrix: | ||
""" | ||
Returns: | ||
csr_matrix: Item features matrix of shape (n_items, n_features) | ||
""" | ||
|
||
rows, cols, data = [], [], [] | ||
|
||
for item_id, feature_ids in self.item_feature_map.items(): | ||
for feature_id in feature_ids: | ||
rows.append(item_id) | ||
cols.append(feature_id) | ||
data.append(1) | ||
|
||
return csr_matrix((data, (rows, cols)), shape=(len(self.item_feature_map), len(self.item_features))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import numpy as np | ||
import pytest | ||
from scipy.sparse import csr_matrix | ||
from rtrec.utils.features import Features | ||
|
||
# Test adding user features | ||
def test_put_user_feature(): | ||
features = Features() | ||
features.put_user_feature(1, ["tag1", "tag2"]) | ||
assert 1 in features.user_feature_map | ||
assert set(features.user_feature_map[1]) == {0, 1} # Assuming tag IDs start at 0 | ||
|
||
def test_put_user_feature_replacement(): | ||
features = Features() | ||
features.put_user_feature(1, ["tag1", "tag2"]) | ||
features.put_user_feature(1, ["tag3", "tag4"]) | ||
assert 1 in features.user_feature_map | ||
assert set(features.user_feature_map[1]) == {2, 3} # Updated IDs for new tags | ||
|
||
# Test adding item features | ||
def test_put_item_feature(): | ||
features = Features() | ||
features.put_item_feature(1, ["item_tag1", "item_tag2"]) | ||
assert 1 in features.item_feature_map | ||
assert set(features.item_feature_map[1]) == {0, 1} | ||
|
||
def test_put_item_feature_replacement(): | ||
features = Features() | ||
features.put_item_feature(1, ["item_tag1", "item_tag2"]) | ||
features.put_item_feature(1, ["item_tag3", "item_tag4"]) | ||
assert 1 in features.item_feature_map | ||
assert set(features.item_feature_map[1]) == {2, 3} | ||
|
||
# Test getting user feature representation | ||
def test_get_user_feature_repr(): | ||
features = Features() | ||
features.put_user_feature(1, ["tag1", "tag2"]) | ||
user_repr = features.get_user_feature_repr(["tag1", "tag2"]) | ||
expected_matrix = csr_matrix(([1, 1], ([0, 0], [0, 1])), shape=(1, 2)) | ||
assert (user_repr != expected_matrix).nnz == 0 # Check equality | ||
|
||
def test_get_user_feature_repr_non_existent_tag(): | ||
features = Features() | ||
features.put_user_feature(1, ["tag1", "tag2"]) | ||
user_repr = features.get_user_feature_repr(["tag3"]) | ||
expected_matrix = csr_matrix(([], ([], [])), shape=(1, 2)) | ||
assert (user_repr != expected_matrix).nnz == 0 | ||
|
||
# Test getting item feature representation | ||
def test_get_item_feature_repr(): | ||
features = Features() | ||
features.put_item_feature(1, ["item_tag1", "item_tag2"]) | ||
item_repr = features.get_item_feature_repr(["item_tag1", "item_tag2"]) | ||
expected_matrix = csr_matrix(([1, 1], ([0, 0], [0, 1])), shape=(1, 2)) | ||
assert (item_repr != expected_matrix).nnz == 0 | ||
|
||
def test_get_item_feature_repr_non_existent_tag(): | ||
features = Features() | ||
features.put_item_feature(1, ["item_tag1", "item_tag2"]) | ||
item_repr = features.get_item_feature_repr(["item_tag3"]) | ||
expected_matrix = csr_matrix(([], ([], [])), shape=(1, 2)) | ||
assert (item_repr != expected_matrix).nnz == 0 | ||
|
||
# Test building user features matrix | ||
def test_build_user_features_matrix(): | ||
features = Features() | ||
features.put_user_feature(0, ["tag1", "tag2"]) | ||
features.put_user_feature(1, ["tag2", "tag3"]) | ||
user_matrix = features.build_user_features_matrix() | ||
expected_matrix = csr_matrix(np.matrix([[1, 1, 0], [0, 1, 1]])) | ||
assert (user_matrix != expected_matrix).nnz == 0 | ||
|
||
# Test building item features matrix | ||
def test_build_item_features_matrix(): | ||
features = Features() | ||
features.put_item_feature(0, ["item_tag1", "item_tag2"]) | ||
features.put_item_feature(1, ["item_tag2", "item_tag3"]) | ||
item_matrix = features.build_item_features_matrix() | ||
expected_matrix = csr_matrix(np.matrix([[1, 1, 0], [0, 1, 1]])) | ||
assert (item_matrix != expected_matrix).nnz == 0 | ||
|
||
# Run tests using pytest if this file is executed directly | ||
if __name__ == "__main__": | ||
pytest.main() |