Skip to content

Commit

Permalink
refactor(graph): directed_path_exists
Browse files Browse the repository at this point in the history
  • Loading branch information
this-is-sofia committed Feb 9, 2025
1 parent 5274228 commit d70de26
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 43 deletions.
2 changes: 1 addition & 1 deletion causy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str], visited

# Recursive DFS through neighbors
for w in self.edges.get(u, []): # Use .get() to avoid KeyError if u is not in self.edges
if self.directed_path_exists(w, v, visited):
if self.edge_of_type_exists(u, w, DirectedEdge()) and self.directed_path_exists(w, v, visited):
return True

return False
Expand Down
49 changes: 10 additions & 39 deletions tests/test_effect_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
PC_ORIENTATION_RULES,
PC_EDGE_TYPES,
PC_GRAPH_UI_EXTENSION,
PC_DEFAULT_THRESHOLD,
PC_DEFAULT_THRESHOLD, PC,
)
from causy.causal_discovery.constraint.independence_tests.common import (
CorrelationCoefficientTest,
Expand Down Expand Up @@ -85,47 +85,18 @@ def test_direct_effect_estimation_trivial_case(self):
)

def test_direct_effect_estimation_basic_example(self):
PC = graph_model_factory(
Algorithm(
pipeline_steps=[
CalculatePearsonCorrelations(
display_name="Calculate Pearson Correlations"
),
CorrelationCoefficientTest(
threshold=VariableReference(name="threshold"),
display_name="Correlation Coefficient Test",
),
PartialCorrelationTest(
threshold=VariableReference(name="threshold"),
display_name="Partial Correlation Test",
),
ExtendedPartialCorrelationTestMatrix(
threshold=VariableReference(name="threshold"),
display_name="Extended Partial Correlation Test Matrix",
),
*PC_ORIENTATION_RULES,
ComputeDirectEffectsInDAGsMultivariateRegression(
display_name="Compute Direct Effects"
),
],
edge_types=PC_EDGE_TYPES,
extensions=[PC_GRAPH_UI_EXTENSION],
name="PC",
variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)],
)
)

model = IIDSampleGenerator(
edges=[
SampleEdge(NodeReference("X"), NodeReference("Z"), 5),
SampleEdge(NodeReference("Y"), NodeReference("Z"), 6),
SampleEdge(NodeReference("Z"), NodeReference("V"), 3),
SampleEdge(NodeReference("Z"), NodeReference("W"), 4),
SampleEdge(NodeReference("X"), NodeReference("Z"), 1),
SampleEdge(NodeReference("Y"), NodeReference("Z"), 1),
SampleEdge(NodeReference("Z"), NodeReference("V"), 1),
SampleEdge(NodeReference("Z"), NodeReference("W"), 1),
],
)

tst = PC()
sample_size = 1000000
sample_size = 10000
test_data, graph = model.generate(sample_size)
tst.create_graph_from_data(test_data)
tst.create_all_possible_edges()
Expand All @@ -137,30 +108,30 @@ def test_direct_effect_estimation_basic_example(self):
tst.graph.edge_value(tst.graph.nodes["X"], tst.graph.nodes["Z"])[
"direct_effect"
],
5.0,
1.0,
0,
)
self.assertAlmostEqual(
tst.graph.edge_value(tst.graph.nodes["Y"], tst.graph.nodes["Z"])[
"direct_effect"
],
6.0,
1.0,
0,
)

self.assertAlmostEqual(
tst.graph.edge_value(tst.graph.nodes["Z"], tst.graph.nodes["V"])[
"direct_effect"
],
3.0,
1.0,
0,
)

self.assertAlmostEqual(
tst.graph.edge_value(tst.graph.nodes["Z"], tst.graph.nodes["W"])[
"direct_effect"
],
4.0,
1.0,
0,
)

Expand Down
88 changes: 85 additions & 3 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,12 @@ def test_directed_path_exists_cycle(self):
graph.add_directed_edge(node3, node1, {"test": "test"})
self.assertTrue(graph.directed_path_exists(node1, node2))
self.assertTrue(graph.directed_path_exists(node2, node1))
self.assertTrue(graph.directed_path_exists(node1, node3))
self.assertTrue(graph.directed_path_exists(node3, node1))
self.assertTrue(graph.directed_path_exists(node2, node3))
self.assertTrue(graph.directed_path_exists(node3, node2))

def test_directed_path_mediated_path(self):
def test_directed_path_exists_mediated_path(self):
graph = GraphManager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
Expand All @@ -716,7 +720,7 @@ def test_directed_path_mediated_path(self):
self.assertTrue(graph.directed_path_exists(node2, node3))
self.assertFalse(graph.directed_path_exists(node3, node2))

def test_directed_path_mediated_path_several_mediated_paths(self):
def test_directed_path_exists_mediated_path_several_mediated_paths(self):
graph = GraphManager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
Expand All @@ -742,7 +746,7 @@ def test_directed_path_mediated_path_several_mediated_paths(self):
self.assertTrue(graph.directed_path_exists(node1, node4))
self.assertFalse(graph.directed_path_exists(node4, node1))

def test_directed_path_mediated_path_undirected_edges(self):
def test_directed_path_exists_mediated_path_undirected_edges(self):
graph = GraphManager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
Expand All @@ -759,3 +763,81 @@ def test_directed_path_mediated_path_undirected_edges(self):
self.assertFalse(graph.directed_path_exists(node4, node2))
self.assertFalse(graph.directed_path_exists(node2, node3))
self.assertFalse(graph.directed_path_exists(node3, node2))

def test_directed_path_exists_mediated_path_undirected_and_directed_edges(self):
graph = GraphManager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_edge(node2, node3, {"test": "test"})

