From 74a922c210014e98a4ccef333df16a2ed4bf9ed4 Mon Sep 17 00:00:00 2001 From: Lilith Wittmann Date: Wed, 5 Feb 2025 22:45:24 +0100 Subject: [PATCH] feat(core): add a step execution mode "apply_synchronous" which synchronously applies step results before the next step is processed - All pipeline steps which are not parallel executed can receive apply_synchronous=True in their pipeline step definition which ensures that before the next test in this step is executed the graph object gets applied. - This is disabled by default - example can be found in tests/tes_synchronous_graphs.py --- causy/graph_model.py | 28 +++++++++-- causy/interfaces.py | 21 +++++++-- tests/test_synchronous_graphs.py | 80 ++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 10 deletions(-) create mode 100644 tests/test_synchronous_graphs.py diff --git a/causy/graph_model.py b/causy/graph_model.py index 074ef81..2f7222a 100644 --- a/causy/graph_model.py +++ b/causy/graph_model.py @@ -484,6 +484,14 @@ def execute_pipeline_step( else: # this is the only mode which supports unapplied actions to be passed to the next pipeline step (for now) # which are sometimes needed for e.g. conflict resolution + + is_synchronous = False + + if hasattr(test_fn, "apply_synchronous"): + # ensure that the graph gets changes applied synchronously - so before the next element is executed + if test_fn.apply_synchronous: + is_synchronous = True + iterator = [ i for i in [ @@ -499,11 +507,21 @@ def execute_pipeline_step( if rn_fn.needs_unapplied_actions: i.append(local_results) local_results.append(unpack_run(i)) - actions_taken_current, all_actions_current = self._take_action( - local_results, dry_run=not apply_to_graph - ) - actions_taken.extend(actions_taken_current) - all_actions.extend(all_actions_current) + + if is_synchronous: + actions_taken_current, all_actions_current = self._take_action( + local_results, dry_run=not apply_to_graph + ) + actions_taken.extend(actions_taken_current) + all_actions.extend(all_actions_current) + local_results = [] + + if not is_synchronous: + actions_taken_current, all_actions_current = self._take_action( + local_results, dry_run=not apply_to_graph + ) + actions_taken.extend(actions_taken_current) + all_actions.extend(all_actions_current) return actions_taken, all_actions diff --git a/causy/interfaces.py b/causy/interfaces.py index 2d81c4f..0fbf88f 100644 --- a/causy/interfaces.py +++ b/causy/interfaces.py @@ -343,19 +343,27 @@ def name(self) -> str: class PipelineStepInterface(ABC, BaseModel, Generic[PipelineStepInterfaceType]): generator: Optional[GeneratorInterface] = None - threshold: Optional[FloatParameter] = DEFAULT_THRESHOLD - chunk_size_parallel_processing: IntegerParameter = 1 - parallel: BoolParameter = True + threshold: Optional[FloatParameter] = DEFAULT_THRESHOLD # threshold for the test + chunk_size_parallel_processing: IntegerParameter = ( + 1 # chunk size for parallel processing + ) + parallel: BoolParameter = True # if True, the pipeline step will be executed in parallel (only works non synchronous) - display_name: Optional[StringParameter] = None + display_name: Optional[StringParameter] = None # display name of the pipeline step - needs_unapplied_actions: Optional[BoolParameter] = False + needs_unapplied_actions: Optional[ + BoolParameter + ] = False # if True, the pipeline step needs unapplied actions to be passed to it + apply_synchronous: Optional[ + BoolParameter + ] = False # if True, the result of the pipeline step will be applied synchronously (only works non chunked and non parallel) def __init__( self, threshold: Optional[FloatParameter] = None, generator: Optional[GeneratorInterface] = None, chunk_size_parallel_processing: Optional[IntegerParameter] = None, + apply_synchronous: Optional[BoolParameter] = None, parallel: Optional[BoolParameter] = None, display_name: Optional[StringParameter] = None, **kwargs, @@ -370,6 +378,9 @@ def __init__( if chunk_size_parallel_processing: self.chunk_size_parallel_processing = chunk_size_parallel_processing + if apply_synchronous: + self.apply_synchronous = apply_synchronous + if parallel: self.parallel = parallel diff --git a/tests/test_synchronous_graphs.py b/tests/test_synchronous_graphs.py new file mode 100644 index 0000000..c52c211 --- /dev/null +++ b/tests/test_synchronous_graphs.py @@ -0,0 +1,80 @@ +from causy.causal_discovery.constraint.algorithms.pc import ( + PC_ORIENTATION_RULES, + PC_EDGE_TYPES, + PC_GRAPH_UI_EXTENSION, + PC_DEFAULT_THRESHOLD, +) +from causy.causal_discovery.constraint.independence_tests.common import ( + CorrelationCoefficientTest, + PartialCorrelationTest, + ExtendedPartialCorrelationTestMatrix, +) +from causy.causal_effect_estimation.multivariate_regression import ( + ComputeDirectEffectsInDAGsMultivariateRegression, +) +from causy.common_pipeline_steps.calculation import CalculatePearsonCorrelations +from causy.graph_model import graph_model_factory +from causy.models import Algorithm +from causy.sample_generator import IIDSampleGenerator, SampleEdge, NodeReference +from causy.variables import VariableReference, FloatVariable +from tests.utils import CausyTestCase + + +class PCTestTestCase(CausyTestCase): + SEED = 1 + + def _sample_generator(self): + rdnv = self.seeded_random.normalvariate + return IIDSampleGenerator( + edges=[ + SampleEdge(NodeReference("X"), NodeReference("Y"), 5), + SampleEdge(NodeReference("X"), NodeReference("Z"), 8), + SampleEdge(NodeReference("X"), NodeReference("W"), 4), + ], + random=lambda: rdnv(0, 1), + ) + + SYNCHRONOUS_PC = graph_model_factory( + Algorithm( + pipeline_steps=[ + CalculatePearsonCorrelations( + display_name="Calculate Pearson Correlations" + ), + CorrelationCoefficientTest( + threshold=VariableReference(name="threshold"), + display_name="Correlation Coefficient Test", + apply_synchronous=True, + ), + PartialCorrelationTest( + threshold=VariableReference(name="threshold"), + display_name="Partial Correlation Test", + apply_synchronous=True, + ), + ExtendedPartialCorrelationTestMatrix( + threshold=VariableReference(name="threshold"), + display_name="Extended Partial Correlation Test Matrix", + apply_synchronous=True, + ), + *PC_ORIENTATION_RULES, + ComputeDirectEffectsInDAGsMultivariateRegression( + display_name="Compute Direct Effects in DAGs Multivariate Regression" + ), + ], + edge_types=PC_EDGE_TYPES, + extensions=[PC_GRAPH_UI_EXTENSION], + name="PC", + variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)], + ) + ) + + def test_execute_pipeline(self): + model = self._sample_generator() + data, graph = model.generate(100) + + pc = self.SYNCHRONOUS_PC() + pc.create_graph_from_data(data) + pc.create_graph_from_data(data) + pc.create_all_possible_edges() + pc.execute_pipeline_steps() + + self.assertGraphStructureIsEqual(pc.graph, graph)