Skip to content

Commit a2a8396

Browse files
committed
minor fixes in custom tasks
1 parent 2410e46 commit a2a8396

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

hbw/tasks/corrections.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from columnflow.tasks.framework.mixins import (
1111
SelectorMixin, CalibratorsMixin,
1212
)
13+
from columnflow.tasks.framework.remote import RemoteWorkflow
1314
from columnflow.tasks.selection import MergeSelectionStats
1415
from columnflow.util import maybe_import, dev_sandbox
1516
from columnflow.config_util import get_datasets_from_process
@@ -25,12 +26,16 @@
2526

2627
class GetBtagNormalizationSF(
2728
HBWTask,
28-
ConfigTask,
2929
SelectorMixin,
3030
CalibratorsMixin,
31-
# law.LocalWorkflow,
31+
ConfigTask,
32+
law.LocalWorkflow,
33+
RemoteWorkflow,
3234
):
33-
reqs = Requirements(MergeSelectionStats=MergeSelectionStats)
35+
reqs = Requirements(
36+
RemoteWorkflow.reqs,
37+
MergeSelectionStats=MergeSelectionStats,
38+
)
3439

3540
store_as_dict = False
3641

@@ -47,6 +52,10 @@ class GetBtagNormalizationSF(
4752
description="Processes to consider for the scale factors",
4853
)
4954

55+
def create_branch_map(self):
56+
# single branch without payload
57+
return {0: None}
58+
5059
@cached_property
5160
def process_insts(self):
5261
processes = [self.config_inst.get_process(process) for process in self.processes]
@@ -59,10 +68,22 @@ def dataset_insts(self):
5968
datasets.update(get_datasets_from_process(self.config_inst, process_inst))
6069
return list(datasets)
6170

71+
def workflow_requires(self):
72+
reqs = super().workflow_requires()
73+
reqs["selection_stats"] = {
74+
dataset.name: self.reqs.MergeSelectionStats.req_different_branching(
75+
self,
76+
dataset=dataset.name,
77+
branch=-1,
78+
)
79+
for dataset in self.dataset_insts
80+
}
81+
return reqs
82+
6283
def requires(self):
6384
reqs = {}
6485
reqs["selection_stats"] = {
65-
dataset.name: self.reqs.MergeSelectionStats.req(
86+
dataset.name: self.reqs.MergeSelectionStats.req_different_branching(
6687
self,
6788
dataset=dataset.name,
6889
branch=-1,

hbw/tasks/plotting.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import law
1111
import order as od
1212

13-
from columnflow.tasks.framework.base import Requirements, ShiftTask
13+
from columnflow.tasks.framework.base import Requirements, ShiftTask, ConfigTask
1414
from columnflow.tasks.framework.mixins import (
1515
CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin,
1616
CategoriesMixin, DatasetsProcessesMixin,
@@ -120,6 +120,7 @@ class PlotVariablesMultiWeightProducer(
120120
SelectorStepsMixin,
121121
CalibratorsMixin,
122122
ShiftTask,
123+
ConfigTask,
123124
law.LocalWorkflow,
124125
RemoteWorkflow,
125126
):

0 commit comments

Comments
 (0)