self.assertFalse(graph.directed_path_exists(node1, node3))

def test_directed_path_exists_mediated_path_undirected_and_directed_edges_2(self):
graph = GraphManager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})

self.assertFalse(graph.directed_path_exists(node1, node3))

def test_directed_path_exists_mediated_path_undirected_and_directed_edges_3(self):
graph = GraphManager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
node4 = graph.add_node("test4", [1, 2, 3])
node5 = graph.add_node("test5", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_edge(node2, node3, {"test": "test"})
graph.add_directed_edge(node3, node4, {"test": "test"})
graph.add_directed_edge(node4, node5, {"test": "test"})

self.assertFalse(graph.directed_path_exists(node1, node5))
self.assertFalse(graph.directed_path_exists(node1, node4))
self.assertFalse(graph.directed_path_exists(node1, node3))
self.assertTrue(graph.directed_path_exists(node1, node2))
self.assertTrue(graph.directed_path_exists(node3, node5))

def test_directed_path_exists_mediated_path_undirected_and_directed_edges_4(self):
graph = GraphManager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
node4 = graph.add_node("test4", [1, 2, 3])
node5 = graph.add_node("test5", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})
graph.add_edge(node3, node4, {"test": "test"})
graph.add_directed_edge(node4, node5, {"test": "test"})

self.assertFalse(graph.directed_path_exists(node1, node5))
self.assertFalse(graph.directed_path_exists(node1, node4))
self.assertFalse(graph.directed_path_exists(node2, node5))
self.assertTrue(graph.directed_path_exists(node1, node2))
self.assertTrue(graph.directed_path_exists(node1, node3))

def test_directed_path_exists_mediated_paths_undirected_and_directed_edges(self):
graph = GraphManager()
X = graph.add_node("X", [1, 2, 3])
Y = graph.add_node("Y", [1, 2, 3])
Z = graph.add_node("Z", [1, 2, 3])
W = graph.add_node("W", [1, 2, 3])
V = graph.add_node("V", [1, 2, 3])

graph.add_directed_edge(X, Y, {"test": "test"})
graph.add_edge(Y, Z, {"test": "test"})
graph.add_edge(Z, W, {"test": "test"})
graph.add_directed_edge(W, X, {"test": "test"})
graph.add_directed_edge(V, X, {"test": "test"})

self.assertTrue(graph.directed_path_exists(V, X))
self.assertTrue(graph.directed_path_exists(V, Y))
self.assertFalse(graph.directed_path_exists(V, Z))
self.assertFalse(graph.directed_path_exists(Z, X))



72 changes: 72 additions & 0 deletions tests/test_orientation_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,3 +788,75 @@ def test_further_orient_quadruple_test(self):
self.assertFalse(model.graph.undirected_edge_exists(x, z))
self.assertFalse(model.graph.only_directed_edge_exists(z, x))
self.assertTrue(model.graph.only_directed_edge_exists(x, z))

def test_avoid_cycles_four_nodes(self):
pipeline = [Loop(
pipeline_steps=[
NonColliderTest(),
FurtherOrientTripleTest(),
OrientQuadrupleTest(),
FurtherOrientQuadrupleTest(),
],
exit_condition=ExitOnNoActions(),
),]
model = graph_model_factory(
Algorithm(
pipeline_steps=pipeline,
edge_types=[DirectedEdge(), UndirectedEdge()],
name="TestNonColliderAvoidCycles",
)
)()
model.graph = GraphManager()
x = model.graph.add_node("X", [])
y = model.graph.add_node("Y", [])
z = model.graph.add_node("Z", [])
w = model.graph.add_node("W", [])
model.graph.add_edge(x, y, {})
model.graph.add_edge(y, z, {})
model.graph.add_directed_edge(z, x, {})
model.graph.add_directed_edge(w, x, {})
model.execute_pipeline_steps()
# sanity check
self.assertTrue(model.graph.edge_of_type_exists(z, x, DirectedEdge()))
self.assertTrue(model.graph.edge_of_type_exists(w, x, DirectedEdge()))

self.assertTrue(model.graph.edge_of_type_exists(x, y, DirectedEdge()))
self.assertTrue(model.graph.edge_of_type_exists(z, y, DirectedEdge()))


def test_avoid_cycles_five_nodes(self):
pipeline = [Loop(
pipeline_steps=[
NonColliderTest(),
FurtherOrientTripleTest(),
OrientQuadrupleTest(),
FurtherOrientQuadrupleTest(),
],
exit_condition=ExitOnNoActions(),
),]
model = graph_model_factory(
Algorithm(
pipeline_steps=pipeline,
edge_types=[DirectedEdge(), UndirectedEdge()],
name="TestNonColliderAvoidCycles",
)
)()
model.graph = GraphManager()
x = model.graph.add_node("X", [])
y = model.graph.add_node("Y", [])
z = model.graph.add_node("Z", [])
w = model.graph.add_node("W", [])
v = model.graph.add_node("V", [])
model.graph.add_edge(x, y, {})
model.graph.add_edge(y, z, {})
model.graph.add_edge(z, w, {})
model.graph.add_directed_edge(w, x, {})
model.graph.add_directed_edge(v, x, {})
model.execute_pipeline_steps()
# sanity check
self.assertTrue(model.graph.edge_of_type_exists(w, x, DirectedEdge()))
self.assertTrue(model.graph.edge_of_type_exists(v, x, DirectedEdge()))

self.assertTrue(model.graph.edge_of_type_exists(x, y, DirectedEdge()))
self.assertTrue(model.graph.edge_of_type_exists(y, z, DirectedEdge()))
self.assertFalse(model.graph.edge_of_type_exists(z, w, DirectedEdge()))

0 comments on commit d70de26

Please sign in to comment.