Skip to content

Commit f6f3580

Browse files
committed
update cf (refactor taf_init) and first set of changes
1 parent 95c8198 commit f6f3580

31 files changed

+162
-127
lines changed

hbw/analysis/create_analysis.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def analysis_factory(configs: od.UniqueObjectIndex):
191191
"analysis", "task_family", "config", "configs", "dataset", "shift", "version",
192192
"calibrator", "calibrators", "selector", "producer", "producers",
193193
"ml_model", "ml_data", "ml_models",
194-
"weightprod", "inf_model",
194+
"weight_producer", "inf_model",
195195
"plot", "shift_sources", "shifts", "datasets",
196196
# MLTraining
197197
"calib", "sel", "prod",
@@ -249,7 +249,7 @@ def reorganize_parts(task, store_parts):
249249
"config", "configs",
250250
"producers", "prod",
251251
"ml_data", "ml_model", "ml_models",
252-
"weightprod", "inf_model",
252+
"weight_producer", "inf_model",
253253
"task_family",
254254
"calibrator", "producer",
255255
"shift", "dataset",

hbw/calibration/default.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ def ele_init(self: Calibrator) -> None:
7171
# TODO: deterministic FatJet seeds
7272
produces={"FatJet.pt"},
7373
)
74-
def fatjet(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array:
74+
def fatjet(self: Calibrator, events: ak.Array, task, **kwargs) -> ak.Array:
7575
"""
7676
FatJet calibrator, combining JEC and JER.
7777
Uses as JER uncertainty either only "Total" for MC or no uncertainty for data.
7878
"""
79-
if self.task.local_shift != "nominal":
79+
if task.local_shift != "nominal":
8080
raise Exception("FatJet Calibrator should not be run for shifts other than nominal")
8181

8282
# apply the fatjet JEC and JER
@@ -89,10 +89,6 @@ def fatjet(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array:
8989

9090
@fatjet.init
9191
def fatjet_init(self: Calibrator) -> None:
92-
if not self.task or self.task.task_family != "cf.CalibrateEvents":
93-
# init only required for task itself
94-
return
95-
9692
# derive calibrators to add settings once
9793
flag = f"custom_fatjet_calibs_registered_{self.cls_name}"
9894
if not self.config_inst.x(flag, False):
@@ -185,9 +181,6 @@ def jet_base(self: Calibrator, events: ak.Array, **kwargs) -> ak.Array:
185181

186182
@jet_base.init
187183
def jet_base_init(self: Calibrator) -> None:
188-
if not self.task or self.task.task_family != "cf.CalibrateEvents":
189-
# init only required for task itself
190-
return
191184

192185
# derive calibrators to add settings once
193186
flag = f"custom_jet_calibs_registered_{self.cls_name}"

hbw/calibration/jet.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,22 @@ def bjet_regression(
8888
@bjet_regression.init
8989
def bjet_regression_init(self: Calibrator):
9090
# setup only required by CalibrateEvents task itself
91-
if self.task and self.task.task_family == "cf.CalibrateEvents":
92-
self.b_tagger = {
93-
2: "deepjet",
94-
3: "particlenet",
95-
}[self.config_inst.x.run]
96-
97-
self.b_score_column = {
98-
"particlenet": "btagPNetB",
99-
"deepjet": "btagDeepFlavB",
100-
}[self.b_tagger]
101-
102-
self.b_reg_column = {
103-
"particlenet": "PNetRegPtRawCorr",
104-
"deepjet": "bRegCorr",
105-
}[self.b_tagger]
106-
107-
self.btag_wp = self.config_inst.x("btag_wp", "medium")
108-
109-
self.uses.add(f"Jet.{self.b_score_column}")
110-
self.uses.add(f"Jet.{self.b_reg_column}")
91+
self.b_tagger = {
92+
2: "deepjet",
93+
3: "particlenet",
94+
}[self.config_inst.x.run]
95+
96+
self.b_score_column = {
97+
"particlenet": "btagPNetB",
98+
"deepjet": "btagDeepFlavB",
99+
}[self.b_tagger]
100+
101+
self.b_reg_column = {
102+
"particlenet": "PNetRegPtRawCorr",
103+
"deepjet": "bRegCorr",
104+
}[self.b_tagger]
105+
106+
self.btag_wp = self.config_inst.x("btag_wp", "medium")
107+
108+
self.uses.add(f"Jet.{self.b_score_column}")
109+
self.uses.add(f"Jet.{self.b_reg_column}")

hbw/categorization/categories.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def catid_lep(
112112
muon = events.Muon
113113

114114
mask = (
115-
(ak.sum(electron.pt > 0, axis=-1) == self.n_electron) &
116-
(ak.sum(muon.pt > 0, axis=-1) == self.n_muon)
115+
(ak.sum(electron["pt"] > 0, axis=-1) == self.n_electron) &
116+
(ak.sum(muon["pt"] > 0, axis=-1) == self.n_muon)
117117
)
118118
return events, mask
119119

@@ -138,7 +138,7 @@ def catid_ge3lep(
138138
electron = events.Electron
139139
muon = events.Muon
140140

141-
mask = ak.sum(electron.pt > 0, axis=-1) + ak.sum(muon.pt > 0, axis=-1) >= 3
141+
mask = ak.sum(electron["pt"] > 0, axis=-1) + ak.sum(muon["pt"] > 0, axis=-1) >= 3
142142
return events, mask
143143

144144

hbw/columnflow_patches.py

+9
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,22 @@ def patch_csp_versioning():
8484
Patches the TaskArrayFunction to add the version to the string representation of the task.
8585
"""
8686

87+
from columnflow.tasks.framework.mixins import ArrayFunctionClassMixin
88+
8789
def TaskArrayFunction_str(self):
8890
version = self.version() if callable(getattr(self, "version", None)) else getattr(self, "version", None)
8991
if version and not isinstance(version, (int, str)):
9092
raise Exception(f"version must be an integer or string, but is {version} ({type(version)})")
9193
version_str = f"V{version}" if version is not None else ""
9294
return f"{self.cls_name}{version_str}"
9395

96+
def array_function_cls_repr(self, array_function):
97+
# NOTE: this might be a problem when we have identical names between different types of
98+
# TaskArrayFunctions...
99+
array_function_cls = TaskArrayFunction.get_cls(array_function)
100+
return TaskArrayFunction_str(array_function_cls)
101+
102+
ArrayFunctionClassMixin.array_function_cls_repr = array_function_cls_repr
94103
TaskArrayFunction.__str__ = TaskArrayFunction_str
95104
logger.info(
96105
"patched TaskArrayFunction.__str__ to include the CSP version attribute",

hbw/config/defaults_and_groups.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def default_ml_model(cls, container, task_params):
5151
if inference_model in (None, law.NO_STR, RESOLVE_DEFAULT):
5252
inference_model = container.x.default_inference_model
5353

54-
# get the default_ml_model from the inference_model_inst
55-
inference_model_inst = InferenceModel.get_cls(inference_model)
56-
default_ml_model = getattr(inference_model_inst, "ml_model_name", default_ml_model)
54+
# get the default_ml_model from the inference_model_cls
55+
inference_model_cls = InferenceModel.get_cls(inference_model)
56+
default_ml_model = getattr(inference_model_cls, "ml_model_name", default_ml_model)
5757

5858
return default_ml_model
5959

hbw/config/styling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def stylize_processes(config: od.Config) -> None:
205205
if "hh_" in proc.name.lower():
206206
proc.add_tag("is_signal")
207207
proc.unstack = True
208-
proc.scale = "stack"
208+
# proc.scale = "stack"
209209

210210
# labels used for ML categories
211211
proc.x.ml_label = ml_labels.get(proc.name, proc.name)

hbw/config/variables.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,21 @@ def add_variables(config: od.Config) -> None:
113113
expression="event",
114114
binning=(1, 0.0, 1.0e9),
115115
x_title="Event number",
116-
discrete_x=True,
116+
# discrete_x=True,
117117
)
118118
config.add_variable(
119119
name="run",
120120
expression="run",
121121
binning=(1, 100000.0, 500000.0),
122122
x_title="Run number",
123-
discrete_x=True,
123+
# discrete_x=True,
124124
)
125125
config.add_variable(
126126
name="lumi",
127127
expression="luminosityBlock",
128128
binning=(1, 0.0, 5000.0),
129129
x_title="Luminosity block",
130-
discrete_x=True,
130+
# discrete_x=True,
131131
)
132132

133133
#

hbw/ml/stats.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def del_sub_proc_stats(
5050
def ml_preparation(
5151
self: Producer,
5252
events: ak.Array,
53+
task,
5354
stats: dict = {},
5455
fold_indices: ak.Array | None = None,
5556
ml_model_inst: MLModel | None = None,
@@ -58,7 +59,7 @@ def ml_preparation(
5859
"""
5960
Producer that is run as part of PrepareMLEvents to collect relevant stats
6061
"""
61-
if self.task.task_family == "cf.PrepareMLEvents":
62+
if task.task_family == "cf.PrepareMLEvents":
6263
# pass category mask to only use events that belong to the main "signal region"
6364
# NOTE: we could also just require the pre_ml_cats Producer here
6465
sr_categorizer = catid_sr if self.config_inst.has_tag("is_sl") else catid_mll_low
@@ -70,7 +71,7 @@ def ml_preparation(
7071
"num_events": Ellipsis, # all events
7172
}
7273

73-
if self.task.dataset_inst.is_mc:
74+
if task.dataset_inst.is_mc:
7475
# full event weight
7576
events, weight = self[default_weight_producer](events, **kwargs)
7677
events = set_ak_column_f32(events, "event_weight", weight)
@@ -121,8 +122,6 @@ def ml_preparation(
121122

122123
@ml_preparation.init
123124
def ml_preparation_init(self):
124-
# TODO: we access self.task.dataset_inst instead of self.dataset_inst due to an issue
125-
# with the preparation producer initialization
126125
if not getattr(self, "dataset_inst", None) or self.dataset_inst.is_data:
127126
return
128127

hbw/production/categories.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import law
88

99
from columnflow.production import Producer, producer
10-
from columnflow.util import maybe_import, InsertableDict
10+
from columnflow.util import maybe_import
1111
from columnflow.production.categories import category_ids
1212

1313
from hbw.config.categories import add_categories_production, add_categories_ml
@@ -58,16 +58,18 @@ def cats_ml(self: Producer, events: ak.Array, **kwargs) -> ak.Array:
5858

5959

6060
@cats_ml.requires
61-
def cats_ml_reqs(self: Producer, reqs: dict) -> None:
61+
def cats_ml_reqs(self: Producer, task: law.Task, reqs: dict) -> None:
6262
if "ml" in reqs:
6363
return
6464

6565
from columnflow.tasks.ml import MLEvaluation
66-
reqs["ml"] = MLEvaluation.req(self.task, ml_model=self.ml_model_name)
66+
reqs["ml"] = MLEvaluation.req(task, ml_model=self.ml_model_name)
6767

6868

6969
@cats_ml.setup
70-
def cats_ml_setup(self: Producer, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None:
70+
def cats_ml_setup(
71+
self: Producer, task: law.Task, reqs: dict, inputs: dict, reader_targets: law.util.InsertableDict,
72+
) -> None:
7173
reader_targets["mlcolumns"] = inputs["ml"]["mlcolumns"]
7274

7375

hbw/production/dataset_normalization.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
"""
44
Column production methods related to sample normalization event weights.
55
"""
6+
import law
67

78
from columnflow.production import Producer, producer
8-
from columnflow.util import maybe_import, InsertableDict
9+
from columnflow.util import maybe_import
910
from columnflow.columnar_util import set_ak_column
1011

1112
np = maybe_import("numpy")
@@ -39,7 +40,7 @@ def dataset_normalization_weight(self: Producer, events: ak.Array, **kwargs) ->
3940

4041

4142
@dataset_normalization_weight.requires
42-
def dataset_normalization_weight_requires(self: Producer, reqs: dict) -> None:
43+
def dataset_normalization_weight_requires(self: Producer, task: law.Task, reqs: dict) -> None:
4344
"""
4445
Adds the requirements needed by the underlying py:attr:`task` to access selection stats into
4546
*reqs*.
@@ -49,17 +50,18 @@ def dataset_normalization_weight_requires(self: Producer, reqs: dict) -> None:
4950
# (i.e. all datasets that might contain any of the sub processes found in a dataset)
5051
from columnflow.tasks.selection import MergeSelectionStats
5152
reqs["selection_stats"] = MergeSelectionStats.req(
52-
self.task,
53+
task,
5354
branch=-1,
5455
)
5556

5657

5758
@dataset_normalization_weight.setup
5859
def dataset_normalization_weight_setup(
5960
self: Producer,
61+
task: law.Task,
6062
reqs: dict,
6163
inputs: dict,
62-
reader_targets: InsertableDict,
64+
reader_targets: law.util.InsertableDict,
6365
) -> None:
6466
"""
6567
Load inclusive selection stats and cross sections for the normalization weight calculation.

hbw/production/gen_v.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
from __future__ import annotations
88

9+
import law
10+
911
from columnflow.production import Producer, producer
10-
from columnflow.util import maybe_import, InsertableDict, DotDict
12+
from columnflow.util import maybe_import, DotDict
1113
from columnflow.columnar_util import set_ak_column
1214

1315

@@ -206,24 +208,19 @@ def vjets_weight_skip(self: Producer) -> bool:
206208
)
207209

208210

209-
@vjets_weight.init
210-
def vjets_weight_init(self: Producer) -> None:
211-
shift_inst = getattr(self, "local_shift_inst", None)
212-
if not shift_inst:
213-
return
214-
215-
216211
@vjets_weight.requires
217-
def vjets_weight_requires(self: Producer, reqs: dict) -> None:
212+
def vjets_weight_requires(self: Producer, task: law.Task, reqs: dict) -> None:
218213
if "external_files" in reqs:
219214
return
220215

221216
from columnflow.tasks.external import BundleExternalFiles
222-
reqs["external_files"] = BundleExternalFiles.req(self.task)
217+
reqs["external_files"] = BundleExternalFiles.req(task)
223218

224219

225220
@vjets_weight.setup
226-
def vjets_weight_setup(self: Producer, reqs: dict, inputs: dict, reader_targets: InsertableDict) -> None:
221+
def vjets_weight_setup(
222+
self: Producer, task: law.Task, reqs: dict, inputs: dict, reader_targets: law.util.InsertableDict,
223+
) -> None:
227224
bundle = reqs["external_files"]
228225

229226
# create the L1 prefiring weight evaluator

hbw/production/jets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def jetId(self: Producer, events: ak.Array, **kwargs) -> ak.Array:
2929
https://twiki.cern.ch/twiki/bin/view/CMS/JetID13p6TeV?rev=21
3030
"""
3131
abseta = abs(events.Jet.eta)
32-
print("start")
32+
3333
# baseline mask (abseta < 2.7)
3434
passJetId_Tight = (events.Jet.jetId & 2 == 2)
3535

0 commit comments

Comments
 (0)