From a4c2d05bf11f8a700fb0412c7df37ff5edfa9e6f Mon Sep 17 00:00:00 2001 From: Dan Birman Date: Mon, 6 Jan 2025 13:58:41 -0800 Subject: [PATCH 1/7] fix: changing how schema version bumps are discovered, to avoid re-bumping a major/minor change --- .../utils/schema_version_bump.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/aind_data_schema/utils/schema_version_bump.py b/src/aind_data_schema/utils/schema_version_bump.py index 36e78e7f2..cdbfbe728 100644 --- a/src/aind_data_schema/utils/schema_version_bump.py +++ b/src/aind_data_schema/utils/schema_version_bump.py @@ -54,6 +54,7 @@ def _get_list_of_models_that_changed(self) -> List[AindCoreModel]: with open(main_branch_schema_path, "r") as f: main_branch_schema_contents = json.load(f) diff = dictdiffer.diff(main_branch_schema_contents, core_model_json) + print(f"Diff for {core_model.__name__}: {list(diff)}") if len(list(diff)) > 0: schemas_that_need_updating.append(core_model) return schemas_that_need_updating @@ -74,11 +75,24 @@ def _get_incremented_versions_map(models_that_changed: List[AindCoreModel]) -> D """ version_bump_map = {} - # TODO: Use commit message to determine version number to bump? for model in models_that_changed: + # We only want to bump the patch if the major or minor versions didn't already change + # Load the current version of the model + default_filename = model.default_filename() + if default_filename.find(".") != -1: + schema_filename = default_filename[: default_filename.find(".")] + "_schema.json" + main_branch_schema_path = Path(OLD_SCHEMA_DIR) / schema_filename + if main_branch_schema_path.exists(): + with open(main_branch_schema_path, "r") as f: + main_branch_schema_contents = json.load(f) + orig_ver = semver.Version.parse(main_branch_schema_contents.get("schema_version", model.model_fields["schema_version"].default)) + old_v = semver.Version.parse(model.model_fields["schema_version"].default) - new_v = old_v.bump_patch() - version_bump_map[model] = str(new_v) + if orig_ver.major == old_v.major and orig_ver.minor == old_v.minor: + new_ver = old_v.bump_patch() + else: + new_ver = old_v + version_bump_map[model] = str(new_ver) return version_bump_map @staticmethod From dd47d41740799c5d300f1391628adc2edcdda38a Mon Sep 17 00:00:00 2001 From: Dan Birman Date: Tue, 7 Jan 2025 11:59:23 -0800 Subject: [PATCH 2/7] fix: refactor and fix schema bump code to pull original version properly --- .../utils/schema_version_bump.py | 62 ++++++++++++------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/src/aind_data_schema/utils/schema_version_bump.py b/src/aind_data_schema/utils/schema_version_bump.py index cdbfbe728..b9e1ba5a8 100644 --- a/src/aind_data_schema/utils/schema_version_bump.py +++ b/src/aind_data_schema/utils/schema_version_bump.py @@ -34,6 +34,28 @@ def __init__(self, commit_message: str = "", json_schemas_location: Path = Path( self.commit_message = commit_message self.json_schemas_location = json_schemas_location + def _get_schema_json(self, model: AindCoreModel) -> dict: + """ + Get the json schema of a model + Parameters + ---------- + model : AindCoreModel + The model to get the json schema of + + Returns + ------- + dict + The json schema of the model + """ + default_filename = model.default_filename() + if default_filename.find(".") != -1: + schema_filename = default_filename[: default_filename.find(".")] + "_schema.json" + main_branch_schema_path = self.json_schemas_location / schema_filename + if main_branch_schema_path.exists(): + with open(main_branch_schema_path, "r") as f: + main_branch_schema_contents = json.load(f) + return main_branch_schema_contents + def _get_list_of_models_that_changed(self) -> List[AindCoreModel]: """ Get a list of core models that have been updated by comparing the json @@ -46,21 +68,18 @@ def _get_list_of_models_that_changed(self) -> List[AindCoreModel]: schemas_that_need_updating = [] for core_model in SchemaWriter.get_schemas(): core_model_json = core_model.model_json_schema() - default_filename = core_model.default_filename() - if default_filename.find(".") != -1: - schema_filename = default_filename[: default_filename.find(".")] + "_schema.json" - main_branch_schema_path = self.json_schemas_location / schema_filename - if main_branch_schema_path.exists(): - with open(main_branch_schema_path, "r") as f: - main_branch_schema_contents = json.load(f) - diff = dictdiffer.diff(main_branch_schema_contents, core_model_json) - print(f"Diff for {core_model.__name__}: {list(diff)}") - if len(list(diff)) > 0: - schemas_that_need_updating.append(core_model) + original_schema = self._get_schema_json(core_model) + + diff_list = list(dictdiffer.diff(original_schema, core_model_json)) + + print(f"Diff for {core_model.__name__}: {diff_list}") + if len(diff_list) > 0: + schemas_that_need_updating.append(core_model) + + print(f"Schemas that need updating: {[model.__name__ for model in schemas_that_need_updating]}") return schemas_that_need_updating - @staticmethod - def _get_incremented_versions_map(models_that_changed: List[AindCoreModel]) -> Dict[AindCoreModel, str]: + def _get_incremented_versions_map(self, models_that_changed: List[AindCoreModel]) -> Dict[AindCoreModel, str]: """ Parameters @@ -78,19 +97,19 @@ def _get_incremented_versions_map(models_that_changed: List[AindCoreModel]) -> D for model in models_that_changed: # We only want to bump the patch if the major or minor versions didn't already change # Load the current version of the model - default_filename = model.default_filename() - if default_filename.find(".") != -1: - schema_filename = default_filename[: default_filename.find(".")] + "_schema.json" - main_branch_schema_path = Path(OLD_SCHEMA_DIR) / schema_filename - if main_branch_schema_path.exists(): - with open(main_branch_schema_path, "r") as f: - main_branch_schema_contents = json.load(f) - orig_ver = semver.Version.parse(main_branch_schema_contents.get("schema_version", model.model_fields["schema_version"].default)) + original_schema = self._get_schema_json(model) + schema_version = original_schema.get("properties", {}).get("schema_version").get("default") + if schema_version: + orig_ver = semver.Version.parse(schema_version) + else: + raise Exception("Schema version not found in the schema file") old_v = semver.Version.parse(model.model_fields["schema_version"].default) if orig_ver.major == old_v.major and orig_ver.minor == old_v.minor: + print(f"Updating {model.__name__} from {old_v} to {old_v.bump_patch()}") new_ver = old_v.bump_patch() else: + print(f"Skipping {model.__name__}, major or minor version already updated") new_ver = old_v version_bump_map[model] = str(new_ver) return version_bump_map @@ -112,6 +131,7 @@ def _get_updated_file(python_file_path: str, new_ver: str) -> list: """ new_file_contents = [] + print(f"Updating {python_file_path} to version {new_ver}") with open(python_file_path, "rb") as f: file_lines = f.readlines() for line in file_lines: From 671d4bd52b00e9135da1fba314ed28518e8fb109 Mon Sep 17 00:00:00 2001 From: Dan Birman Date: Tue, 7 Jan 2025 13:20:23 -0800 Subject: [PATCH 3/7] fix: raise errors appropriately and skip properly --- src/aind_data_schema/utils/schema_version_bump.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aind_data_schema/utils/schema_version_bump.py b/src/aind_data_schema/utils/schema_version_bump.py index b9e1ba5a8..b69c40475 100644 --- a/src/aind_data_schema/utils/schema_version_bump.py +++ b/src/aind_data_schema/utils/schema_version_bump.py @@ -54,6 +54,8 @@ def _get_schema_json(self, model: AindCoreModel) -> dict: if main_branch_schema_path.exists(): with open(main_branch_schema_path, "r") as f: main_branch_schema_contents = json.load(f) + else: + raise FileNotFoundError(f"Schema file not found: {main_branch_schema_path}") return main_branch_schema_contents def _get_list_of_models_that_changed(self) -> List[AindCoreModel]: @@ -108,10 +110,10 @@ def _get_incremented_versions_map(self, models_that_changed: List[AindCoreModel] if orig_ver.major == old_v.major and orig_ver.minor == old_v.minor: print(f"Updating {model.__name__} from {old_v} to {old_v.bump_patch()}") new_ver = old_v.bump_patch() + version_bump_map[model] = str(new_ver) else: print(f"Skipping {model.__name__}, major or minor version already updated") new_ver = old_v - version_bump_map[model] = str(new_ver) return version_bump_map @staticmethod From 44dd304654ce809d7096886b408b7e7fa03ccc50 Mon Sep 17 00:00:00 2001 From: Dan Birman Date: Tue, 7 Jan 2025 13:20:33 -0800 Subject: [PATCH 4/7] test: coverage on new get_schema_json function --- tests/test_bump_schema_versions.py | 65 +++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/tests/test_bump_schema_versions.py b/tests/test_bump_schema_versions.py index fec955bb1..9f41b260a 100644 --- a/tests/test_bump_schema_versions.py +++ b/tests/test_bump_schema_versions.py @@ -8,6 +8,7 @@ from aind_data_schema.core.session import Session from aind_data_schema.core.subject import Subject +from aind_data_schema.core.rig import Rig from aind_data_schema.utils.json_writer import SchemaWriter from aind_data_schema.utils.schema_version_bump import SchemaVersionHandler @@ -39,7 +40,8 @@ def test_get_list_of_models_that_changed(self, mock_exists: MagicMock, mock_json self.assertTrue(Session in models_that_changed) self.assertTrue(Subject in models_that_changed) - def test_get_list_of_incremented_versions(self): + @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json") + def test_get_list_of_incremented_versions(self, mock_get_schema: MagicMock): """Tests get_list_of_incremented_versions method""" handler = SchemaVersionHandler(json_schemas_location=Path(".")) @@ -47,8 +49,15 @@ def test_get_list_of_incremented_versions(self): new_subject_version = str(Version.parse(old_subject_version).bump_patch()) old_session_version = Session.model_fields["schema_version"].default new_session_version = str(Version.parse(old_session_version).bump_patch()) - # Pycharm raises a warning about types that we can ignore - # noinspection PyTypeChecker + + def side_effect(model): + if model == Subject: + return {"properties": {"schema_version": {"default": old_subject_version}}} + elif model == Session: + return {"properties": {"schema_version": {"default": old_session_version}}} + + mock_get_schema.side_effect = side_effect + model_map = handler._get_incremented_versions_map([Subject, Session]) expected_model_map = {Subject: new_subject_version, Session: new_session_version} self.assertEqual(expected_model_map, model_map) @@ -64,6 +73,39 @@ def test_write_new_file(self, mock_open: MagicMock): mock_open.assert_called_once_with(file_path, "wb") mock_open.return_value.__enter__().write.assert_has_calls([call(file_contents[0]), call(file_contents[1])]) + """Tests get_schema_json method""" + @patch("builtins.open") + @patch("json.load") + @patch("pathlib.Path.exists") + def test_get_schema_json(self, mock_exists: MagicMock, mock_json_load: MagicMock, mock_open: MagicMock): + """Tests _get_schema_json method""" + handler = SchemaVersionHandler(json_schemas_location=Path(".")) + + mock_exists.return_value = True + mock_json_load.return_value = {"properties": {"schema_version": {"default": "1.0.0"}}} + + model = MagicMock() + model.default_filename.return_value = "test_model.json" + + schema_json = handler._get_schema_json(model) + self.assertEqual(schema_json, {"properties": {"schema_version": {"default": "1.0.0"}}}) + + mock_open.assert_called_once_with(Path("./test_model_schema.json"), "r") + mock_json_load.assert_called_once() + + @patch("pathlib.Path.exists") + def test_get_schema_json_file_not_found(self, mock_exists: MagicMock): + """Tests _get_schema_json method when file is not found""" + handler = SchemaVersionHandler(json_schemas_location=Path(".")) + + mock_exists.return_value = False + + model = MagicMock() + model.default_filename.return_value = "test_model.json" + + with self.assertRaises(FileNotFoundError): + handler._get_schema_json(model) + @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._write_new_file") def test_update_files(self, mock_write: MagicMock): """Tests the update_files method""" @@ -72,9 +114,11 @@ def test_update_files(self, mock_write: MagicMock): new_subject_version = str(Version.parse(old_subject_version).bump_patch()) old_session_version = Session.model_fields["schema_version"].default new_session_version = str(Version.parse(old_session_version).bump_patch()) + old_rig_version = Rig.model_fields["schema_version"].default + new_rig_version = str(Version.parse(old_rig_version).bump_minor()) # Pycharm raises a warning about types that we can ignore # noinspection PyTypeChecker - handler._update_files({Subject: new_subject_version, Session: new_session_version}) + handler._update_files({Subject: new_subject_version, Session: new_session_version, Rig: new_rig_version}) expected_line_change0 = ( f'schema_version: SkipValidation[Literal["{new_subject_version}"]] = Field(default="{new_subject_version}")' @@ -90,16 +134,27 @@ def test_update_files(self, mock_write: MagicMock): self.assertTrue(expected_line_change1 in str(mock_write_args1[0])) self.assertTrue("session.py" in str(mock_write_args1[1])) + @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json") @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_list_of_models_that_changed") @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._update_files") - def test_run_job(self, mock_update_files: MagicMock, mock_get_list_of_models: MagicMock): + def test_run_job(self, mock_update_files: MagicMock, mock_get_list_of_models: MagicMock, mock_get_schema: MagicMock): """Tests run_job method""" old_subject_version = Subject.model_fields["schema_version"].default new_subject_version = str(Version.parse(old_subject_version).bump_patch()) old_session_version = Session.model_fields["schema_version"].default new_session_version = str(Version.parse(old_session_version).bump_patch()) + mock_get_list_of_models.return_value = [Subject, Session] + + def side_effect(model): + if model == Subject: + return {"properties": {"schema_version": {"default": old_subject_version}}} + elif model == Session: + return {"properties": {"schema_version": {"default": old_session_version}}} + + mock_get_schema.side_effect = side_effect + handler = SchemaVersionHandler(json_schemas_location=Path(".")) handler.run_job() mock_update_files.assert_called_once_with({Subject: new_subject_version, Session: new_session_version}) From 6f936fcf6e680f333b5ebebc8b47b5b48c7bb1d2 Mon Sep 17 00:00:00 2001 From: Dan Birman Date: Tue, 7 Jan 2025 13:23:05 -0800 Subject: [PATCH 5/7] chore: lint --- src/aind_data_schema/core/procedures.py | 7 +------ tests/test_bump_schema_versions.py | 5 ++++- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/aind_data_schema/core/procedures.py b/src/aind_data_schema/core/procedures.py index 42fec7119..b35cc0142 100644 --- a/src/aind_data_schema/core/procedures.py +++ b/src/aind_data_schema/core/procedures.py @@ -714,12 +714,7 @@ class Procedures(AindCoreModel): ) subject_procedures: List[ Annotated[ - Union[ - Surgery, - TrainingProtocol, - WaterRestriction, - OtherSubjectProcedure - ], + Union[Surgery, TrainingProtocol, WaterRestriction, OtherSubjectProcedure], Field(discriminator="procedure_type"), ] ] = Field(default=[], title="Subject Procedures") diff --git a/tests/test_bump_schema_versions.py b/tests/test_bump_schema_versions.py index 9f41b260a..59ee50cf8 100644 --- a/tests/test_bump_schema_versions.py +++ b/tests/test_bump_schema_versions.py @@ -74,6 +74,7 @@ def test_write_new_file(self, mock_open: MagicMock): mock_open.return_value.__enter__().write.assert_has_calls([call(file_contents[0]), call(file_contents[1])]) """Tests get_schema_json method""" + @patch("builtins.open") @patch("json.load") @patch("pathlib.Path.exists") @@ -137,7 +138,9 @@ def test_update_files(self, mock_write: MagicMock): @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json") @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_list_of_models_that_changed") @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._update_files") - def test_run_job(self, mock_update_files: MagicMock, mock_get_list_of_models: MagicMock, mock_get_schema: MagicMock): + def test_run_job( + self, mock_update_files: MagicMock, mock_get_list_of_models: MagicMock, mock_get_schema: MagicMock + ): """Tests run_job method""" old_subject_version = Subject.model_fields["schema_version"].default From 0ff23f6e46e52d1a541172e898ed43f85f4ed43c Mon Sep 17 00:00:00 2001 From: Dan Birman Date: Tue, 7 Jan 2025 13:25:14 -0800 Subject: [PATCH 6/7] chore: docstrings --- tests/test_bump_schema_versions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_bump_schema_versions.py b/tests/test_bump_schema_versions.py index 59ee50cf8..13d3c19ee 100644 --- a/tests/test_bump_schema_versions.py +++ b/tests/test_bump_schema_versions.py @@ -51,6 +51,7 @@ def test_get_list_of_incremented_versions(self, mock_get_schema: MagicMock): new_session_version = str(Version.parse(old_session_version).bump_patch()) def side_effect(model): + """Side effect for mock_get_schema""" if model == Subject: return {"properties": {"schema_version": {"default": old_subject_version}}} elif model == Session: @@ -151,6 +152,7 @@ def test_run_job( mock_get_list_of_models.return_value = [Subject, Session] def side_effect(model): + """Return values for get_schema_json""" if model == Subject: return {"properties": {"schema_version": {"default": old_subject_version}}} elif model == Session: From ad64ec708f53a113672188b29796f32f8dd1cb94 Mon Sep 17 00:00:00 2001 From: Dan Birman Date: Tue, 7 Jan 2025 13:37:59 -0800 Subject: [PATCH 7/7] tests: coverage on missing lines --- .../utils/schema_version_bump.py | 4 ++-- tests/test_bump_schema_versions.py | 22 +++++++++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/aind_data_schema/utils/schema_version_bump.py b/src/aind_data_schema/utils/schema_version_bump.py index b69c40475..201ff13e9 100644 --- a/src/aind_data_schema/utils/schema_version_bump.py +++ b/src/aind_data_schema/utils/schema_version_bump.py @@ -100,11 +100,11 @@ def _get_incremented_versions_map(self, models_that_changed: List[AindCoreModel] # We only want to bump the patch if the major or minor versions didn't already change # Load the current version of the model original_schema = self._get_schema_json(model) - schema_version = original_schema.get("properties", {}).get("schema_version").get("default") + schema_version = original_schema.get("properties", {}).get("schema_version", {}).get("default") if schema_version: orig_ver = semver.Version.parse(schema_version) else: - raise Exception("Schema version not found in the schema file") + raise ValueError("Schema version not found in the schema file") old_v = semver.Version.parse(model.model_fields["schema_version"].default) if orig_ver.major == old_v.major and orig_ver.minor == old_v.minor: diff --git a/tests/test_bump_schema_versions.py b/tests/test_bump_schema_versions.py index 13d3c19ee..8f1c2e542 100644 --- a/tests/test_bump_schema_versions.py +++ b/tests/test_bump_schema_versions.py @@ -74,8 +74,6 @@ def test_write_new_file(self, mock_open: MagicMock): mock_open.assert_called_once_with(file_path, "wb") mock_open.return_value.__enter__().write.assert_has_calls([call(file_contents[0]), call(file_contents[1])]) - """Tests get_schema_json method""" - @patch("builtins.open") @patch("json.load") @patch("pathlib.Path.exists") @@ -108,6 +106,26 @@ def test_get_schema_json_file_not_found(self, mock_exists: MagicMock): with self.assertRaises(FileNotFoundError): handler._get_schema_json(model) + @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json") + def test_get_incremented_versions_map_exception(self, mock_get_schema: MagicMock): + """Test that missing schema_version field raises an error""" + handler = SchemaVersionHandler(json_schemas_location=Path(".")) + + mock_get_schema.return_value = {} + + with self.assertRaises(ValueError): + handler._get_incremented_versions_map([Subject]) + + @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json") + def test_get_incremented_versions_map_skip(self, mock_get_schema: MagicMock): + """Test that missing schema_version field raises an error""" + handler = SchemaVersionHandler(json_schemas_location=Path(".")) + + mock_get_schema.return_value = {"properties": {"schema_version": {"default": "0.0.0"}}} + + empty_map = handler._get_incremented_versions_map([Subject]) + self.assertEqual(empty_map, {}) + @patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._write_new_file") def test_update_files(self, mock_write: MagicMock): """Tests the update_files method"""