Skip to content

Commit

Permalink
allow to create edges after directed edges
Browse files Browse the repository at this point in the history
  • Loading branch information
LilithWittmann committed Feb 7, 2025
1 parent a14b12b commit 98f7275
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
42 changes: 24 additions & 18 deletions causy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,28 @@ def get_edge(self, u: Node, v: Node) -> Edge:
raise GraphError(f"Edge {u} -> {v} does not exist")
return self.edges[u.id][v.id]

def _init_edge(self, u: Node, v: Node):
"""
Initialize an edge between two nodes
:param u:
:param v:
:return:
"""

if u.id not in self.edges:
self.edges[u.id] = self.__init_dict()
self._reverse_edges[u.id] = self.__init_dict()
self._deleted_edges[u.id] = self.__init_dict()
if v.id not in self.edges:
self.edges[v.id] = self.__init_dict()
self._reverse_edges[v.id] = self.__init_dict()
self._deleted_edges[v.id] = self.__init_dict()

if (u.id, v.id) not in self.edge_history:
self.edge_history[(u.id, v.id)] = []
if (v.id, u.id) not in self.edge_history:
self.edge_history[(v.id, u.id)] = []

def add_edge(
self,
u: Node,
Expand All @@ -674,14 +696,7 @@ def add_edge(
if u.id == v.id:
raise GraphError("Self loops are currently not allowed")

if u.id not in self.edges:
self.edges[u.id] = self.__init_dict()
self._reverse_edges[u.id] = self.__init_dict()
self._deleted_edges[u.id] = self.__init_dict()
if v.id not in self.edges:
self.edges[v.id] = self.__init_dict()
self._reverse_edges[v.id] = self.__init_dict()
self._deleted_edges[v.id] = self.__init_dict()
self._init_edge(u, v)

a_edge = Edge(u=u, v=v, edge_type=edge_type, metadata=metadata)
self.edges[u.id][v.id] = a_edge
Expand All @@ -691,9 +706,6 @@ def add_edge(
self.edges[v.id][u.id] = b_edge
self._reverse_edges[u.id][v.id] = b_edge

self.edge_history[(u.id, v.id)] = []
self.edge_history[(v.id, u.id)] = []

def add_directed_edge(
self,
u: Node,
Expand All @@ -718,19 +730,13 @@ def add_directed_edge(
if u.id == v.id:
raise GraphError("Self loops are currently not allowed")

if u.id not in self.edges:
self.edges[u.id] = self.__init_dict()
self._deleted_edges[u.id] = self.__init_dict()
if v.id not in self._reverse_edges:
self._reverse_edges[v.id] = self.__init_dict()
self._init_edge(u, v)

edge = Edge(u=u, v=v, edge_type=edge_type, metadata=metadata)

self.edges[u.id][v.id] = edge
self._reverse_edges[v.id][u.id] = edge

self.edge_history[(u.id, v.id)] = []

def add_edge_history(self, u, v, action: TestResult):
"""
Add an action to the edge history
Expand Down
14 changes: 13 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,23 @@ def test_add_directed_edge(self):
node2 = graph.add_node("test2", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
self.assertEqual(len(graph.nodes), 2)
self.assertEqual(len(graph.edges), 1)
self.assertEqual(graph.edge_value(node1, node2), {"test": "test"})
self.assertTrue(graph.directed_edge_exists(node1, node2))
self.assertFalse(graph.directed_edge_exists(node2, node1))

def test_create_undirected_edge_after_directed_edge(self):
graph = GraphManager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_edge(node1, node2, {"test": "test"})
self.assertEqual(len(graph.nodes), 2)
self.assertEqual(graph.edge_value(node1, node2), {"test": "test"})
self.assertTrue(graph.undirected_edge_exists(node1, node2))
self.assertTrue(graph.undirected_edge_exists(node2, node1))
self.assertTrue(graph.edge_exists(node1, node2))
self.assertTrue(graph.edge_exists(node2, node1))

def test_add_edge_with_non_existing_node(self):
graph = GraphManager()
node1 = graph.add_node("test1", [1, 2, 3])
Expand Down

0 comments on commit 98f7275

Please sign in to comment.