Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cascade DELETEs in the database #363

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions buildingmotif/database/tables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from typing import Dict, List

from sqlalchemy import Column, ForeignKey, Integer, String, Text, UniqueConstraint
from sqlalchemy import (
Column,
ForeignKey,
Integer,
String,
Text,
UniqueConstraint,
event,
)
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Mapped, declarative_base, relationship

# from sqlalchemy.dialects.postgresql import JSON
Expand All @@ -9,6 +18,14 @@
Base = declarative_base()


# https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#foreign-key-support
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()


class DBModel(Base):
"""A Model is a metadata model of all or part of a building."""

Expand All @@ -18,12 +35,13 @@ class DBModel(Base):
description: Mapped[str] = Column(Text(), default="", nullable=False)
graph_id: Mapped[str] = Column(String())
manifest_id: Mapped[int] = Column(
Integer, ForeignKey("shape_collection.id"), nullable=False
Integer, ForeignKey("shape_collection.id", ondelete="CASCADE"), nullable=False
)
manifest: "DBShapeCollection" = relationship(
"DBShapeCollection",
uselist=False,
cascade="all,delete",
cascade="all",
passive_deletes=True,
)


Expand All @@ -45,16 +63,17 @@ class DBLibrary(Base):
name: Mapped[str] = Column(String(), nullable=False, unique=True)

templates: Mapped[List["DBTemplate"]] = relationship(
"DBTemplate", back_populates="library", cascade="all,delete"
"DBTemplate", back_populates="library", cascade="all", passive_deletes=True
)

shape_collection_id = Column(
Integer, ForeignKey("shape_collection.id"), nullable=False
Integer, ForeignKey("shape_collection.id", ondelete="CASCADE"), nullable=False
)
shape_collection: DBShapeCollection = relationship(
"DBShapeCollection",
uselist=False,
cascade="all,delete",
cascade="all",
passive_deletes=True,
)


Expand All @@ -64,8 +83,8 @@ class DepsAssociation(Base):
__tablename__ = "deps_association_table"

id: Mapped[int] = Column(Integer, primary_key=True)
dependant_id: Mapped[int] = Column(ForeignKey("template.id"))
dependee_id: Mapped[int] = Column(ForeignKey("template.id"))
dependant_id: Mapped[int] = Column(ForeignKey("template.id", ondelete="CASCADE"))
dependee_id: Mapped[int] = Column(ForeignKey("template.id", ondelete="CASCADE"))
# args are a mapping of dependee args to dependant args
args: Mapped[Dict[str, str]] = Column(JSONType) # type: ignore

Expand All @@ -89,21 +108,27 @@ class DBTemplate(Base):
body_id: Mapped[str] = Column(String())
optional_args: Mapped[List[str]] = Column(JSONType) # type: ignore

library_id: Mapped[int] = Column(Integer, ForeignKey("library.id"), nullable=False)
library_id: Mapped[int] = Column(
Integer, ForeignKey("library.id", ondelete="CASCADE"), nullable=False
)
library: Mapped[DBLibrary] = relationship("DBLibrary", back_populates="templates")
dependencies: Mapped[List["DBTemplate"]] = relationship(
"DBTemplate",
secondary="deps_association_table",
primaryjoin=id == DepsAssociation.dependant_id,
secondaryjoin=id == DepsAssociation.dependee_id,
back_populates="dependants",
cascade="all",
passive_deletes=True,
)
dependants: Mapped[List["DBTemplate"]] = relationship(
"DBTemplate",
secondary="deps_association_table",
primaryjoin=id == DepsAssociation.dependee_id,
secondaryjoin=id == DepsAssociation.dependant_id,
back_populates="dependencies",
cascade="all",
passive_deletes=True,
)

