diff --git a/invenio_communities/records/records/systemfields/communities/field.py b/invenio_communities/records/records/systemfields/communities/field.py index 0d75207de..cf7a6a64b 100644 --- a/invenio_communities/records/records/systemfields/communities/field.py +++ b/invenio_communities/records/records/systemfields/communities/field.py @@ -10,6 +10,7 @@ from invenio_records.systemfields import SystemField +from .....utils import filter_dict_keys from .context import CommunitiesFieldContext from .manager import CommunitiesRelationManager @@ -57,3 +58,42 @@ def __get__(self, record, owner=None): if record is None: return self._context_cls(self, owner) return self.obj(record) + + def post_dump(self, record, data, dumper=None): + """Dump the communities field.""" + comms = getattr(record, self.attr_name) + res = comms.to_dict() + keep_fields = [ + "uuid", + "created", + "updated", + "id", + "slug", + "theme", + "version_id", + "metadata.title", + "metadata.type", + "metadata.website", + "metadata.organizations", + "metadata.funding", + "parent.id", + "parent.slug", + "parent.theme", + "parent.metadata.title", + "parent.metadata.type", + "parent.metadata.website", + "parent.metadata.organizations", + "parent.metadata.funding", + ] + if res: + res["entries"] = [ + filter_dict_keys(comm.dumps(), keep_fields) for comm in comms + ] + data[self.key] = res + + def post_load(self, record, data, loader=None): + """Laod the parent community using the OS data (preventing a DB query).""" + comms = data.get("communities") + if comms: + obj = self._manager_cls(self._m2m_model_cls, record.id, comms) + self._set_cache(record, obj) diff --git a/invenio_communities/records/records/systemfields/communities/manager.py b/invenio_communities/records/records/systemfields/communities/manager.py index 93c45dddd..c4330edff 100644 --- a/invenio_communities/records/records/systemfields/communities/manager.py +++ b/invenio_communities/records/records/systemfields/communities/manager.py @@ -12,6 +12,8 @@ associated with a record. """ +import uuid + from invenio_db import db from invenio_records.api import Record @@ -36,8 +38,8 @@ def __init__(self, m2m_model_cls, record_id, data): # def _to_id(self, val): """Get the community id.""" - if isinstance(val, str): - return val + if isinstance(val, (str, uuid.UUID)): + return str(val) elif isinstance(val, Record): return str(val.id) return None @@ -159,6 +161,11 @@ def default(self): return self._lookup_community(self._default_id) return None + @property + def entries(self): + """Get community objects list.""" + return list(self) + @default.setter def default(self, community_or_id): """Set the default community. @@ -194,4 +201,9 @@ def from_dict(self, data): data = data or {} self._default_id = data.get("default", None) self._communities_ids = set(data.get("ids", [])) + # Search results will have denormalized communities, so we can populate the + # cache from it. + entries = data.pop("entries", None) + if entries: + self._communities_cache = {c["id"]: Community.loads(c) for c in entries} return self