From 98f7275737be26b7aa07306b33ae1c9b136f0664 Mon Sep 17 00:00:00 2001 From: Lilith Wittmann Date: Fri, 7 Feb 2025 19:47:20 +0100 Subject: [PATCH] allow to create edges after directed edges --- causy/graph.py | 42 ++++++++++++++++++++++++------------------ tests/test_graph.py | 14 +++++++++++++- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/causy/graph.py b/causy/graph.py index 901a861..9e22e05 100644 --- a/causy/graph.py +++ b/causy/graph.py @@ -650,6 +650,28 @@ def get_edge(self, u: Node, v: Node) -> Edge: raise GraphError(f"Edge {u} -> {v} does not exist") return self.edges[u.id][v.id] + def _init_edge(self, u: Node, v: Node): + """ + Initialize an edge between two nodes + :param u: + :param v: + :return: + """ + + if u.id not in self.edges: + self.edges[u.id] = self.__init_dict() + self._reverse_edges[u.id] = self.__init_dict() + self._deleted_edges[u.id] = self.__init_dict() + if v.id not in self.edges: + self.edges[v.id] = self.__init_dict() + self._reverse_edges[v.id] = self.__init_dict() + self._deleted_edges[v.id] = self.__init_dict() + + if (u.id, v.id) not in self.edge_history: + self.edge_history[(u.id, v.id)] = [] + if (v.id, u.id) not in self.edge_history: + self.edge_history[(v.id, u.id)] = [] + def add_edge( self, u: Node, @@ -674,14 +696,7 @@ def add_edge( if u.id == v.id: raise GraphError("Self loops are currently not allowed") - if u.id not in self.edges: - self.edges[u.id] = self.__init_dict() - self._reverse_edges[u.id] = self.__init_dict() - self._deleted_edges[u.id] = self.__init_dict() - if v.id not in self.edges: - self.edges[v.id] = self.__init_dict() - self._reverse_edges[v.id] = self.__init_dict() - self._deleted_edges[v.id] = self.__init_dict() + self._init_edge(u, v) a_edge = Edge(u=u, v=v, edge_type=edge_type, metadata=metadata) self.edges[u.id][v.id] = a_edge @@ -691,9 +706,6 @@ def add_edge( self.edges[v.id][u.id] = b_edge self._reverse_edges[u.id][v.id] = b_edge - self.edge_history[(u.id, v.id)] = [] - self.edge_history[(v.id, u.id)] = [] - def add_directed_edge( self, u: Node, @@ -718,19 +730,13 @@ def add_directed_edge( if u.id == v.id: raise GraphError("Self loops are currently not allowed") - if u.id not in self.edges: - self.edges[u.id] = self.__init_dict() - self._deleted_edges[u.id] = self.__init_dict() - if v.id not in self._reverse_edges: - self._reverse_edges[v.id] = self.__init_dict() + self._init_edge(u, v) edge = Edge(u=u, v=v, edge_type=edge_type, metadata=metadata) self.edges[u.id][v.id] = edge self._reverse_edges[v.id][u.id] = edge - self.edge_history[(u.id, v.id)] = [] - def add_edge_history(self, u, v, action: TestResult): """ Add an action to the edge history diff --git a/tests/test_graph.py b/tests/test_graph.py index 0122d0e..63b55f8 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -57,11 +57,23 @@ def test_add_directed_edge(self): node2 = graph.add_node("test2", [1, 2, 3]) graph.add_directed_edge(node1, node2, {"test": "test"}) self.assertEqual(len(graph.nodes), 2) - self.assertEqual(len(graph.edges), 1) self.assertEqual(graph.edge_value(node1, node2), {"test": "test"}) self.assertTrue(graph.directed_edge_exists(node1, node2)) self.assertFalse(graph.directed_edge_exists(node2, node1)) + def test_create_undirected_edge_after_directed_edge(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_edge(node1, node2, {"test": "test"}) + self.assertEqual(len(graph.nodes), 2) + self.assertEqual(graph.edge_value(node1, node2), {"test": "test"}) + self.assertTrue(graph.undirected_edge_exists(node1, node2)) + self.assertTrue(graph.undirected_edge_exists(node2, node1)) + self.assertTrue(graph.edge_exists(node1, node2)) + self.assertTrue(graph.edge_exists(node2, node1)) + def test_add_edge_with_non_existing_node(self): graph = GraphManager() node1 = graph.add_node("test1", [1, 2, 3])