Skip to content

Commit 2ab8ec5

Browse files
committed
generalize btag reweighting corrections
1 parent ede1cbe commit 2ab8ec5

File tree

1 file changed

+35
-22
lines changed

1 file changed

+35
-22
lines changed

hbw/tasks/corrections.py

+35-22
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import law
6+
import luigi
67

78
from functools import cached_property
89

@@ -39,10 +40,18 @@ class GetBtagNormalizationSF(
3940

4041
store_as_dict = False
4142

42-
rescale_mode = "nevents"
43-
4443
reweighting_step = "selected_no_bjet"
4544

45+
rescale_mode = luigi.ChoiceParameter(
46+
default="nevents",
47+
choices=("nevents", "xs"),
48+
)
49+
base_key = luigi.ChoiceParameter(
50+
default="rescaled_sum_mc_weight",
51+
# NOTE: "num_events" does not work because I did not store the corresponding key in the stats :/
52+
choices=("rescaled_sum_mc_weight", "sum_mc_weight", "num_events"),
53+
)
54+
4655
# default sandbox, might be overwritten by selector function
4756
sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox"))
4857

@@ -98,6 +107,9 @@ def store_parts(self):
98107
processes_repr = "__".join(self.processes)
99108
parts.insert_before("version", "processes", processes_repr)
100109

110+
significant_params = (self.rescale_mode, self.base_key)
111+
parts.insert_before("version", "params", "__".join(significant_params))
112+
101113
return parts
102114

103115
def output(self):
@@ -143,23 +155,24 @@ def safe_div(num, den):
143155
)
144156

145157
# rescale the histograms
146-
for dataset, hists in hists_per_dataset.items():
147-
process = self.config_inst.get_dataset(dataset).processes.get_first()
148-
if self.rescale_mode == "xs":
149-
# scale such that the sum of weights is the cross section
150-
xs = process.get_xsec(self.config_inst.campaign.ecm).nominal
151-
dataset_factor = xs / hists["sum_mc_weight"][{"steps": "Initial"}].value
152-
elif self.rescale_mode == "nevents":
153-
# scale such that mean weight is 1
154-
n_events = hists["num_events"][{"steps": self.reweighting_step}].value
155-
dataset_factor = n_events / hists["sum_mc_weight"][{"steps": self.reweighting_step}].value
156-
else:
157-
raise ValueError(f"Invalid rescale mode {self.rescale_mode}")
158-
for key in tuple(hists.keys()):
159-
if "sum" not in key:
160-
continue
161-
h = hists[key].copy() * dataset_factor
162-
hists[f"rescaled_{key}"] = h
158+
if "rescaled" in self.base_key:
159+
for dataset, hists in hists_per_dataset.items():
160+
process = self.config_inst.get_dataset(dataset).processes.get_first()
161+
if self.rescale_mode == "xs":
162+
# scale such that the sum of weights is the cross section
163+
xs = process.get_xsec(self.config_inst.campaign.ecm).nominal
164+
dataset_factor = xs / hists["sum_mc_weight"][{"steps": "Initial"}].value
165+
elif self.rescale_mode == "nevents":
166+
# scale such that mean weight is 1
167+
n_events = hists["num_events"][{"steps": self.reweighting_step}].value
168+
dataset_factor = n_events / hists["sum_mc_weight"][{"steps": self.reweighting_step}].value
169+
else:
170+
raise ValueError(f"Invalid rescale mode {self.rescale_mode}")
171+
for key in tuple(hists.keys()):
172+
if "sum" not in key:
173+
continue
174+
h = hists[key].copy() * dataset_factor
175+
hists[f"rescaled_{key}"] = h
163176

164177
# if necessary, merge the histograms across datasets
165178
if len(hists_per_dataset) > 1:
@@ -183,18 +196,18 @@ def safe_div(num, den):
183196
# ("",),
184197
):
185198
mode_str = "_".join(mode)
186-
numerator = merged_hists["rescaled_sum_mc_weight_per_process_ht_njet_nhf"]
199+
numerator = merged_hists[f"{self.base_key}_per_process_ht_njet_nhf"]
187200
numerator = self.reduce_hist(numerator, mode).values()
188201

189202
for key in merged_hists.keys():
190203
if (
191-
not key.startswith("rescaled_sum_mc_weight_btag_weight") or
204+
not key.startswith(f"{self.base_key}_btag_weight") or
192205
not key.endswith("_per_process_ht_njet_nhf")
193206
):
194207
continue
195208

196209
# extract the weight name
197-
weight_name = key.replace("rescaled_sum_mc_weight_", "").replace("_per_process_ht_njet_nhf", "")
210+
weight_name = key.replace(f"{self.base_key}_", "").replace("_per_process_ht_njet_nhf", "")
198211

199212
# create the scale factor histogram
200213
h = merged_hists[key]

0 commit comments

Comments
 (0)