Skip to content

Commit

Permalink
fest(graph): get parents
Browse files Browse the repository at this point in the history
also: tests for parents_of_nodes and directed_paths, fixed directed_paths
  • Loading branch information
this-is-sofia committed Nov 13, 2023
1 parent 704ba7c commit a77454e
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
41 changes: 39 additions & 2 deletions causy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,29 @@ def add_edge(self, u: Node, v: Node, value: Dict):
self.edge_history[(u.id, v.id)] = []
self.edge_history[(v.id, u.id)] = []

def add_directed_edge(self, u: Node, v: Node, value: Dict):
"""
Add a directed edge from u to v to the graph
:param u: u node
:param v: v node
:return:
"""

if u.id not in self.nodes:
raise GraphError(f"Node {u} does not exist")
if v.id not in self.nodes:
raise GraphError(f"Node {v} does not exist")

if u.id == v.id:
raise GraphError("Self loops are currently not allowed")

if u.id not in self.edges:
self.edges[u.id] = {}

self.edges[u.id][v.id] = value

self.edge_history[(u.id, v.id)] = []

def retrieve_edge_history(
self, u, v, action: TestResultAction = None
) -> List[TestResult]:
Expand Down Expand Up @@ -334,14 +357,28 @@ def directed_paths(self, u: Node, v: Node):
:param v: node v
:return: list of directed paths
"""
# TODO: try a better data structure for this
if self.directed_edge_exists(u, v):
return [[(u, v)]]
paths = []
for w in self.edges[u.id]:
for path in self.directed_paths(self.nodes[w], v):
paths.append([(u, w)] + path)
if self.directed_edge_exists(u, self.nodes[w]):
for path in self.directed_paths(self.nodes[w], v):
paths.append([(u, self.nodes[w])] + path)
return paths

def parents_of_node(self, u: Node):
"""
Return all parents of a node u
:param u: node u
:return: list of nodes (parents)
"""
parents = []
for w in self.edges:
if self.directed_edge_exists(self.nodes[w], u):
parents.append(self.nodes[w])
return parents

def inducing_path_exists(self, u: Node, v: Node):
"""
Check if an inducing path from u to v exists.
Expand Down
45 changes: 45 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def test_add_edge(self):
self.assertEqual(graph.edge_value(node2, node1), {"test": "test"})
self.assertTrue(graph.edge_exists(node1, node2))

def test_add_directed_edge(self):
graph = Graph()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
self.assertEqual(len(graph.nodes), 2)
self.assertEqual(len(graph.edges), 1)
self.assertEqual(graph.edge_value(node1, node2), {"test": "test"})
self.assertTrue(graph.directed_edge_exists(node1, node2))
self.assertFalse(graph.directed_edge_exists(node2, node1))

def test_add_edge_with_non_existing_node(self):
graph = Graph()
node1 = graph.add_node("test1", [1, 2, 3])
Expand Down Expand Up @@ -223,3 +234,37 @@ def test_undirected_edge_exists_with_non_existing_edge(self):

self.assertFalse(graph.undirected_edge_exists(node1, node2))
self.assertFalse(graph.undirected_edge_exists(node2, node1))

def test_parents_of_node_two_nodes(self):
graph = Graph()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
self.assertEqual(graph.parents_of_node(node2), [node1])

def test_parents_of_node_three_nodes(self):
graph = Graph()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test2", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node3, node2, {"test": "test"})
self.assertEqual(graph.parents_of_node(node2), [node1, node3])

def test_directed_paths_two_nodes(self):
graph = Graph()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
self.assertEqual(graph.directed_paths(node1, node2), [[(node1, node2)]])

def test_directed_paths_three_nodes(self):
graph = Graph()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test2", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})
self.assertEqual(
graph.directed_paths(node1, node3), [[(node1, node2), (node2, node3)]]
)

0 comments on commit a77454e

Please sign in to comment.