Skip to content

Commit

Permalink
Merge pull request #77 from causy-dev/track-orientation-conflicts
Browse files Browse the repository at this point in the history
feat(orientation_rules): track orientation conflicts
  • Loading branch information
this-is-sofia authored Feb 1, 2025
2 parents c5e68b3 + 2997d29 commit ca02309
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 7 deletions.
42 changes: 35 additions & 7 deletions causy/causal_discovery/constraint/orientation_rules/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,30 @@ def process(
unapplied_actions, y, z
)
if len(unapplied_actions_y_z) > 0 or len(unapplied_actions_x_z) > 0:
logger.warning(
f"Orientation conflict detected in ColliderTest stage when orienting the edge between {x.name} and {y.name}. The conflict is resolved using the strategy {self.conflict_resolution_strategy}, but orientation conflicts indicate assumption violations and can severely affect the accuracy of the results.",
)
if (
ColliderTestConflictResolutionStrategies.KEEP_FIRST
is self.conflict_resolution_strategy
):
# We keep the first edge that was removed
continue
# We prioritize the first edge that was removed, we do nothing
if len(unapplied_actions_y_z) > 0:
results.append(
TestResult(
u=z,
v=y,
action=TestResultAction.DO_NOTHING,
data={"orientation_conflict": True},
)
)
if len(unapplied_actions_x_z) > 0:
results.append(
TestResult(
u=z,
v=x,
action=TestResultAction.DO_NOTHING,
data={"orientation_conflict": True},
)
)

elif (
ColliderTestConflictResolutionStrategies.KEEP_LAST
is self.conflict_resolution_strategy
Expand Down Expand Up @@ -237,7 +252,12 @@ def process(
breakflag = True
break
if breakflag is True:
continue
return TestResult(
u=y,
v=z,
action=TestResultAction.DO_NOTHING,
data={"orientation_conflict": True},
)
return TestResult(
u=y,
v=z,
Expand All @@ -250,7 +270,15 @@ def process(
):
for node in graph.nodes:
if graph.only_directed_edge_exists(graph.nodes[node], x):
continue
breakflag = True
break
if breakflag is True:
return TestResult(
u=x,
v=z,
action=TestResultAction.DO_NOTHING,
data={"orientation_conflict": True},
)
return TestResult(
u=x,
v=z,
Expand Down
75 changes: 75 additions & 0 deletions tests/test_pc_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ComputeDirectEffectsInDAGsMultivariateRegression,
)
from causy.common_pipeline_steps.calculation import CalculatePearsonCorrelations
from causy.edge_types import DirectedEdge
from causy.generators import PairsWithNeighboursGenerator
from causy.graph_model import graph_model_factory
from causy.causal_discovery.constraint.independence_tests.common import (
Expand Down Expand Up @@ -426,6 +427,36 @@ def test_track_triples_three_nodes_pc_unconditionally_independent(self):
# TODO: find issue with tracking in partial correlation test in this setting
pass

def test_orientation_conflict_tracking(self):
causal_insufficiency_four_nodes = IIDSampleGenerator(
edges=[
SampleEdge(NodeReference("U1"), NodeReference("X"), 1),
SampleEdge(NodeReference("U1"), NodeReference("Y"), 1),
SampleEdge(NodeReference("U2"), NodeReference("Y"), 1),
SampleEdge(NodeReference("U2"), NodeReference("Z"), 1),
SampleEdge(NodeReference("U3"), NodeReference("Z"), 1),
SampleEdge(NodeReference("U3"), NodeReference("V"), 1),
SampleEdge(NodeReference("U4"), NodeReference("V"), 1),
SampleEdge(NodeReference("U4"), NodeReference("X"), 1),
],
)
test_data, graph = causal_insufficiency_four_nodes.generate(10000)
test_data.pop("U1")
test_data.pop("U2")
test_data.pop("U3")
test_data.pop("U4")
tst = PCClassic()
tst.create_graph_from_data(test_data)
tst.create_all_possible_edges()
tst.execute_pipeline_steps()

nb_of_conflicts = 0
for result in tst.graph.action_history:
for proposed_action in result.all_proposed_actions:
if "orientation_conflict" in proposed_action.data:
nb_of_conflicts += 1
self.assertGreater(nb_of_conflicts, 1)

def test_d_separation_on_output_of_pc(self):
rdnv = self.seeded_random.normalvariate
sample_generator = IIDSampleGenerator(
Expand All @@ -447,3 +478,47 @@ def test_d_separation_on_output_of_pc(self):
z = tst.graph.node_by_id("Z")
self.assertEqual(tst.graph.are_nodes_d_separated_cpdag(x, z, []), False)
self.assertEqual(tst.graph.are_nodes_d_separated_cpdag(x, z, [y]), True)

def test_pc_faithfulness_violation(self):
rdnv = self.seeded_random.normalvariate
sample_generator = IIDSampleGenerator(
edges=[
SampleEdge(NodeReference("X"), NodeReference("V"), 2),
SampleEdge(NodeReference("V"), NodeReference("W"), 2),
SampleEdge(NodeReference("W"), NodeReference("Y"), -2),
SampleEdge(NodeReference("X"), NodeReference("Y"), 8),
],
random=lambda: rdnv(0, 1),
)
test_data, graph = sample_generator.generate(10000)
tst = PCClassic()
tst.create_graph_from_data(test_data)
tst.create_all_possible_edges()
tst.execute_pipeline_steps()

self.assertEqual(tst.graph.edge_exists("X", "Y"), False)
self.assertEqual(tst.graph.edge_exists("V", "Y"), False)
self.assertEqual(tst.graph.edge_exists("W", "X"), False)
self.assertEqual(tst.graph.edge_exists("W", "Y"), True)
self.assertEqual(tst.graph.edge_exists("V", "W"), True)
self.assertEqual(tst.graph.edge_exists("X", "V"), True)

def test_noncollider_triple_rule_e2e(self):
rdnv = self.seeded_random.normalvariate
sample_generator = IIDSampleGenerator(
edges=[
SampleEdge(NodeReference("X"), NodeReference("Y"), 2),
SampleEdge(NodeReference("Z"), NodeReference("Y"), 2),
SampleEdge(NodeReference("Y"), NodeReference("W"), 2),
],
random=lambda: rdnv(0, 1),
)
test_data, graph = sample_generator.generate(10000)
tst = PCClassic()
tst.create_graph_from_data(test_data)
tst.create_all_possible_edges()
tst.execute_pipeline_steps()

self.assertEqual(tst.graph.edge_of_type_exists("X", "Y", DirectedEdge()), True)
self.assertEqual(tst.graph.edge_of_type_exists("Z", "Y", DirectedEdge()), True)
self.assertEqual(tst.graph.edge_of_type_exists("Y", "W", DirectedEdge()), True)

0 comments on commit ca02309

Please sign in to comment.