Skip to content

Commit

Permalink
Merge branch 'dev' into release-v2.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dbirman committed Jan 24, 2025
2 parents b566d49 + e347569 commit de14b78
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 29 deletions.
2 changes: 2 additions & 0 deletions src/aind_data_schema/components/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,8 @@ class Treadmill(MousePlatform):
device_type: Literal["Treadmill"] = "Treadmill"
treadmill_width: Decimal = Field(..., title="Width of treadmill (mm)")
width_unit: SizeUnit = Field(default=SizeUnit.CM, title="Width unit")
encoder: Device = Field(..., title="Encoder")
pulse_per_revolution: int = Field(..., title="Pulse per revolution")


class Arena(MousePlatform):
Expand Down
84 changes: 60 additions & 24 deletions src/aind_data_schema/utils/schema_version_bump.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import dictdiffer
import semver

from aind_data_schema.base import DataCoreModel
from aind_data_schema.base import AindCoreModel
from aind_data_schema.utils.json_writer import SchemaWriter

CURRENT_DIR = Path(os.path.dirname(os.path.realpath(__file__)))
Expand All @@ -34,51 +34,86 @@ def __init__(self, commit_message: str = "", json_schemas_location: Path = Path(
self.commit_message = commit_message
self.json_schemas_location = json_schemas_location

def _get_list_of_models_that_changed(self) -> List[DataCoreModel]:
def _get_schema_json(self, model: AindCoreModel) -> dict:
"""
Get the json schema of a model
Parameters
----------
model : AindCoreModel
The model to get the json schema of
Returns
-------
dict
The json schema of the model
"""
default_filename = model.default_filename()
if default_filename.find(".") != -1:
schema_filename = default_filename[: default_filename.find(".")] + "_schema.json"
main_branch_schema_path = self.json_schemas_location / schema_filename
if main_branch_schema_path.exists():
with open(main_branch_schema_path, "r") as f:
main_branch_schema_contents = json.load(f)
else:
raise FileNotFoundError(f"Schema file not found: {main_branch_schema_path}")
return main_branch_schema_contents

def _get_list_of_models_that_changed(self) -> List[AindCoreModel]:
"""
Get a list of core models that have been updated by comparing the json
schema of the models to the json schema in the schemas folder.
Returns
-------
List[DataCoreModel]
A list of DataCoreModels that changed.
List[AindCoreModel]
A list of AindCoreModels that changed.
"""
schemas_that_need_updating = []
for core_model in SchemaWriter.get_schemas():
core_model_json = core_model.model_json_schema()
default_filename = core_model.default_filename()
if default_filename.find(".") != -1:
schema_filename = default_filename[: default_filename.find(".")] + "_schema.json"
main_branch_schema_path = self.json_schemas_location / schema_filename
if main_branch_schema_path.exists():
with open(main_branch_schema_path, "r") as f:
main_branch_schema_contents = json.load(f)
diff = dictdiffer.diff(main_branch_schema_contents, core_model_json)
if len(list(diff)) > 0:
schemas_that_need_updating.append(core_model)
original_schema = self._get_schema_json(core_model)

diff_list = list(dictdiffer.diff(original_schema, core_model_json))

print(f"Diff for {core_model.__name__}: {diff_list}")
if len(diff_list) > 0:
schemas_that_need_updating.append(core_model)

print(f"Schemas that need updating: {[model.__name__ for model in schemas_that_need_updating]}")
return schemas_that_need_updating

@staticmethod
def _get_incremented_versions_map(models_that_changed: List[DataCoreModel]) -> Dict[DataCoreModel, str]:
def _get_incremented_versions_map(self, models_that_changed: List[AindCoreModel]) -> Dict[AindCoreModel, str]:
"""
Parameters
----------
models_that_changed : List[DataCoreModel]
models_that_changed : List[AindCoreModel]
A list of models that have been updated and need to have their version numbers incremented.
Returns
-------
Dict[DataCoreModel, str]
A mapping of the DataCoreModel to its new version number.
Dict[AindCoreModel, str]
A mapping of the AindCoreModel to its new version number.
"""
version_bump_map = {}
# TODO: Use commit message to determine version number to bump?
for model in models_that_changed:
# We only want to bump the patch if the major or minor versions didn't already change
# Load the current version of the model
original_schema = self._get_schema_json(model)
schema_version = original_schema.get("properties", {}).get("schema_version", {}).get("default")
if schema_version:
orig_ver = semver.Version.parse(schema_version)
else:
raise ValueError("Schema version not found in the schema file")

old_v = semver.Version.parse(model.model_fields["schema_version"].default)
new_v = old_v.bump_patch()
version_bump_map[model] = str(new_v)
if orig_ver.major == old_v.major and orig_ver.minor == old_v.minor:
print(f"Updating {model.__name__} from {old_v} to {old_v.bump_patch()}")
new_ver = old_v.bump_patch()
version_bump_map[model] = str(new_ver)
else:
print(f"Skipping {model.__name__}, major or minor version already updated")
new_ver = old_v
return version_bump_map

@staticmethod
Expand All @@ -98,6 +133,7 @@ def _get_updated_file(python_file_path: str, new_ver: str) -> list:
"""
new_file_contents = []
print(f"Updating {python_file_path} to version {new_ver}")
with open(python_file_path, "rb") as f:
file_lines = f.readlines()
for line in file_lines:
Expand Down Expand Up @@ -133,13 +169,13 @@ def _write_new_file(new_file_contents: list, python_file_path: str) -> None:
for line in new_file_contents:
f.write(line)

def _update_files(self, version_bump_map: Dict[DataCoreModel, str]) -> None:
def _update_files(self, version_bump_map: Dict[AindCoreModel, str]) -> None:
"""
Using the information in the version_bump_map, will update the python
files in the core directory.
Parameters
----------
version_bump_map : Dict[DataCoreModel, str]
version_bump_map : Dict[AindCoreModel, str]
The models that need updating are in the dictionary keys and the
new version number is the dictionary value.
Expand Down
88 changes: 83 additions & 5 deletions tests/test_bump_schema_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from aind_data_schema.core.session import Session
from aind_data_schema.core.subject import Subject
from aind_data_schema.core.rig import Rig
from aind_data_schema.utils.json_writer import SchemaWriter
from aind_data_schema.utils.schema_version_bump import SchemaVersionHandler

Expand Down Expand Up @@ -39,16 +40,25 @@ def test_get_list_of_models_that_changed(self, mock_exists: MagicMock, mock_json
self.assertTrue(Session in models_that_changed)
self.assertTrue(Subject in models_that_changed)

def test_get_list_of_incremented_versions(self):
@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
def test_get_list_of_incremented_versions(self, mock_get_schema: MagicMock):
"""Tests get_list_of_incremented_versions method"""

handler = SchemaVersionHandler(json_schemas_location=Path("."))
old_subject_version = Subject.model_fields["schema_version"].default
new_subject_version = str(Version.parse(old_subject_version).bump_patch())
old_session_version = Session.model_fields["schema_version"].default
new_session_version = str(Version.parse(old_session_version).bump_patch())
# Pycharm raises a warning about types that we can ignore
# noinspection PyTypeChecker

def side_effect(model):
"""Side effect for mock_get_schema"""
if model == Subject:
return {"properties": {"schema_version": {"default": old_subject_version}}}
elif model == Session:
return {"properties": {"schema_version": {"default": old_session_version}}}

mock_get_schema.side_effect = side_effect

model_map = handler._get_incremented_versions_map([Subject, Session])
expected_model_map = {Subject: new_subject_version, Session: new_session_version}
self.assertEqual(expected_model_map, model_map)
Expand All @@ -64,6 +74,58 @@ def test_write_new_file(self, mock_open: MagicMock):
mock_open.assert_called_once_with(file_path, "wb")
mock_open.return_value.__enter__().write.assert_has_calls([call(file_contents[0]), call(file_contents[1])])

@patch("builtins.open")
@patch("json.load")
@patch("pathlib.Path.exists")
def test_get_schema_json(self, mock_exists: MagicMock, mock_json_load: MagicMock, mock_open: MagicMock):
"""Tests _get_schema_json method"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_exists.return_value = True
mock_json_load.return_value = {"properties": {"schema_version": {"default": "1.0.0"}}}

model = MagicMock()
model.default_filename.return_value = "test_model.json"

schema_json = handler._get_schema_json(model)
self.assertEqual(schema_json, {"properties": {"schema_version": {"default": "1.0.0"}}})

mock_open.assert_called_once_with(Path("./test_model_schema.json"), "r")
mock_json_load.assert_called_once()

@patch("pathlib.Path.exists")
def test_get_schema_json_file_not_found(self, mock_exists: MagicMock):
"""Tests _get_schema_json method when file is not found"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_exists.return_value = False

model = MagicMock()
model.default_filename.return_value = "test_model.json"

with self.assertRaises(FileNotFoundError):
handler._get_schema_json(model)

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
def test_get_incremented_versions_map_exception(self, mock_get_schema: MagicMock):
"""Test that missing schema_version field raises an error"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_get_schema.return_value = {}

with self.assertRaises(ValueError):
handler._get_incremented_versions_map([Subject])

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
def test_get_incremented_versions_map_skip(self, mock_get_schema: MagicMock):
"""Test that missing schema_version field raises an error"""
handler = SchemaVersionHandler(json_schemas_location=Path("."))

mock_get_schema.return_value = {"properties": {"schema_version": {"default": "0.0.0"}}}

empty_map = handler._get_incremented_versions_map([Subject])
self.assertEqual(empty_map, {})

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._write_new_file")
def test_update_files(self, mock_write: MagicMock):
"""Tests the update_files method"""
Expand All @@ -72,9 +134,11 @@ def test_update_files(self, mock_write: MagicMock):
new_subject_version = str(Version.parse(old_subject_version).bump_patch())
old_session_version = Session.model_fields["schema_version"].default
new_session_version = str(Version.parse(old_session_version).bump_patch())
old_rig_version = Rig.model_fields["schema_version"].default
new_rig_version = str(Version.parse(old_rig_version).bump_minor())
# Pycharm raises a warning about types that we can ignore
# noinspection PyTypeChecker
handler._update_files({Subject: new_subject_version, Session: new_session_version})
handler._update_files({Subject: new_subject_version, Session: new_session_version, Rig: new_rig_version})

expected_line_change0 = (
f'schema_version: SkipValidation[Literal["{new_subject_version}"]] = Field(default="{new_subject_version}")'
Expand All @@ -90,16 +154,30 @@ def test_update_files(self, mock_write: MagicMock):
self.assertTrue(expected_line_change1 in str(mock_write_args1[0]))
self.assertTrue("session.py" in str(mock_write_args1[1]))

@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_schema_json")
@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._get_list_of_models_that_changed")
@patch("aind_data_schema.utils.schema_version_bump.SchemaVersionHandler._update_files")
def test_run_job(self, mock_update_files: MagicMock, mock_get_list_of_models: MagicMock):
def test_run_job(
self, mock_update_files: MagicMock, mock_get_list_of_models: MagicMock, mock_get_schema: MagicMock
):
"""Tests run_job method"""

old_subject_version = Subject.model_fields["schema_version"].default
new_subject_version = str(Version.parse(old_subject_version).bump_patch())
old_session_version = Session.model_fields["schema_version"].default
new_session_version = str(Version.parse(old_session_version).bump_patch())

mock_get_list_of_models.return_value = [Subject, Session]

def side_effect(model):
"""Return values for get_schema_json"""
if model == Subject:
return {"properties": {"schema_version": {"default": old_subject_version}}}
elif model == Session:
return {"properties": {"schema_version": {"default": old_session_version}}}

mock_get_schema.side_effect = side_effect

handler = SchemaVersionHandler(json_schemas_location=Path("."))
handler.run_job()
mock_update_files.assert_called_once_with({Subject: new_subject_version, Session: new_session_version})
Expand Down

0 comments on commit de14b78

Please sign in to comment.