Skip to content

Commit

Permalink
systemfields: index communities in records
Browse files Browse the repository at this point in the history
  • Loading branch information
slint authored and kpsherva committed Feb 15, 2024
1 parent 93099c4 commit 2accde6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from invenio_records.systemfields import SystemField

from .....utils import filter_dict_keys
from .context import CommunitiesFieldContext
from .manager import CommunitiesRelationManager

Expand Down Expand Up @@ -57,3 +58,42 @@ def __get__(self, record, owner=None):
if record is None:
return self._context_cls(self, owner)
return self.obj(record)

def post_dump(self, record, data, dumper=None):
"""Dump the communities field."""
comms = getattr(record, self.attr_name)
res = comms.to_dict()
keep_fields = [
"uuid",
"created",
"updated",
"id",
"slug",
"theme",
"version_id",
"metadata.title",
"metadata.type",
"metadata.website",
"metadata.organizations",
"metadata.funding",
"parent.id",
"parent.slug",
"parent.theme",
"parent.metadata.title",
"parent.metadata.type",
"parent.metadata.website",
"parent.metadata.organizations",
"parent.metadata.funding",
]
if res:
res["entries"] = [
filter_dict_keys(comm.dumps(), keep_fields) for comm in comms
]
data[self.key] = res

def post_load(self, record, data, loader=None):
"""Laod the parent community using the OS data (preventing a DB query)."""
comms = data.get("communities")
if comms:
obj = self._manager_cls(self._m2m_model_cls, record.id, comms)
self._set_cache(record, obj)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
associated with a record.
"""

import uuid

from invenio_db import db
from invenio_records.api import Record

Expand All @@ -36,8 +38,8 @@ def __init__(self, m2m_model_cls, record_id, data):
#
def _to_id(self, val):
"""Get the community id."""
if isinstance(val, str):
return val
if isinstance(val, (str, uuid.UUID)):
return str(val)
elif isinstance(val, Record):
return str(val.id)
return None
Expand Down Expand Up @@ -159,6 +161,11 @@ def default(self):
return self._lookup_community(self._default_id)
return None

@property
def entries(self):
"""Get community objects list."""
return list(self)

@default.setter
def default(self, community_or_id):
"""Set the default community.
Expand Down Expand Up @@ -194,4 +201,9 @@ def from_dict(self, data):
data = data or {}
self._default_id = data.get("default", None)
self._communities_ids = set(data.get("ids", []))
# Search results will have denormalized communities, so we can populate the
# cache from it.
entries = data.pop("entries", None)
if entries:
self._communities_cache = {c["id"]: Community.loads(c) for c in entries}
return self

0 comments on commit 2accde6

Please sign in to comment.