diff --git a/docs/source/conf.py b/docs/source/conf.py index 31d9a3bee..42a7ee173 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,7 @@ rig, session, subject, - quality_control + quality_control, ) dummy_object = [ @@ -34,7 +34,7 @@ rig, session, subject, - quality_control + quality_control, ] # A temporary workaround to bypass "Imported but unused" error INSTITUTE_NAME = "Allen Institute for Neural Dynamics" diff --git a/docs/source/quality_control.md b/docs/source/quality_control.md index acf3da9ea..6fbf43af2 100644 --- a/docs/source/quality_control.md +++ b/docs/source/quality_control.md @@ -102,6 +102,14 @@ Each metric is associated with a reference figure, image, or video. The QC porta By default the QC portal displays dictionaries as tables where the values can be edited. We also support a few special cases to allow a bit more flexibility or to constrain the actions that manual annotators can take. Install the `aind-qcportal-schema` package and set the `value` field to the corresponding pydantic object to use these. -### Multi-session QC +### Multi-asset QC -[Details coming soon, this is under discussion] \ No newline at end of file +During analysis there are many situations where multiple data assets need to be pulled together, often for comparison. For example, FOVs across imaging sessions or recording sessions from a chronic probe might need to get matched up across days. When a `QCEvaluation` is being calculated from multiple assets it should be tagged with `Stage:MULTI_ASSET` and each of its `QCMetric` objects needs to track the assets that were used to generate that metric in the `evaluated_assets` list. + +**Q: Where do I store multi-asset QC?** + +You should follow the preferred/alternate workflows described above. If your multi-asset analysis pipeline generates a new data asset, put the QC there. If your pipeline does not generate an asset, push a copy of each `QCEvaluation` back to **each** individual data asset. + +**Q: I want to be able to store data about each of the evaluated assets in this metric** + +Take a look at the `MultiAssetMetric` class in `aind-qc-portal-schema`. It allows you to pass a list of values which will be matched up with the `evaluated_assets` names. You can also include options which will appear as dropdowns or checkboxes. \ No newline at end of file diff --git a/docs/source/session.rst b/docs/source/session.rst index aa4d1df15..f0103b093 100644 --- a/docs/source/session.rst +++ b/docs/source/session.rst @@ -59,10 +59,11 @@ the Stimulus Class, but the trial-by-trial stimulus information belongs in the N Great question! We began defining specific classes for different stimulus and behavior modalities, but quickly found that this won't be scalable. You can currently use these classes if they work for you. However, in the long run we -would like this to move into the `script` field. This field uses the Software class, which has a field for stimulus -parameters, where users can define their own dictionary of parameters used in the script to control the stimulus/ -behavior. We recommend that you use software to define these and be consistent within your projects. Please reach out -with questions and we can help you with this. +would like this to move into the `script` field. This field uses the Software class, which has a field for +`parameters`. Users should use this to document the parameters used to control the stimulus or behavior. parameters +should have unambiguous names (e.g. "trial_duration" rather than "duration") and units must be provided as a separate +field (e.g. "trial_duration_unit"). We recommend that you use software to define these and be consistent within your +projects. Please reach out with questions and we can help you with this. **Q: What should I put for the `session_type`?** @@ -77,8 +78,8 @@ and SLIMS. Until this is fully functional, these files must be created manually. **Q: How do I know if my mouse platform is "active"?** There are experiments in which the mouse platform is actively controlled by the stimulus/behavior software - i.e. the -resistance of the wheel is adjusted based on the subjects activity. This is an "active" mouse platform. Most platforms -we use are not active in this way. +resistance of the wheel is adjusted based on the subject's activity. This is an "active" mouse platform. Most platforms +we use are currently not active in this way. **Q: How do I use the Calibration field?** diff --git a/examples/quality_control.json b/examples/quality_control.json index fdb50d082..352f364c8 100644 --- a/examples/quality_control.json +++ b/examples/quality_control.json @@ -80,6 +80,7 @@ "evaluated_assets": null } ], + "tags": null, "notes": "", "allow_failed_metrics": false }, @@ -121,6 +122,7 @@ "evaluated_assets": null } ], + "tags": null, "notes": "Pass when video_1_num_frames==video_2_num_frames", "allow_failed_metrics": false }, @@ -176,6 +178,7 @@ "evaluated_assets": null } ], + "tags": null, "notes": null, "allow_failed_metrics": false } diff --git a/pyproject.toml b/pyproject.toml index bfb7fa1de..2f4b7d2a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ readme = "README.md" dynamic = ["version"] dependencies = [ - 'aind-data-schema-models>=0.3.2', + 'aind-data-schema-models>=0.5.4, <1.0.0', 'dictdiffer', 'pydantic>=2.7', 'inflection', diff --git a/src/aind_data_schema/__init__.py b/src/aind_data_schema/__init__.py index c02cc6a60..5d3d4791e 100755 --- a/src/aind_data_schema/__init__.py +++ b/src/aind_data_schema/__init__.py @@ -1,4 +1,4 @@ """ imports for AindModel subclasses """ -__version__ = "1.1.0" +__version__ = "1.1.1" diff --git a/src/aind_data_schema/base.py b/src/aind_data_schema/base.py index c04cd3c74..d3ecb5c75 100644 --- a/src/aind_data_schema/base.py +++ b/src/aind_data_schema/base.py @@ -1,5 +1,6 @@ """ generic base class with supporting validators and fields for basic AIND schema """ +import json import re from pathlib import Path from typing import Any, Generic, Optional, TypeVar @@ -14,6 +15,7 @@ ValidationError, ValidatorFunctionWrapHandler, create_model, + model_validator, ) from pydantic.functional_validators import WrapValidator from typing_extensions import Annotated @@ -31,13 +33,57 @@ def _coerce_naive_datetime(v: Any, handler: ValidatorFunctionWrapHandler) -> Awa AwareDatetimeWithDefault = Annotated[AwareDatetime, WrapValidator(_coerce_naive_datetime)] +def is_dict_corrupt(input_dict: dict) -> bool: + """ + Checks that dictionary keys, included nested keys, do not contain + forbidden characters ("$" and "."). + + Parameters + ---------- + input_dict : dict + + Returns + ------- + bool + True if input_dict is not a dict, or if any keys contain + forbidden characters. False otherwise. + + """ + + def has_corrupt_keys(input) -> bool: + """Recursively checks nested dictionaries and lists""" + if isinstance(input, dict): + for key, value in input.items(): + if "$" in key or "." in key: + return True + elif has_corrupt_keys(value): + return True + elif isinstance(input, list): + for item in input: + if has_corrupt_keys(item): + return True + return False + + # Top-level input must be a dictionary + if not isinstance(input_dict, dict): + return True + return has_corrupt_keys(input_dict) + + class AindGeneric(BaseModel, extra="allow"): """Base class for generic types that can be used in AIND schema""" # extra="allow" is needed because BaseModel by default drops extra parameters. # Alternatively, consider using 'SerializeAsAny' once this issue is resolved # https://github.com/pydantic/pydantic/issues/6423 - pass + + @model_validator(mode="after") + def validate_fieldnames(self): + """Ensure that field names do not contain forbidden characters""" + model_dict = json.loads(self.model_dump_json(by_alias=True)) + if is_dict_corrupt(model_dict): + raise ValueError("Field names cannot contain '.' or '$'") + return self AindGenericType = TypeVar("AindGenericType", bound=AindGeneric) diff --git a/src/aind_data_schema/core/acquisition.py b/src/aind_data_schema/core/acquisition.py index 77a4f8b88..da7953386 100644 --- a/src/aind_data_schema/core/acquisition.py +++ b/src/aind_data_schema/core/acquisition.py @@ -4,7 +4,7 @@ from typing import List, Literal, Optional, Union from aind_data_schema_models.process_names import ProcessName -from pydantic import Field, field_validator +from pydantic import Field, SkipValidation, field_validator from aind_data_schema.base import AindCoreModel, AindModel, AwareDatetimeWithDefault from aind_data_schema.components.coordinates import AnatomicalDirection, AxisName, ImageAxis @@ -45,7 +45,7 @@ class Acquisition(AindCoreModel): _DESCRIBED_BY_URL = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/acquisition.py" describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL}) - schema_version: Literal["1.0.1"] = Field(default="1.0.1") + schema_version: SkipValidation[Literal["1.0.1"]] = Field(default="1.0.1") protocol_id: List[str] = Field(default=[], title="Protocol ID", description="DOI for protocols.io") experimenter_full_name: List[str] = Field( ..., diff --git a/src/aind_data_schema/core/data_description.py b/src/aind_data_schema/core/data_description.py index 47e3d7edc..31441cf9f 100644 --- a/src/aind_data_schema/core/data_description.py +++ b/src/aind_data_schema/core/data_description.py @@ -15,7 +15,7 @@ from aind_data_schema_models.organizations import Organization from aind_data_schema_models.pid_names import PIDName from aind_data_schema_models.platforms import Platform -from pydantic import Field, model_validator +from pydantic import Field, SkipValidation, model_validator from aind_data_schema.base import AindCoreModel, AindModel, AwareDatetimeWithDefault @@ -40,7 +40,7 @@ class DataDescription(AindCoreModel): _DESCRIBED_BY_URL = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/data_description.py" describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL}) - schema_version: Literal["1.0.1"] = Field(default="1.0.1") + schema_version: SkipValidation[Literal["1.0.1"]] = Field(default="1.0.1") license: Literal["CC-BY-4.0"] = Field("CC-BY-4.0", title="License") platform: Platform.ONE_OF = Field( diff --git a/src/aind_data_schema/core/instrument.py b/src/aind_data_schema/core/instrument.py index ee44f12c0..699bba10d 100644 --- a/src/aind_data_schema/core/instrument.py +++ b/src/aind_data_schema/core/instrument.py @@ -4,7 +4,7 @@ from typing import List, Literal, Optional from aind_data_schema_models.organizations import Organization -from pydantic import Field, ValidationInfo, field_validator +from pydantic import Field, SkipValidation, ValidationInfo, field_validator from aind_data_schema.base import AindCoreModel, AindModel from aind_data_schema.components.devices import ( @@ -35,7 +35,7 @@ class Instrument(AindCoreModel): _DESCRIBED_BY_URL = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/instrument.py" describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL}) - schema_version: Literal["1.0.1"] = Field(default="1.0.1") + schema_version: SkipValidation[Literal["1.0.1"]] = Field(default="1.0.1") instrument_id: Optional[str] = Field( default=None, diff --git a/src/aind_data_schema/core/metadata.py b/src/aind_data_schema/core/metadata.py index eeb307249..9142140ac 100644 --- a/src/aind_data_schema/core/metadata.py +++ b/src/aind_data_schema/core/metadata.py @@ -1,6 +1,8 @@ """Generic metadata class for Data Asset Records.""" import inspect +import json +import logging from datetime import datetime from enum import Enum from typing import Dict, List, Literal, Optional, get_args @@ -8,9 +10,17 @@ from aind_data_schema_models.modalities import ExpectedFiles, FileRequirement from aind_data_schema_models.platforms import Platform -from pydantic import Field, PrivateAttr, ValidationError, ValidationInfo, field_validator, model_validator - -from aind_data_schema.base import AindCoreModel +from pydantic import ( + Field, + PrivateAttr, + SkipValidation, + ValidationError, + ValidationInfo, + field_validator, + model_validator, +) + +from aind_data_schema.base import AindCoreModel, is_dict_corrupt from aind_data_schema.core.acquisition import Acquisition from aind_data_schema.core.data_description import DataDescription from aind_data_schema.core.instrument import Instrument @@ -61,7 +71,7 @@ class Metadata(AindCoreModel): _DESCRIBED_BY_URL = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/metadata.py" describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL}) - schema_version: Literal["1.0.2"] = Field(default="1.0.2") + schema_version: SkipValidation[Literal["1.0.2"]] = Field(default="1.0.2") id: UUID = Field( default_factory=uuid4, alias="_id", @@ -278,3 +288,43 @@ def validate_rig_session_compatibility(self): check = RigSessionCompatibility(self.rig, self.session) check.run_compatibility_check() return self + + +def create_metadata_json( + name: str, + location: str, + core_jsons: Dict[str, Optional[dict]], + optional_created: Optional[datetime] = None, + optional_external_links: Optional[dict] = None, +) -> dict: + """Creates a Metadata dict from dictionary of core schema fields.""" + # Extract basic parameters and non-corrupt core schema fields + params = { + "name": name, + "location": location, + } + if optional_created is not None: + params["created"] = optional_created + if optional_external_links is not None: + params["external_links"] = optional_external_links + core_fields = dict() + for key, value in core_jsons.items(): + if key in CORE_FILES and value is not None: + if is_dict_corrupt(value): + logging.warning(f"Provided {key} is corrupt! It will be ignored.") + else: + core_fields[key] = value + # Create Metadata object and convert to JSON + # If there are any validation errors, still create it + # but set MetadataStatus as Invalid + try: + metadata = Metadata.model_validate({**params, **core_fields}) + metadata_json = json.loads(metadata.model_dump_json(by_alias=True)) + except Exception as e: + logging.warning(f"Issue with metadata construction! {e.args}") + metadata = Metadata.model_validate(params) + metadata_json = json.loads(metadata.model_dump_json(by_alias=True)) + for key, value in core_fields.items(): + metadata_json[key] = value + metadata_json["metadata_status"] = MetadataStatus.INVALID.value + return metadata_json diff --git a/src/aind_data_schema/core/procedures.py b/src/aind_data_schema/core/procedures.py index 9e9cda4fd..94ab07083 100644 --- a/src/aind_data_schema/core/procedures.py +++ b/src/aind_data_schema/core/procedures.py @@ -20,7 +20,7 @@ VolumeUnit, create_unit_with_value, ) -from pydantic import Field, field_serializer, field_validator, model_validator +from pydantic import Field, SkipValidation, field_serializer, field_validator, model_validator from pydantic_core.core_schema import ValidationInfo from typing_extensions import Annotated @@ -649,7 +649,7 @@ class Procedures(AindCoreModel): _DESCRIBED_BY_URL = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/procedures.py" describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL}) - schema_version: Literal["1.1.1"] = Field(default="1.1.1") + schema_version: SkipValidation[Literal["1.1.1"]] = Field(default="1.1.1") subject_id: str = Field( ..., description="Unique identifier for the subject. If this is not a Allen LAS ID, indicate this in the Notes.", diff --git a/src/aind_data_schema/core/processing.py b/src/aind_data_schema/core/processing.py index 8b52958c9..f4abe9295 100644 --- a/src/aind_data_schema/core/processing.py +++ b/src/aind_data_schema/core/processing.py @@ -5,7 +5,7 @@ from aind_data_schema_models.process_names import ProcessName from aind_data_schema_models.units import MemoryUnit, UnitlessUnit -from pydantic import Field, ValidationInfo, field_validator, model_validator +from pydantic import Field, SkipValidation, ValidationInfo, field_validator, model_validator from aind_data_schema.base import AindCoreModel, AindGeneric, AindGenericType, AindModel, AwareDatetimeWithDefault from aind_data_schema.components.tile import Tile @@ -124,7 +124,7 @@ class Processing(AindCoreModel): _DESCRIBED_BY_URL: str = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/processing.py" describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL}) - schema_version: Literal["1.1.1"] = Field(default="1.1.1") + schema_version: SkipValidation[Literal["1.1.1"]] = Field(default="1.1.1") processing_pipeline: PipelineProcess = Field( ..., description="Pipeline used to process data", title="Processing Pipeline" diff --git a/src/aind_data_schema/core/quality_control.py b/src/aind_data_schema/core/quality_control.py index 04661f84c..22875b804 100644 --- a/src/aind_data_schema/core/quality_control.py +++ b/src/aind_data_schema/core/quality_control.py @@ -4,7 +4,7 @@ from typing import Any, List, Literal, Optional from aind_data_schema_models.modalities import Modality -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, SkipValidation, field_validator, model_validator from aind_data_schema.base import AindCoreModel, AindModel, AwareDatetimeWithDefault @@ -78,6 +78,9 @@ class QCEvaluation(AindModel): name: str = Field(..., title="Evaluation name") description: Optional[str] = Field(default=None, title="Evaluation description") metrics: List[QCMetric] = Field(..., title="QC metrics") + tags: Optional[List[str]] = Field( + default=None, title="Tags", description="Tags can be used to group QCEvaluation objects into groups" + ) notes: Optional[str] = Field(default=None, title="Notes") allow_failed_metrics: bool = Field( default=False, @@ -161,7 +164,7 @@ class QualityControl(AindCoreModel): _DESCRIBED_BY_URL = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/quality_control.py" describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL}) - schema_version: Literal["1.1.1"] = Field(default="1.1.1") + schema_version: SkipValidation[Literal["1.1.1"]] = Field(default="1.1.1") evaluations: List[QCEvaluation] = Field(..., title="Evaluations") notes: Optional[str] = Field(default=None, title="Notes") diff --git a/src/aind_data_schema/core/rig.py b/src/aind_data_schema/core/rig.py index ef26f34bf..4ab4efb0c 100644 --- a/src/aind_data_schema/core/rig.py +++ b/src/aind_data_schema/core/rig.py @@ -4,7 +4,7 @@ from typing import List, Literal, Optional, Set, Union from aind_data_schema_models.modalities import Modality -from pydantic import Field, ValidationInfo, field_serializer, field_validator, model_validator +from pydantic import Field, SkipValidation, ValidationInfo, field_serializer, field_validator, model_validator from typing_extensions import Annotated from aind_data_schema.base import AindCoreModel @@ -51,7 +51,7 @@ class Rig(AindCoreModel): _DESCRIBED_BY_URL = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/rig.py" describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL}) - schema_version: Literal["1.0.1"] = Field(default="1.0.1") + schema_version: SkipValidation[Literal["1.0.1"]] = Field(default="1.0.1") rig_id: str = Field( ..., description="Unique rig identifier, name convention: --", @@ -90,9 +90,9 @@ class Rig(AindCoreModel): notes: Optional[str] = Field(default=None, title="Notes") @field_serializer("modalities", when_used="json") - def serialize_modalities(modalities: Set[Modality.ONE_OF]): - """sort modalities by name when serializing to JSON""" - return sorted(modalities, key=lambda x: x.name) + def serialize_modalities(self, modalities: Set[Modality.ONE_OF]): + """Dynamically serialize modalities based on their type.""" + return sorted(modalities, key=lambda x: x.get("name") if isinstance(x, dict) else x.name) @model_validator(mode="after") def validate_cameras_other(self): diff --git a/src/aind_data_schema/core/session.py b/src/aind_data_schema/core/session.py index d5d3fef35..202b6f5db 100644 --- a/src/aind_data_schema/core/session.py +++ b/src/aind_data_schema/core/session.py @@ -17,7 +17,7 @@ TimeUnit, VolumeUnit, ) -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SkipValidation, field_validator, model_validator from pydantic_core.core_schema import ValidationInfo from typing_extensions import Annotated @@ -534,7 +534,7 @@ class Session(AindCoreModel): _DESCRIBED_BY_URL = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/session.py" describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL}) - schema_version: Literal["1.0.1"] = Field(default="1.0.1") + schema_version: SkipValidation[Literal["1.0.1"]] = Field(default="1.0.1") protocol_id: List[str] = Field(default=[], title="Protocol ID", description="DOI for protocols.io") experimenter_full_name: List[str] = Field( ..., diff --git a/src/aind_data_schema/core/subject.py b/src/aind_data_schema/core/subject.py index aedc6a9a7..db3aa0988 100644 --- a/src/aind_data_schema/core/subject.py +++ b/src/aind_data_schema/core/subject.py @@ -8,7 +8,7 @@ from aind_data_schema_models.organizations import Organization from aind_data_schema_models.pid_names import PIDName from aind_data_schema_models.species import Species -from pydantic import Field, field_validator +from pydantic import Field, SkipValidation, field_validator from pydantic_core.core_schema import ValidationInfo from aind_data_schema.base import AindCoreModel, AindModel @@ -89,7 +89,7 @@ class Subject(AindCoreModel): _DESCRIBED_BY_URL = AindCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/subject.py" describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL}) - schema_version: Literal["1.0.0"] = Field(default="1.0.0") + schema_version: SkipValidation[Literal["1.0.0"]] = Field(default="1.0.0") subject_id: str = Field( ..., description="Unique identifier for the subject. If this is not a Allen LAS ID, indicate this in the Notes.", diff --git a/src/aind_data_schema/utils/schema_version_bump.py b/src/aind_data_schema/utils/schema_version_bump.py index ee2f0fe04..e3ee33fc3 100644 --- a/src/aind_data_schema/utils/schema_version_bump.py +++ b/src/aind_data_schema/utils/schema_version_bump.py @@ -101,8 +101,8 @@ def _get_updated_file(python_file_path: str, new_ver: str) -> list: with open(python_file_path, "rb") as f: file_lines = f.readlines() for line in file_lines: - if "schema_version: Literal[" in str(line): - new_line_str = f' schema_version: Literal["{new_ver}"] = Field("{new_ver}")\n' + if "schema_version: SkipValidation[Literal[" in str(line): + new_line_str = f' schema_version: SkipValidation[Literal["{new_ver}"]] = Field("{new_ver}")\n' new_line = new_line_str.encode() else: new_line = line diff --git a/tests/test_base.py b/tests/test_base.py index d122e5bfa..a95e5e161 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,13 +1,14 @@ """ tests for Subject """ +import json import unittest from datetime import datetime, timezone from pathlib import Path from unittest.mock import MagicMock, call, mock_open, patch -from pydantic import create_model +from pydantic import ValidationError, create_model -from aind_data_schema.base import AwareDatetimeWithDefault +from aind_data_schema.base import AindGeneric, AwareDatetimeWithDefault, is_dict_corrupt from aind_data_schema.core.subject import Subject @@ -55,6 +56,65 @@ def test_aware_datetime_with_setting(self): expected_json = '{"dt":"2020-10-10T01:02:03Z"}' self.assertEqual(expected_json, model_instance.model_dump_json()) + def test_is_dict_corrupt(self): + """Tests is_dict_corrupt method""" + good_contents = [ + {"a": 1, "b": {"c": 2, "d": 3}}, + {"a": 1, "b": {"c": 2, "d": 3}, "e": ["f", "g"]}, + {"a": 1, "b": {"c": 2, "d": 3}, "e": ["f.valid", "g"]}, + {"a": 1, "b": {"c": {"d": 2}, "e": 3}}, + {"a": 1, "b": [{"c": 2}, {"d": 3}], "e": 4}, + ] + bad_contents = [ + {"a.1": 1, "b": {"c": 2, "d": 3}}, + {"a": 1, "b": {"c": 2, "$d": 3}}, + {"a": 1, "b": {"c": {"d": 2}, "$e": 3}}, + {"a": 1, "b": {"c": 2, "d": 3, "e.csv": 4}}, + {"a": 1, "b": [{"c": 2}, {"d.csv": 3}], "e": 4}, + ] + invalid_types = [ + json.dumps({"a": 1, "b": {"c": 2, "d": 3}}), + [{"a": 1}, {"b": {"c": 2, "d": 3}}], + 1, + None, + ] + for contents in good_contents: + with self.subTest(contents=contents): + self.assertFalse(is_dict_corrupt(contents)) + for contents in bad_contents: + with self.subTest(contents=contents): + self.assertTrue(is_dict_corrupt(contents)) + for contents in invalid_types: + with self.subTest(contents=contents): + self.assertTrue(is_dict_corrupt(contents)) + + def test_aind_generic_constructor(self): + """Tests default constructor for AindGeneric""" + model = AindGeneric() + self.assertEqual("{}", model.model_dump_json()) + + params = {"foo": "bar"} + model = AindGeneric(**params) + self.assertEqual('{"foo":"bar"}', model.model_dump_json()) + + def test_aind_generic_validate_fieldnames(self): + """Tests that fieldnames are validated in AindGeneric""" + expected_error = ( + "1 validation error for AindGeneric\n" + " Value error, Field names cannot contain '.' or '$' " + ) + invalid_params = [ + {"$foo": "bar"}, + {"foo": {"foo.name": "bar"}}, + ] + for params in invalid_params: + with self.assertRaises(ValidationError) as e: + AindGeneric(**params) + self.assertIn(expected_error, repr(e.exception)) + with self.assertRaises(ValidationError) as e: + AindGeneric.model_validate(params) + self.assertIn(expected_error, repr(e.exception)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_bump_schema_versions.py b/tests/test_bump_schema_versions.py index 47709e88c..6401de7d3 100644 --- a/tests/test_bump_schema_versions.py +++ b/tests/test_bump_schema_versions.py @@ -77,18 +77,18 @@ def test_update_files(self, mock_write: MagicMock): handler._update_files({Subject: new_subject_version, Session: new_session_version}) expected_line_change0 = ( - f' schema_version: Literal["{new_subject_version}"] = Field("{new_subject_version}")\n' + f'schema_version: SkipValidation[Literal["{new_subject_version}"]] = Field("{new_subject_version}")' ) expected_line_change1 = ( - f' schema_version: Literal["{new_session_version}"] = Field("{new_session_version}")\n' + f'schema_version: SkipValidation[Literal["{new_session_version}"]] = Field("{new_session_version}")' ) mock_write_args0 = mock_write.mock_calls[0].args mock_write_args1 = mock_write.mock_calls[1].args + self.assertTrue(expected_line_change0 in str(mock_write_args0[0])) self.assertTrue("subject.py" in str(mock_write_args0[1])) - self.assertTrue(expected_line_change0.encode() in mock_write_args0[0]) + self.assertTrue(expected_line_change1 in str(mock_write_args1[0])) self.assertTrue("session.py" in str(mock_write_args1[1])) - self.assertTrue(expected_line_change1.encode() in mock_write_args1[0]) @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_list_of_models_that_changed") @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._update_files") diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 941a79523..faf578b7f 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -3,19 +3,21 @@ import json import re import unittest -from datetime import time +from datetime import datetime, time, timezone +from unittest.mock import MagicMock, call, patch +from aind_data_schema_models.modalities import Modality from aind_data_schema_models.organizations import Organization +from aind_data_schema_models.pid_names import PIDName from aind_data_schema_models.platforms import Platform -from aind_data_schema_models.modalities import Modality from pydantic import ValidationError from pydantic import __version__ as pyd_version from aind_data_schema.components.devices import MousePlatform from aind_data_schema.core.acquisition import Acquisition -from aind_data_schema.core.data_description import DataDescription +from aind_data_schema.core.data_description import DataDescription, Funding from aind_data_schema.core.instrument import Instrument -from aind_data_schema.core.metadata import Metadata, MetadataStatus +from aind_data_schema.core.metadata import ExternalPlatforms, Metadata, MetadataStatus, create_metadata_json from aind_data_schema.core.procedures import ( IontophoresisInjection, NanojectInjection, @@ -23,10 +25,10 @@ Surgery, ViralMaterial, ) -from aind_data_schema.core.processing import Processing +from aind_data_schema.core.processing import PipelineProcess, Processing from aind_data_schema.core.rig import Rig from aind_data_schema.core.session import Session -from aind_data_schema.core.subject import BreedingInfo, Sex, Species, Subject +from aind_data_schema.core.subject import BreedingInfo, Housing, Sex, Species, Subject PYD_VERSION = re.match(r"(\d+.\d+).\d+", pyd_version).group(1) @@ -34,6 +36,56 @@ class TestMetadata(unittest.TestCase): """Class to test Metadata model""" + @classmethod + def setUpClass(cls) -> None: + """Set up the test class.""" + subject = Subject( + species=Species.MUS_MUSCULUS, + subject_id="12345", + sex=Sex.MALE, + date_of_birth=datetime(2022, 11, 22, 8, 43, 00, tzinfo=timezone.utc).date(), + source=Organization.AI, + breeding_info=BreedingInfo( + breeding_group="Emx1-IRES-Cre(ND)", + maternal_id="546543", + maternal_genotype="Emx1-IRES-Cre/wt; Camk2a-tTa/Camk2a-tTA", + paternal_id="232323", + paternal_genotype="Ai93(TITL-GCaMP6f)/wt", + ), + genotype="Emx1-IRES-Cre/wt;Camk2a-tTA/wt;Ai93(TITL-GCaMP6f)/wt", + housing=Housing(home_cage_enrichment=["Running wheel"], cage_id="123"), + background_strain="C57BL/6J", + ) + dd = DataDescription( + label="test_data", + modality=[Modality.ECEPHYS], + platform=Platform.ECEPHYS, + subject_id="123456", + data_level="raw", + creation_time=datetime(2022, 11, 22, 8, 43, 00, tzinfo=timezone.utc), + institution=Organization.AIND, + funding_source=[Funding(funder=Organization.NINDS, grant_number="grant001")], + investigators=[PIDName(name="Jane Smith")], + ) + procedures = Procedures( + subject_id="12345", + ) + processing = Processing( + processing_pipeline=PipelineProcess(processor_full_name="Processor", data_processes=[]), + ) + + cls.sample_name = "ecephys_655019_2023-04-03_18-17-09" + cls.sample_location = "s3://bucket/ecephys_655019_2023-04-03_18-17-09" + cls.subject = subject + cls.dd = dd + cls.procedures = procedures + cls.processing = processing + + cls.subject_json = json.loads(subject.model_dump_json()) + cls.dd_json = json.loads(dd.model_dump_json()) + cls.procedures_json = json.loads(procedures.model_dump_json()) + cls.processing_json = json.loads(processing.model_dump_json()) + def test_valid_subject_info(self): """Tests that the record is marked as VALID if a valid subject model is present.""" @@ -326,6 +378,171 @@ def test_validate_rig_session_compatibility(self): str(context.exception), ) + def test_validate_old_schema_version(self): + """Tests that old schema versions are ignored during validation + """ + m = Metadata.model_construct( + name="name", + location="location", + id="1", + ) + + m_dict = m.model_dump() + + m_dict["schema_version"] = "0.0.0" + m_dict.pop("id") + + m2 = Metadata(**m_dict) + + self.assertIsNotNone(m2) + + def test_create_from_core_jsons(self): + """Tests metadata json can be created with valid inputs""" + core_jsons = { + "subject": self.subject_json, + "data_description": None, + "procedures": self.procedures_json, + "session": None, + "rig": None, + "processing": self.processing_json, + "acquisition": None, + "instrument": None, + "quality_control": None, + } + expected_md = Metadata( + name=self.sample_name, + location=self.sample_location, + subject=self.subject, + procedures=self.procedures, + processing=self.processing, + ) + expected_result = json.loads(expected_md.model_dump_json(by_alias=True)) + # there are some userwarnings when creating Subject from json + with self.assertWarns(UserWarning): + result = create_metadata_json( + name=self.sample_name, + location=self.sample_location, + core_jsons=core_jsons, + ) + # check that metadata was created with expected values + self.assertEqual(self.sample_name, result["name"]) + self.assertEqual(self.sample_location, result["location"]) + self.assertEqual(MetadataStatus.VALID.value, result["metadata_status"]) + self.assertEqual(self.subject_json, result["subject"]) + self.assertEqual(self.procedures_json, result["procedures"]) + self.assertEqual(self.processing_json, result["processing"]) + self.assertIsNone(result["acquisition"]) + # also check the other fields + # small hack to mock the _id, created, and last_modified fields + expected_result["_id"] = result["_id"] + expected_result["created"] = result["created"] + expected_result["last_modified"] = result["last_modified"] + self.assertDictEqual(expected_result, result) + + def test_create_from_core_jsons_optional_overwrite(self): + """Tests metadata json creation with created and external links""" + created = datetime(2024, 10, 31, 12, 0, 0, tzinfo=timezone.utc) + external_links = { + ExternalPlatforms.CODEOCEAN.value: ["123", "abc"], + } + # there are some userwarnings when creating from json + with self.assertWarns(UserWarning): + result = create_metadata_json( + name=self.sample_name, + location=self.sample_location, + core_jsons={ + "subject": self.subject_json, + }, + optional_created=created, + optional_external_links=external_links, + ) + self.assertEqual(self.sample_name, result["name"]) + self.assertEqual(self.sample_location, result["location"]) + self.assertEqual("2024-10-31T12:00:00Z", result["created"]) + self.assertEqual(external_links, result["external_links"]) + + @patch("logging.warning") + def test_create_from_core_jsons_invalid(self, mock_warning: MagicMock): + """Tests that metadata json is marked invalid if there are errors""" + # data_description triggers cross-validation of other fields to fail + core_jsons = { + "subject": self.subject_json, + "data_description": self.dd_json, + "procedures": self.procedures_json, + "session": None, + "rig": None, + "processing": self.processing_json, + "acquisition": None, + "instrument": None, + "quality_control": None, + } + # there are some userwarnings when creating Subject from json + with self.assertWarns(UserWarning): + result = create_metadata_json( + name=self.sample_name, + location=self.sample_location, + core_jsons=core_jsons, + ) + # check that metadata was still created + self.assertEqual(self.sample_name, result["name"]) + self.assertEqual(self.sample_location, result["location"]) + self.assertEqual(self.subject_json, result["subject"]) + self.assertEqual(self.dd_json, result["data_description"]) + self.assertEqual(self.procedures_json, result["procedures"]) + self.assertEqual(self.processing_json, result["processing"]) + self.assertIsNone(result["acquisition"]) + # check that metadata was marked as invalid + self.assertEqual(MetadataStatus.INVALID.value, result["metadata_status"]) + mock_warning.assert_called_once() + self.assertIn("Issue with metadata construction!", mock_warning.call_args_list[0].args[0]) + + @patch("logging.warning") + @patch("aind_data_schema.core.metadata.is_dict_corrupt") + def test_create_from_core_jsons_corrupt( + self, + mock_is_dict_corrupt: MagicMock, + mock_warning: MagicMock + ): + """Tests metadata json creation ignores corrupt core jsons""" + # mock corrupt procedures and processing + mock_is_dict_corrupt.side_effect = lambda x: ( + x == self.procedures_json or x == self.processing_json + ) + core_jsons = { + "subject": self.subject_json, + "data_description": None, + "procedures": self.procedures_json, + "session": None, + "rig": None, + "processing": self.processing_json, + "acquisition": None, + "instrument": None, + "quality_control": None, + } + # there are some userwarnings when creating Subject from json + with self.assertWarns(UserWarning): + result = create_metadata_json( + name=self.sample_name, + location=self.sample_location, + core_jsons=core_jsons, + ) + # check that metadata was still created + self.assertEqual(self.sample_name, result["name"]) + self.assertEqual(self.sample_location, result["location"]) + self.assertEqual(self.subject_json, result["subject"]) + self.assertIsNone(result["acquisition"]) + self.assertEqual(MetadataStatus.VALID.value, result["metadata_status"]) + # check that corrupt core jsons were ignored + self.assertIsNone(result["procedures"]) + self.assertIsNone(result["processing"]) + mock_warning.assert_has_calls( + [ + call("Provided processing is corrupt! It will be ignored."), + call("Provided procedures is corrupt! It will be ignored."), + ], + any_order=True, + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rig.py b/tests/test_rig.py index d1d86f454..32b243125 100644 --- a/tests/test_rig.py +++ b/tests/test_rig.py @@ -1,11 +1,13 @@ """ test Rig """ import unittest +import json from datetime import date, datetime from aind_data_schema_models.modalities import Modality from aind_data_schema_models.organizations import Organization from pydantic import ValidationError +from pydantic_core import PydanticSerializationError from aind_data_schema.components.devices import ( Calibration, @@ -821,6 +823,30 @@ def test_rig_id_validator(self): calibrations=[calibration], ) + def test_serialize_modalities(self): + """Tests that modalities serializer can handle different types""" + expected_modalities = [{"name": "Extracellular electrophysiology", "abbreviation": "ecephys"}] + # Case 1: Modality is a class instance + rig_instance_modality = Rig.model_construct( + modalities=[Modality.ECEPHYS] # Example with a valid Modality instance + ) + rig_json = rig_instance_modality.model_dump_json() + rig_data = json.loads(rig_json) + self.assertEqual(rig_data["modalities"], expected_modalities) + + # Case 2: Modality is a dictionary when Rig is constructed from JSON + rig_dict_modality = Rig.model_construct(**rig_data) + rig_dict_json = rig_dict_modality.model_dump_json() + rig_dict_data = json.loads(rig_dict_json) + self.assertEqual(rig_dict_data["modalities"], expected_modalities) + + # Case 3: Modality is an unknown type + with self.assertRaises(PydanticSerializationError) as context: + rig_unknown_modality = Rig.model_construct(modalities={"UnknownModality"}) + + rig_unknown_modality.model_dump_json() + self.assertIn("Error calling function `serialize_modalities`", str(context.exception)) + if __name__ == "__main__": unittest.main()