__table_args__ = (
Expand Down
64 changes: 28 additions & 36 deletions buildingmotif/dataclasses/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,45 +422,37 @@ def _resolve_dependency(
:return: the template instance this dependency points to
:rtype: Template
"""
# if dep is a _template_dependency, turn it into a template
dependee = None
binding_args = {}
if isinstance(dep, _template_dependency):
dependee = dep.to_template(template_id_lookup)
template.add_dependency(dependee, dep.bindings)
return

# now, we know that dep is a dict

# if dependency names a library explicitly, load that library and get the template by name
if "library" in dep:
dependee = Library.load(name=dep["library"]).get_template_by_name(
dep["template"]
)
template.add_dependency(dependee, dep["args"])
return
# if no library is provided, try to resolve the dependency from this library
if dep["template"] in template_id_lookup:
dependee = Template.load(template_id_lookup[dep["template"]])
template.add_dependency(dependee, dep["args"])
return
# check documentation for skip_uri for what URIs get skipped
if skip_uri(dep["template"]):
return

# if the dependency is not in the local cache, then search through this library's imports
# for the template
for imp in self.graph_imports:
try:
library = Library.load(name=str(imp))
dependee = library.get_template_by_name(dep["template"])
template.add_dependency(dependee, dep["args"])
return
except Exception as e:
logging.debug(
f"Could not find dependee {dep['template']} in library {imp}: {e}"
binding_args = dep.bindings
elif isinstance(dep, dict):
binding_args = dep.get("args", {})
if "library" in dep:
dependee = Library.load(name=dep["library"]).get_template_by_name(
dep["template"]
)
logging.warning(
f"Warning: could not find dependee {dep['template']} in libraries {self.graph_imports}"
)
elif dep["template"] in template_id_lookup:
dependee = Template.load(template_id_lookup[dep["template"]])
elif skip_uri(dep["template"]):
return
else:
for imp in self.graph_imports:
try:
library = Library.load(name=str(imp))
dependee = library.get_template_by_name(dep["template"])
break
except Exception as e:
logging.debug(
f"Could not find dependee {dep['template']} in library {imp}: {e}"
)
if dependee is not None:
template.add_dependency(dependee, binding_args)
else:
logging.warning(
f"Warning: could not find dependee {dep} in libraries {self.graph_imports}"
)

def _resolve_template_dependencies(
self,
Expand Down
1 change: 1 addition & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Generates tests automatically
"""

import glob
from pathlib import Path

Expand Down
20 changes: 10 additions & 10 deletions tests/unit/database/table_connection/test_db_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,27 @@
from buildingmotif.database.tables import DBLibrary, DBShapeCollection, DBTemplate


def test_create_db_library(table_connection, monkeypatch):
def test_create_db_library(bm, monkeypatch):
mocked_uuid = uuid.uuid4()

def mockreturn():
return mocked_uuid

monkeypatch.setattr(uuid, "uuid4", mockreturn)

db_library = table_connection.create_db_library(name="my_db_library")
db_library = bm.table_connection.create_db_library(name="my_db_library")

assert db_library.name == "my_db_library"
assert db_library.templates == []
assert isinstance(db_library.shape_collection, DBShapeCollection)
assert db_library.shape_collection.graph_id == str(mocked_uuid)


def test_get_db_libraries(table_connection):
table_connection.create_db_library(name="my_db_library")
table_connection.create_db_library(name="your_db_library")
def test_get_db_libraries(bm):
bm.table_connection.create_db_library(name="my_db_library")
bm.table_connection.create_db_library(name="your_db_library")

db_libraries = table_connection.get_all_db_libraries()
db_libraries = bm.table_connection.get_all_db_libraries()

assert len(db_libraries) == 2
assert all(type(tl) == DBLibrary for tl in db_libraries)
Expand All @@ -40,11 +40,11 @@ def test_get_db_libraries(table_connection):
}


def test_get_db_library(table_connection):
db_library = table_connection.create_db_library(name="my_library")
table_connection.create_db_template("my_db_template", library_id=db_library.id)
def test_get_db_library(bm):
db_library = bm.table_connection.create_db_library(name="my_library")
bm.table_connection.create_db_template("my_db_template", library_id=db_library.id)

db_library = table_connection.get_db_library(id=db_library.id)
db_library = bm.table_connection.get_db_library(id=db_library.id)
assert db_library.name == "my_library"
assert len(db_library.templates) == 1
assert type(db_library.templates[0]) == DBTemplate
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/database/table_connection/test_db_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def test_create_db_model(mock_uuid4, table_connection):
assert db_model.manifest.graph_id == str(mocked_manifest_uuid)


def test_get_db_models(table_connection):
table_connection.create_db_model(
def test_get_db_models(bm):
bm.table_connection.create_db_model(
name="my_db_model", description="a very good model"
)
table_connection.create_db_model(
bm.table_connection.create_db_model(
name="your_db_model", description="an ok good model"
)

db_models = table_connection.get_all_db_models()
db_models = bm.table_connection.get_all_db_models()

assert len(db_models) == 2
assert all(type(m) == DBModel for m in db_models)
Expand All @@ -43,13 +43,13 @@ def test_get_db_models(table_connection):


@mock.patch("uuid.uuid4")
def test_get_db_model(mock_uuid4, table_connection):
def test_get_db_model(mock_uuid4, bm):
mocked_graph_uuid = uuid.uuid4()
mocked_manifest_uuid = uuid.uuid4()
mock_uuid4.side_effect = [mocked_graph_uuid, mocked_manifest_uuid]

db_model = table_connection.create_db_model(name="my_db_model")
db_model = table_connection.get_db_model(id=db_model.id)
db_model = bm.table_connection.create_db_model(name="my_db_model")
db_model = bm.table_connection.get_db_model(id=db_model.id)

assert db_model.name == "my_db_model"
assert db_model.graph_id == str(mocked_graph_uuid)
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/database/table_connection/test_db_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@
from buildingmotif.database.tables import DBShapeCollection


def test_create_db_shape_collection(monkeypatch, table_connection):
def test_create_db_shape_collection(monkeypatch, bm):
mocked_uuid = uuid.uuid4()

def mockreturn():
return mocked_uuid

monkeypatch.setattr(uuid, "uuid4", mockreturn)

db_shape_collection = table_connection.create_db_shape_collection()
db_shape_collection = bm.table_connection.create_db_shape_collection()

assert db_shape_collection.graph_id == str(mocked_uuid)


def test_get_db_shape_collections(table_connection):
shape_collection1 = table_connection.create_db_shape_collection()
shape_collection2 = table_connection.create_db_shape_collection()
def test_get_db_shape_collections(bm):
shape_collection1 = bm.table_connection.create_db_shape_collection()
shape_collection2 = bm.table_connection.create_db_shape_collection()

db_shape_collections = table_connection.get_all_db_shape_collections()
db_shape_collections = bm.table_connection.get_all_db_shape_collections()

assert len(db_shape_collections) == 2
assert all(type(m) == DBShapeCollection for m in db_shape_collections)
Expand Down
Loading
Loading