diff --git a/causy/graph.py b/causy/graph.py index bd010bc..bb27d29 100644 --- a/causy/graph.py +++ b/causy/graph.py @@ -86,6 +86,29 @@ def add_edge(self, u: Node, v: Node, value: Dict): self.edge_history[(u.id, v.id)] = [] self.edge_history[(v.id, u.id)] = [] + def add_directed_edge(self, u: Node, v: Node, value: Dict): + """ + Add a directed edge from u to v to the graph + :param u: u node + :param v: v node + :return: + """ + + if u.id not in self.nodes: + raise GraphError(f"Node {u} does not exist") + if v.id not in self.nodes: + raise GraphError(f"Node {v} does not exist") + + if u.id == v.id: + raise GraphError("Self loops are currently not allowed") + + if u.id not in self.edges: + self.edges[u.id] = {} + + self.edges[u.id][v.id] = value + + self.edge_history[(u.id, v.id)] = [] + def retrieve_edge_history( self, u, v, action: TestResultAction = None ) -> List[TestResult]: @@ -334,14 +357,28 @@ def directed_paths(self, u: Node, v: Node): :param v: node v :return: list of directed paths """ + # TODO: try a better data structure for this if self.directed_edge_exists(u, v): return [[(u, v)]] paths = [] for w in self.edges[u.id]: - for path in self.directed_paths(self.nodes[w], v): - paths.append([(u, w)] + path) + if self.directed_edge_exists(u, self.nodes[w]): + for path in self.directed_paths(self.nodes[w], v): + paths.append([(u, self.nodes[w])] + path) return paths + def parents_of_node(self, u: Node): + """ + Return all parents of a node u + :param u: node u + :return: list of nodes (parents) + """ + parents = [] + for w in self.edges: + if self.directed_edge_exists(self.nodes[w], u): + parents.append(self.nodes[w]) + return parents + def inducing_path_exists(self, u: Node, v: Node): """ Check if an inducing path from u to v exists. diff --git a/tests/test_graph.py b/tests/test_graph.py index ce69977..2a4b952 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -50,6 +50,17 @@ def test_add_edge(self): self.assertEqual(graph.edge_value(node2, node1), {"test": "test"}) self.assertTrue(graph.edge_exists(node1, node2)) + def test_add_directed_edge(self): + graph = Graph() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + self.assertEqual(len(graph.nodes), 2) + self.assertEqual(len(graph.edges), 1) + self.assertEqual(graph.edge_value(node1, node2), {"test": "test"}) + self.assertTrue(graph.directed_edge_exists(node1, node2)) + self.assertFalse(graph.directed_edge_exists(node2, node1)) + def test_add_edge_with_non_existing_node(self): graph = Graph() node1 = graph.add_node("test1", [1, 2, 3]) @@ -223,3 +234,37 @@ def test_undirected_edge_exists_with_non_existing_edge(self): self.assertFalse(graph.undirected_edge_exists(node1, node2)) self.assertFalse(graph.undirected_edge_exists(node2, node1)) + + def test_parents_of_node_two_nodes(self): + graph = Graph() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + self.assertEqual(graph.parents_of_node(node2), [node1]) + + def test_parents_of_node_three_nodes(self): + graph = Graph() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test2", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node3, node2, {"test": "test"}) + self.assertEqual(graph.parents_of_node(node2), [node1, node3]) + + def test_directed_paths_two_nodes(self): + graph = Graph() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + self.assertEqual(graph.directed_paths(node1, node2), [[(node1, node2)]]) + + def test_directed_paths_three_nodes(self): + graph = Graph() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test2", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + self.assertEqual( + graph.directed_paths(node1, node3), [[(node1, node2), (node2, node3)]] + )