diff --git a/invenio_communities/communities/schema.py b/invenio_communities/communities/schema.py index c6c6ff2f5..6c8f60c00 100644 --- a/invenio_communities/communities/schema.py +++ b/invenio_communities/communities/schema.py @@ -9,6 +9,7 @@ """Community schema.""" import re +from copy import deepcopy from functools import partial from uuid import UUID @@ -252,16 +253,11 @@ def post_dump(self, data, many, **kwargs): class CommunityParentSchema(BaseCommunitySchema): """Community parent schema.""" - id = fields.String(required=True) - slug = SanitizedUnicode() - metadata = NestedAttribute(CommunityMetadataSchema) - access = NestedAttribute(CommunityAccessSchema) - class CommunitySchema(BaseCommunitySchema): """Community schema.""" - parent = NestedAttribute(CommunityParentSchema, allow_none=True) + parent = NestedAttribute(CommunityParentSchema, dump_only=True, allow_none=True) @post_dump def post_dump(self, data, many, **kwargs): @@ -271,11 +267,15 @@ def post_dump(self, data, many, **kwargs): data.pop("parent", None) return data - @post_load - def filter_parent_id(self, in_data, **kwargs): + @post_load(pass_original=True) + def filter_parent_id(self, in_data, original_data, **kwargs): """Simply keep the parent id.""" - if in_data.get("parent"): - in_data["parent"] = dict(id=in_data.get("parent", {}).get("id")) + if "parent" in original_data: + in_data["parent"] = ( + dict(id=original_data["parent"]["id"]) + if original_data["parent"] + else None + ) return in_data @pre_load