diff --git a/causy/causal_discovery/constraint/algorithms/pc.py b/causy/causal_discovery/constraint/algorithms/pc.py index 97449a4..83075b8 100644 --- a/causy/causal_discovery/constraint/algorithms/pc.py +++ b/causy/causal_discovery/constraint/algorithms/pc.py @@ -88,6 +88,38 @@ ) ) +PCClassic = graph_model_factory( + Algorithm( + pipeline_steps=[ + CalculatePearsonCorrelations( + display_name="Calculate Pearson Correlations" + ), + CorrelationCoefficientTest( + threshold=VariableReference(name="threshold"), + display_name="Correlation Coefficient Test", + ), + ExtendedPartialCorrelationTestMatrix( + threshold=VariableReference(name="threshold"), + display_name="Extended Partial Correlation Test Matrix", + generator=PairsWithNeighboursGenerator( + comparison_settings=ComparisonSettings( + min=3, max=AS_MANY_AS_FIELDS + ), + shuffle_combinations=False, + ), + ), + *PC_ORIENTATION_RULES, + ComputeDirectEffectsInDAGsMultivariateRegression( + display_name="Compute Direct Effects in DAGs Multivariate Regression" + ), + ], + edge_types=PC_EDGE_TYPES, + extensions=[PC_GRAPH_UI_EXTENSION], + name="PC", + variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)], + ) +) + PCStable = graph_model_factory( Algorithm( pipeline_steps=[ diff --git a/tests/test_pc_e2e.py b/tests/test_pc_e2e.py index 97f72a2..009890b 100644 --- a/tests/test_pc_e2e.py +++ b/tests/test_pc_e2e.py @@ -3,7 +3,7 @@ PC, PC_ORIENTATION_RULES, PC_GRAPH_UI_EXTENSION, - PC_DEFAULT_THRESHOLD, + PC_DEFAULT_THRESHOLD, PCClassic, ) from causy.causal_effect_estimation.multivariate_regression import ( ComputeDirectEffectsInDAGsMultivariateRegression, @@ -351,37 +351,6 @@ def test_tracking_triples_four_nodes(self): self.assertEqual(len(triples), 6 + 12 + 12) def test_track_triples_three_nodes_custom_pc(self): - algo = graph_model_factory( - Algorithm( - pipeline_steps=[ - CalculatePearsonCorrelations( - display_name="Calculate Pearson Correlations" - ), - CorrelationCoefficientTest( - threshold=VariableReference(name="threshold"), - display_name="Correlation Coefficient Test", - ), - ExtendedPartialCorrelationTestMatrix( - threshold=VariableReference(name="threshold"), - display_name="Extended Partial Correlation Test Matrix", - generator=PairsWithNeighboursGenerator( - comparison_settings=ComparisonSettings( - min=3, max=AS_MANY_AS_FIELDS - ), - shuffle_combinations=False, - ), - ), - *PC_ORIENTATION_RULES, - ComputeDirectEffectsInDAGsMultivariateRegression( - display_name="Compute Direct Effects in DAGs Multivariate Regression" - ), - ], - edge_types=PC_EDGE_TYPES, - extensions=[PC_GRAPH_UI_EXTENSION], - name="PC", - variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)], - ) - ) rdnv = self.seeded_random.normalvariate sample_generator = IIDSampleGenerator( edges=[ @@ -391,7 +360,7 @@ def test_track_triples_three_nodes_custom_pc(self): random=lambda: rdnv(0, 1), ) test_data, graph = sample_generator.generate(10000) - tst = algo() + tst = PCClassic() tst.create_graph_from_data(test_data) tst.create_all_possible_edges() pc_results = tst.execute_pipeline_steps() @@ -407,37 +376,6 @@ def test_track_triples_three_nodes_custom_pc(self): self.assertIn(len(triples), [6, 7, 8]) def test_track_triples_two_nodes_custom_pc_unconditionally_independent(self): - algo = graph_model_factory( - Algorithm( - pipeline_steps=[ - CalculatePearsonCorrelations( - display_name="Calculate Pearson Correlations" - ), - CorrelationCoefficientTest( - threshold=VariableReference(name="threshold"), - display_name="Correlation Coefficient Test", - ), - ExtendedPartialCorrelationTestMatrix( - threshold=VariableReference(name="threshold"), - display_name="Extended Partial Correlation Test Matrix", - generator=PairsWithNeighboursGenerator( - comparison_settings=ComparisonSettings( - min=3, max=AS_MANY_AS_FIELDS - ), - shuffle_combinations=False, - ), - ), - *PC_ORIENTATION_RULES, - ComputeDirectEffectsInDAGsMultivariateRegression( - display_name="Compute Direct Effects in DAGs Multivariate Regression" - ), - ], - edge_types=PC_EDGE_TYPES, - extensions=[PC_GRAPH_UI_EXTENSION], - name="PC", - variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)], - ) - ) rdnv = self.seeded_random.normalvariate sample_generator = IIDSampleGenerator( edges=[ @@ -447,7 +385,7 @@ def test_track_triples_two_nodes_custom_pc_unconditionally_independent(self): random=lambda: rdnv(0, 1), ) test_data, graph = sample_generator.generate(10000) - tst = algo() + tst = PCClassic() tst.create_graph_from_data(test_data) tst.create_all_possible_edges() pc_results = tst.execute_pipeline_steps()