Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat fix schema bump #1215

Merged
merged 10 commits into from
Jan 23, 2025
7 changes: 1 addition & 6 deletions src/aind_data_schema/core/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,12 +714,7 @@ class Procedures(AindCoreModel):
)
subject_procedures: List[
Annotated[
Union[
Surgery,
TrainingProtocol,
WaterRestriction,
OtherSubjectProcedure
],
Union[Surgery, TrainingProtocol, WaterRestriction, OtherSubjectProcedure],
Field(discriminator="procedure_type"),
]
] = Field(default=[], title="Subject Procedures")
Expand Down
66 changes: 51 additions & 15 deletions src/aind_data_schema/utils/schema_version_bump.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ def __init__(self, commit_message: str = "", json_schemas_location: Path = Path(
self.commit_message = commit_message
self.json_schemas_location = json_schemas_location

def _get_schema_json(self, model: AindCoreModel) -> dict:
"""
Get the json schema of a model
Parameters
----------
model : AindCoreModel
The model to get the json schema of

Returns
-------
dict
The json schema of the model
"""
default_filename = model.default_filename()
if default_filename.find(".") != -1:
schema_filename = default_filename[: default_filename.find(".")] + "_schema.json"
main_branch_schema_path = self.json_schemas_location / schema_filename
if main_branch_schema_path.exists():
with open(main_branch_schema_path, "r") as f:
main_branch_schema_contents = json.load(f)
else:
raise FileNotFoundError(f"Schema file not found: {main_branch_schema_path}")
return main_branch_schema_contents

def _get_list_of_models_that_changed(self) -> List[AindCoreModel]:
"""
Get a list of core models that have been updated by comparing the json
Expand All @@ -46,20 +70,18 @@ def _get_list_of_models_that_changed(self) -> List[AindCoreModel]:
schemas_that_need_updating = []
for core_model in SchemaWriter.get_schemas():
core_model_json = core_model.model_json_schema()
default_filename = core_model.default_filename()
if default_filename.find(".") != -1:
schema_filename = default_filename[: default_filename.find(".")] + "_schema.json"
main_branch_schema_path = self.json_schemas_location / schema_filename
if main_branch_schema_path.exists():
with open(main_branch_schema_path, "r") as f:
main_branch_schema_contents = json.load(f)
diff = dictdiffer.diff(main_branch_schema_contents, core_model_json)
if len(list(diff)) > 0:
schemas_that_need_updating.append(core_model)
original_schema = self._get_schema_json(core_model)

diff_list = list(dictdiffer.diff(original_schema, core_model_json))

print(f"Diff for {core_model.__name__}: {diff_list}")
if len(diff_list) > 0:
schemas_that_need_updating.append(core_model)

print(f"Schemas that need updating: {[model.__name__ for model in schemas_that_need_updating]}")
return schemas_that_need_updating

@staticmethod
def _get_incremented_versions_map(models_that_changed: List[AindCoreModel]) -> Dict[AindCoreModel, str]:
def _get_incremented_versions_map(self, models_that_changed: List[AindCoreModel]) -> Dict[AindCoreModel, str]:
"""

Parameters
Expand All @@ -74,11 +96,24 @@ def _get_incremented_versions_map(models_that_changed: List[AindCoreModel]) -> D

"""
version_bump_map = {}
# TODO: Use commit message to determine version number to bump?
for model in models_that_changed:
# We only want to bump the patch if the major or minor versions didn't already change
# Load the current version of the model
original_schema = self._get_schema_json(model)
schema_version = original_schema.get("properties", {}).get("schema_version", {}).get("default")
if schema_version:
orig_ver = semver.Version.parse(schema_version)
else:
raise ValueError("Schema version not found in the schema file")

old_v = semver.Version.parse(model.model_fields["schema_version"].default)
new_v = old_v.bump_patch()
version_bump_map[model] = str(new_v)
if orig_ver.major == old_v.major and orig_ver.minor == old_v.minor:
print(f"Updating {model.__name__} from {old_v} to {old_v.bump_patch()}")
new_ver = old_v.bump_patch()
version_bump_map[model] = str(new_ver)
else:
print(f"Skipping {model.__name__}, major or minor version already updated")
new_ver = old_v
return version_bump_map

@staticmethod
Expand All @@ -98,6 +133,7 @@ def _get_updated_file(python_file_path: str, new_ver: str) -> list:

"""
new_file_contents = []
print(f"Updating {python_file_path} to version {new_ver}")
with open(python_file_path, "rb") as f:
file_lines = f.readlines()
for line in file_lines:
Expand Down
88 changes: 83 additions & 5 deletions tests/test_bump_schema_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from aind_data_schema.core.session import Session
from aind_data_schema.core.subject import Subject
from aind_data_schema.core.rig import Rig
from aind_data_schema.utils.json_writer import SchemaWriter
from aind_data_schema.utils.schema_version_bump import SchemaVersionHandler

Expand Down Expand Up @@ -39,16 +40,25 @@ def test_get_list_of_models_that_changed(self, mock_exists: MagicMock, mock_json
self.assertTrue(Session in models_that_changed)
self.assertTrue(Subject in models_that_changed)

def test_get_list_of_incremented_versions(self):
@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
def test_get_list_of_incremented_versions(self, mock_get_schema: MagicMock):
"""Tests get_list_of_incremented_versions method"""

handler = SchemaVersionHandler(json_schemas_location=Path("."))
old_subject_version = Subject.model_fields["schema_version"].default
new_subject_version = str(Version.parse(old_subject_version).bump_patch())
old_session_version = Session.model_fields["schema_version"].default
new_session_version = str(Version.parse(old_session_version).bump_patch())
# Pycharm raises a warning about types that we can ignore
# noinspection PyTypeChecker

def side_effect(model):
"""Side effect for mock_get_schema"""
if model == Subject:
return {"properties": {"schema_version": {"default": old_subject_version}}}
elif model == Session:
return {"properties": {"schema_version": {"default": old_session_version}}}

mock_get_schema.side_effect = side_effect

model_map = handler._get_incremented_versions_map([Subject, Session])
expected_model_map = {Subject: new_subject_version, Session: new_session_version}
self.assertEqual(expected_model_map, model_map)
Expand All @@ -64,6 +74,58 @@ def test_write_new_file(self, mock_open: MagicMock):
mock_open.assert_called_once_with(file_path, "wb")
mock_open.return_value.__enter__().write.assert_has_calls([call(file_contents[0]), call(file_contents[1])])

@patch("builtins.open")
@patch("json.load")
@patch("pathlib.Path.exists")
def test_get_schema_json(self, mock_exists: MagicMock, mock_json_load: MagicMock, mock_open: MagicMock):
"""Tests _get_schema_json method"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_exists.return_value = True
mock_json_load.return_value = {"properties": {"schema_version": {"default": "1.0.0"}}}

model = MagicMock()
model.default_filename.return_value = "test_model.json"

schema_json = handler._get_schema_json(model)
self.assertEqual(schema_json, {"properties": {"schema_version": {"default": "1.0.0"}}})

mock_open.assert_called_once_with(Path("./test_model_schema.json"), "r")
mock_json_load.assert_called_once()

@patch("pathlib.Path.exists")
def test_get_schema_json_file_not_found(self, mock_exists: MagicMock):
"""Tests _get_schema_json method when file is not found"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_exists.return_value = False

model = MagicMock()
model.default_filename.return_value = "test_model.json"

with self.assertRaises(FileNotFoundError):
handler._get_schema_json(model)

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
def test_get_incremented_versions_map_exception(self, mock_get_schema: MagicMock):
"""Test that missing schema_version field raises an error"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_get_schema.return_value = {}

with self.assertRaises(ValueError):
handler._get_incremented_versions_map([Subject])

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
def test_get_incremented_versions_map_skip(self, mock_get_schema: MagicMock):
"""Test that missing schema_version field raises an error"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_get_schema.return_value = {"properties": {"schema_version": {"default": "0.0.0"}}}

empty_map = handler._get_incremented_versions_map([Subject])
self.assertEqual(empty_map, {})

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._write_new_file")
def test_update_files(self, mock_write: MagicMock):
"""Tests the update_files method"""
Expand All @@ -72,9 +134,11 @@ def test_update_files(self, mock_write: MagicMock):
new_subject_version = str(Version.parse(old_subject_version).bump_patch())
old_session_version = Session.model_fields["schema_version"].default
new_session_version = str(Version.parse(old_session_version).bump_patch())
old_rig_version = Rig.model_fields["schema_version"].default
new_rig_version = str(Version.parse(old_rig_version).bump_minor())
# Pycharm raises a warning about types that we can ignore
# noinspection PyTypeChecker
handler._update_files({Subject: new_subject_version, Session: new_session_version})
handler._update_files({Subject: new_subject_version, Session: new_session_version, Rig: new_rig_version})

expected_line_change0 = (
f'schema_version: SkipValidation[Literal["{new_subject_version}"]] = Field(default="{new_subject_version}")'
Expand All @@ -90,16 +154,30 @@ def test_update_files(self, mock_write: MagicMock):
self.assertTrue(expected_line_change1 in str(mock_write_args1[0]))
self.assertTrue("session.py" in str(mock_write_args1[1]))

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_list_of_models_that_changed")
@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._update_files")
def test_run_job(self, mock_update_files: MagicMock, mock_get_list_of_models: MagicMock):
def test_run_job(
self, mock_update_files: MagicMock, mock_get_list_of_models: MagicMock, mock_get_schema: MagicMock
):
"""Tests run_job method"""

old_subject_version = Subject.model_fields["schema_version"].default
new_subject_version = str(Version.parse(old_subject_version).bump_patch())
old_session_version = Session.model_fields["schema_version"].default
new_session_version = str(Version.parse(old_session_version).bump_patch())

mock_get_list_of_models.return_value = [Subject, Session]

def side_effect(model):
"""Return values for get_schema_json"""
if model == Subject:
return {"properties": {"schema_version": {"default": old_subject_version}}}
elif model == Session:
return {"properties": {"schema_version": {"default": old_session_version}}}

mock_get_schema.side_effect = side_effect

handler = SchemaVersionHandler(json_schemas_location=Path("."))
handler.run_job()
mock_update_files.assert_called_once_with({Subject: new_subject_version, Session: new_session_version})
Expand Down
Loading