diff --git a/buildingmotif/dataclasses/library.py b/buildingmotif/dataclasses/library.py index 6a0ff4ac..7fe42011 100644 --- a/buildingmotif/dataclasses/library.py +++ b/buildingmotif/dataclasses/library.py @@ -422,45 +422,31 @@ def _resolve_dependency( :return: the template instance this dependency points to :rtype: Template """ - # if dep is a _template_dependency, turn it into a template + dependee = None + binding_args = {} if isinstance(dep, _template_dependency): dependee = dep.to_template(template_id_lookup) - template.add_dependency(dependee, dep.bindings) - return - - # now, we know that dep is a dict - - # if dependency names a library explicitly, load that library and get the template by name - if "library" in dep: - dependee = Library.load(name=dep["library"]).get_template_by_name( - dep["template"] - ) - template.add_dependency(dependee, dep["args"]) - return - # if no library is provided, try to resolve the dependency from this library - if dep["template"] in template_id_lookup: - dependee = Template.load(template_id_lookup[dep["template"]]) - template.add_dependency(dependee, dep["args"]) - return - # check documentation for skip_uri for what URIs get skipped - if skip_uri(dep["template"]): - return - - # if the dependency is not in the local cache, then search through this library's imports - # for the template - for imp in self.graph_imports: - try: - library = Library.load(name=str(imp)) - dependee = library.get_template_by_name(dep["template"]) - template.add_dependency(dependee, dep["args"]) + binding_args = dep.bindings + elif isinstance(dep, dict): + binding_args = dep.get("args", {}) + if "library" in dep: + dependee = Library.load(name=dep["library"]).get_template_by_name(dep["template"]) + elif dep["template"] in template_id_lookup: + dependee = Template.load(template_id_lookup[dep["template"]]) + elif skip_uri(dep["template"]): return - except Exception as e: - logging.debug( - f"Could not find dependee {dep['template']} in library {imp}: {e}" - ) - logging.warning( - f"Warning: could not find dependee {dep['template']} in libraries {self.graph_imports}" - ) + else: + for imp in self.graph_imports: + try: + library = Library.load(name=str(imp)) + dependee = library.get_template_by_name(dep["template"]) + break + except Exception as e: + logging.debug(f"Could not find dependee {dep['template']} in library {imp}: {e}") + if dependee is not None: + template.add_dependency(dependee, binding_args) + else: + logging.warning(f"Warning: could not find dependee {dep['template']} in libraries {self.graph_imports}") def _resolve_template_dependencies( self, diff --git a/tests/unit/database/table_connection/test_db_library.py b/tests/unit/database/table_connection/test_db_library.py index 9eded8e2..9b992298 100644 --- a/tests/unit/database/table_connection/test_db_library.py +++ b/tests/unit/database/table_connection/test_db_library.py @@ -10,7 +10,7 @@ from buildingmotif.database.tables import DBLibrary, DBShapeCollection, DBTemplate -def test_create_db_library(table_connection, monkeypatch): +def test_create_db_library(bm, monkeypatch): mocked_uuid = uuid.uuid4() def mockreturn(): @@ -18,7 +18,7 @@ def mockreturn(): monkeypatch.setattr(uuid, "uuid4", mockreturn) - db_library = table_connection.create_db_library(name="my_db_library") + db_library = bm.table_connection.create_db_library(name="my_db_library") assert db_library.name == "my_db_library" assert db_library.templates == [] @@ -26,11 +26,11 @@ def mockreturn(): assert db_library.shape_collection.graph_id == str(mocked_uuid) -def test_get_db_libraries(table_connection): - table_connection.create_db_library(name="my_db_library") - table_connection.create_db_library(name="your_db_library") +def test_get_db_libraries(bm): + bm.table_connection.create_db_library(name="my_db_library") + bm.table_connection.create_db_library(name="your_db_library") - db_libraries = table_connection.get_all_db_libraries() + db_libraries = bm.table_connection.get_all_db_libraries() assert len(db_libraries) == 2 assert all(type(tl) == DBLibrary for tl in db_libraries) @@ -40,11 +40,11 @@ def test_get_db_libraries(table_connection): } -def test_get_db_library(table_connection): - db_library = table_connection.create_db_library(name="my_library") - table_connection.create_db_template("my_db_template", library_id=db_library.id) +def test_get_db_library(bm): + db_library = bm.table_connection.create_db_library(name="my_library") + bm.table_connection.create_db_template("my_db_template", library_id=db_library.id) - db_library = table_connection.get_db_library(id=db_library.id) + db_library = bm.table_connection.get_db_library(id=db_library.id) assert db_library.name == "my_library" assert len(db_library.templates) == 1 assert type(db_library.templates[0]) == DBTemplate diff --git a/tests/unit/database/table_connection/test_db_model.py b/tests/unit/database/table_connection/test_db_model.py index 7d305c80..974355d5 100644 --- a/tests/unit/database/table_connection/test_db_model.py +++ b/tests/unit/database/table_connection/test_db_model.py @@ -24,15 +24,15 @@ def test_create_db_model(mock_uuid4, table_connection): assert db_model.manifest.graph_id == str(mocked_manifest_uuid) -def test_get_db_models(table_connection): - table_connection.create_db_model( +def test_get_db_models(bm): + bm.table_connection.create_db_model( name="my_db_model", description="a very good model" ) - table_connection.create_db_model( + bm.table_connection.create_db_model( name="your_db_model", description="an ok good model" ) - db_models = table_connection.get_all_db_models() + db_models = bm.table_connection.get_all_db_models() assert len(db_models) == 2 assert all(type(m) == DBModel for m in db_models) @@ -43,13 +43,13 @@ def test_get_db_models(table_connection): @mock.patch("uuid.uuid4") -def test_get_db_model(mock_uuid4, table_connection): +def test_get_db_model(mock_uuid4, bm): mocked_graph_uuid = uuid.uuid4() mocked_manifest_uuid = uuid.uuid4() mock_uuid4.side_effect = [mocked_graph_uuid, mocked_manifest_uuid] - db_model = table_connection.create_db_model(name="my_db_model") - db_model = table_connection.get_db_model(id=db_model.id) + db_model = bm.table_connection.create_db_model(name="my_db_model") + db_model = bm.table_connection.get_db_model(id=db_model.id) assert db_model.name == "my_db_model" assert db_model.graph_id == str(mocked_graph_uuid) diff --git a/tests/unit/database/table_connection/test_db_shape.py b/tests/unit/database/table_connection/test_db_shape.py index e813df50..ad2435fe 100644 --- a/tests/unit/database/table_connection/test_db_shape.py +++ b/tests/unit/database/table_connection/test_db_shape.py @@ -6,7 +6,7 @@ from buildingmotif.database.tables import DBShapeCollection -def test_create_db_shape_collection(monkeypatch, table_connection): +def test_create_db_shape_collection(monkeypatch, bm): mocked_uuid = uuid.uuid4() def mockreturn(): @@ -14,16 +14,16 @@ def mockreturn(): monkeypatch.setattr(uuid, "uuid4", mockreturn) - db_shape_collection = table_connection.create_db_shape_collection() + db_shape_collection = bm.table_connection.create_db_shape_collection() assert db_shape_collection.graph_id == str(mocked_uuid) -def test_get_db_shape_collections(table_connection): - shape_collection1 = table_connection.create_db_shape_collection() - shape_collection2 = table_connection.create_db_shape_collection() +def test_get_db_shape_collections(bm): + shape_collection1 = bm.table_connection.create_db_shape_collection() + shape_collection2 = bm.table_connection.create_db_shape_collection() - db_shape_collections = table_connection.get_all_db_shape_collections() + db_shape_collections = bm.table_connection.get_all_db_shape_collections() assert len(db_shape_collections) == 2 assert all(type(m) == DBShapeCollection for m in db_shape_collections) diff --git a/tests/unit/database/test_cascading_deletes_cascades.py b/tests/unit/database/test_cascading_deletes_cascades.py new file mode 100644 index 00000000..f8b8be05 --- /dev/null +++ b/tests/unit/database/test_cascading_deletes_cascades.py @@ -0,0 +1,74 @@ +import pytest +from buildingmotif.dataclasses.template import Template +from buildingmotif.database.errors import ( + LibraryNotFound, + TemplateNotFound, + ModelNotFound, + ShapeCollectionNotFound, +) + +def test_cascade_delete_model_shape_collection(bm): + # Create a model; its manifest (shape collection) should be cascading deleted. + db_model = bm.table_connection.create_db_model( + name="cascade_model", description="test cascading delete on model" + ) + shape_collection_id = db_model.manifest.id + # assert we can get the shape collection + assert bm.table_connection.get_db_shape_collection(shape_collection_id) + # now delete the model, and assert the shape collection is gone + bm.table_connection.delete_db_model(db_model.id) + with pytest.raises(ShapeCollectionNotFound): + bm.table_connection.get_db_shape_collection(shape_collection_id) + + +def test_cascade_delete_library_cascades(bm): + # Create a library, two templates within it, and a dependency relationship. + db_library = bm.table_connection.create_db_library(name="cascade_library") + shape_collection_id = db_library.shape_collection.id + template1 = bm.table_connection.create_db_template(name="template1", library_id=db_library.id) + template2 = bm.table_connection.create_db_template(name="template2", library_id=db_library.id) + + # Add a dependency relationship between template1 and template2. + template1 = Template.load(template1.id) + template2 = Template.load(template2.id) + template1.add_dependency(template2, {"name": "dependency"}) + # Verify the dependency exists. + deps = bm.table_connection.get_db_template_dependencies(template1.id) + assert len(deps) == 1 + + # Deleting the library should cascade-delete the library, its templates, and its associated shape collection. + bm.table_connection.delete_db_library(db_library.id) + + with pytest.raises(LibraryNotFound): + bm.table_connection.get_db_library(db_library.id) + with pytest.raises(TemplateNotFound): + bm.table_connection.get_db_template(template1.id) + with pytest.raises(TemplateNotFound): + bm.table_connection.get_db_template(template2.id) + with pytest.raises(ShapeCollectionNotFound): + bm.table_connection.get_db_shape_collection(shape_collection_id) + +def test_cascade_delete_multi_library(bm): + # Create two libraries + library1 = bm.table_connection.create_db_library(name="cascade_library1") + library2 = bm.table_connection.create_db_library(name="cascade_library2") + # Create template1 in library1 and template2 in library2 + template1 = bm.table_connection.create_db_template(name="template1", library_id=library1.id) + template2 = bm.table_connection.create_db_template(name="template2", library_id=library2.id) + # Load templates to add dependency and verify dependency relationship. + template1 = Template.load(template1.id) + template2 = Template.load(template2.id) + # Add dependency: template1 depends on template2 + template1.add_dependency(template2, {"name": "dependency"}) + # Verify dependency exists. + deps = bm.table_connection.get_db_template_dependencies(template1.id) + assert len(deps) == 1 + # Delete library1 and ensure cascading deletion + bm.table_connection.delete_db_library(library1.id) + with pytest.raises(LibraryNotFound): + bm.table_connection.get_db_library(library1.id) + with pytest.raises(TemplateNotFound): + bm.table_connection.get_db_template(template1.id) + # Library2 and its template should still exist + assert bm.table_connection.get_db_library(library2.id) + assert bm.table_connection.get_db_template(template2.id)