Skip to content

Commit 9844a1a

Browse files
committed
remove inconsitencies between config category and inference category names
1 parent 17a9501 commit 9844a1a

File tree

3 files changed

+72
-57
lines changed

3 files changed

+72
-57
lines changed

hbw/config/defaults_and_groups.py

+27-32
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,22 @@ def set_config_defaults_and_groups(config_inst):
156156
}
157157

158158
# category groups for conveniently looping over certain categories
159-
# (used during plotting)
159+
# (used during plotting and for rebinning)
160160
config_inst.x.category_groups = {
161161
"much": ["1mu", "1mu__resolved", "1mu__boosted"],
162162
"ech": ["1e", "1e__resolved", "1e__boosted"],
163163
"default": ["incl", "1e", "1mu"],
164164
"test": ["incl", "1e"],
165165
"dilep": ["incl", "2e", "2mu", "emu"],
166+
"SR": ("1e__ml_ggHH_kl_1_kt_1_sl_hbbhww", "1mu__ml_ggHH_kl_1_kt_1_sl_hbbhww"),
167+
"vbfSR": ("1e__ml_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww", "1mu__ml_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww"),
168+
"SR_resolved": ("1e__ml_resolved_ggHH_kl_1_kt_1_sl_hbbhww", "1mu__ml_resolved_ggHH_kl_1_kt_1_sl_hbbhww"),
169+
"SR_boosted": ("1e__ml_boosted_ggHH_kl_1_kt_1_sl_hbbhww", "1mu__ml_boosted_ggHH_kl_1_kt_1_sl_hbbhww"),
170+
"vbfSR_resolved": ("1e__ml_resolved_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww", "1mu__ml_resolved_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww"), # noqa
171+
"vbfSR_boosted": ("1e__ml_boosted_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww", "1mu__ml_boosted_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww"), # noqa
172+
"BR": ("1e__ml_tt", "1e__ml_st", "1e__ml_v_lep", "1mu__ml_tt", "1mu__ml_st", "1mu__ml_v_lep"),
173+
"SR_dl": ("2e__ml_ggHH_kl_5_kt_1_dl_hbbhww", "2mu__ml_ggHH_kl_5_kt_1_dl_hbbhww"),
174+
"BR_dl": ("2e__ml_t_bkg", "2e__ml_v_lep", "2mu__ml_t_bkg", "2mu__ml_v_lep"),
166175
}
167176

168177
# variable groups for conveniently looping over certain variables
@@ -265,21 +274,7 @@ def set_config_defaults_and_groups(config_inst):
265274
"cols": ["mli", "features"],
266275
}
267276

268-
# configuration regarding rebinning
269-
config_inst.x.inference_category_groups = {
270-
"SR": ("cat_1e_ggHH_kl_1_kt_1_sl_hbbhww", "cat_1mu_ggHH_kl_1_kt_1_sl_hbbhww"),
271-
"vbfSR": ("cat_1e_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww", "cat_1mu_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww"),
272-
"SR_resolved": ("cat_1e_resolved_ggHH_kl_1_kt_1_sl_hbbhww", "cat_1mu_resolved_ggHH_kl_1_kt_1_sl_hbbhww"),
273-
"SR_boosted": ("cat_1e_boosted_ggHH_kl_1_kt_1_sl_hbbhww", "cat_1mu_boosted_ggHH_kl_1_kt_1_sl_hbbhww"),
274-
"vbfSR_resolved": ("cat_1e_resolved_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww",
275-
"cat_1mu_resolved_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww"),
276-
"vbfSR_boosted": ("cat_1e_boosted_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww",
277-
"cat_1mu_boosted_qqHH_CV_1_C2V_1_kl_1_sl_hbbhww"),
278-
"BR": ("cat_1e_tt", "cat_1e_st", "cat_1e_v_lep", "cat_1mu_tt", "cat_1mu_st", "cat_1mu_v_lep"),
279-
"SR_dl": ("cat_2e_ggHH_kl_5_kt_1_dl_hbbhww", "cat_2mu_ggHH_kl_5_kt_1_dl_hbbhww"),
280-
"BR_dl": ("cat_2e_t_bkg", "cat_2e_v_lep", "cat_2mu_t_bkg", "cat_2mu_v_lep"),
281-
}
282-
277+
# groups are defined via config.x.category_groups
283278
config_inst.x.default_bins_per_category = {
284279
"SR": 10,
285280
"vbfSR": 5,
@@ -290,14 +285,14 @@ def set_config_defaults_and_groups(config_inst):
290285
"vbfSR_boosted": 3,
291286
# "SR_dl": 10,
292287
# "BR_dl": 3,
293-
# "cat_1e_ggHH_kl_1_kt_1_sl_hbbhww": 10,
294-
# "cat_1e_tt": 3,
295-
# "cat_1e_st": 3,
296-
# "cat_1e_v_lep": 3,
297-
# "cat_1mu_ggHH_kl_1_kt_1_sl_hbbhww": 10,
298-
# "cat_1mu_tt": 3,
299-
# "cat_1mu_st": 3,
300-
# "cat_1mu_v_lep": 3,
288+
# "1e__ml_ggHH_kl_1_kt_1_sl_hbbhww": 10,
289+
# "1e__ml_tt": 3,
290+
# "1e__ml_st": 3,
291+
# "1e__ml_v_lep": 3,
292+
# "1mu__ml_ggHH_kl_1_kt_1_sl_hbbhww": 10,
293+
# "1mu__ml_tt": 3,
294+
# "1mu__ml_st": 3,
295+
# "1mu__ml_v_lep": 3,
301296
}
302297

303298
config_inst.x.inference_category_rebin_processes = {
@@ -310,12 +305,12 @@ def set_config_defaults_and_groups(config_inst):
310305
"BR": lambda proc_name: "hbbhww" not in proc_name,
311306
# "SR_dl": ("ggHH_kl_5_kt_1_dl_hbbhww",),
312307
# "BR_dl": lambda proc_name: "hbbhww" not in proc_name,
313-
# "cat_1e_ggHH_kl_1_kt_1_sl_hbbhww": ("ggHH_kl_1_kt_1_sl_hbbhww", "qqHH_CV_1_C2V_1_kl_1_sl_hbbhww"),
314-
# "cat_1e_tt": lambda proc_name: "hbbhww" not in proc_name,
315-
# "cat_1e_st": lambda proc_name: "hbbhww" not in proc_name,
316-
# "cat_1e_v_lep": lambda proc_name: "hbbhww" not in proc_name,
317-
# "cat_1mu_ggHH_kl_1_kt_1_sl_hbbhww": ("ggHH_kl_1_kt_1_sl_hbbhww", "qqHH_CV_1_C2V_1_kl_1_sl_hbbhww"),
318-
# "cat_1mu_tt": lambda proc_name: "hbbhww" not in proc_name,
319-
# "cat_1mu_st": lambda proc_name: "hbbhww" not in proc_name,
320-
# "cat_1mu_v_lep": lambda proc_name: "hbbhww" not in proc_name,
308+
# "1e__ml_ggHH_kl_1_kt_1_sl_hbbhww": ("ggHH_kl_1_kt_1_sl_hbbhww", "qqHH_CV_1_C2V_1_kl_1_sl_hbbhww"),
309+
# "1e__ml_tt": lambda proc_name: "hbbhww" not in proc_name,
310+
# "1e__ml_st": lambda proc_name: "hbbhww" not in proc_name,
311+
# "1e__ml_v_lep": lambda proc_name: "hbbhww" not in proc_name,
312+
# "1mu__ml_ggHH_kl_1_kt_1_sl_hbbhww": ("ggHH_kl_1_kt_1_sl_hbbhww", "qqHH_CV_1_C2V_1_kl_1_sl_hbbhww"),
313+
# "1mu__ml_tt": lambda proc_name: "hbbhww" not in proc_name,
314+
# "1mu__ml_st": lambda proc_name: "hbbhww" not in proc_name,
315+
# "1mu__ml_v_lep": lambda proc_name: "hbbhww" not in proc_name,
321316
}

hbw/inference/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ def print_model(self):
6565

6666
def cat_name(self: InferenceModel, config_cat_inst: od.Category):
6767
""" Function to determine inference category name from config category """
68-
root_cats = config_cat_inst.x.root_cats
69-
return "cat_" + "_".join(root_cats.values())
68+
# Note: the name of the inference category cannot start with a Number
69+
# -> use config category with single letter added at the start?
70+
return f"cat_{config_cat_inst.name}"
7071

7172
def config_variable(self: InferenceModel, config_cat_inst: od.Config):
7273
""" Function to determine inference variable name from config category """

hbw/tasks/inference.py

+42-23
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# import luigi
1212
import law
13+
import order as od
1314

1415
from columnflow.tasks.framework.base import Requirements, RESOLVE_DEFAULT
1516
from columnflow.tasks.framework.parameters import SettingsParameter
@@ -24,6 +25,9 @@
2425
array = maybe_import("array")
2526

2627

28+
logger = law.logger.get_logger(__name__)
29+
30+
2731
def get_hist_name(cat_name: str, proc_name: str, syst_name: str | None = None) -> str:
2832
hist_name = f"{cat_name}/{proc_name}"
2933
if syst_name:
@@ -72,7 +76,7 @@ def get_rebin_values(hist, N_bins_final: int = 10):
7276

7377
# determine events per bin the final histogram should have
7478
events_per_bin = hist.Integral() / N_bins_final
75-
print(f"============ {round(events_per_bin, 3)} events per bin")
79+
logger.info(f"============ {round(events_per_bin, 3)} events per bin")
7680

7781
# bookkeeping number of bins and number of events
7882
bin_count = 1
@@ -88,17 +92,17 @@ def get_rebin_values(hist, N_bins_final: int = 10):
8892

8993
N_events += hist.GetBinContent(i)
9094
if i % 100 == 0:
91-
print(f"========== Bin {i} of {N_bins_input}, {N_events} events")
95+
logger.info(f"========== Bin {i} of {N_bins_input}, {N_events} events")
9296
if N_events >= events_per_bin * bin_count:
9397
# when *N_events* surpasses threshold, append the corresponding bin edge and count
94-
print(f"++++++++++ Append bin edge {bin_count} of {N_bins_final} at edge {hist.GetBinLowEdge(i)}")
98+
logger.info(f"++++++++++ Append bin edge {bin_count} of {N_bins_final} at edge {hist.GetBinLowEdge(i)}")
9599
rebin_values.append(hist.GetBinLowEdge(i + 1))
96100
bin_count += 1
97101

98102
# final bin is x_max
99103
x_max = hist.GetBinLowEdge(N_bins_input + 1)
100104
rebin_values.append(x_max)
101-
print(f"final bin edges: {rebin_values}")
105+
logger.info(f"final bin edges: {rebin_values}")
102106
return rebin_values
103107

104108

@@ -127,28 +131,28 @@ def check_empty_bins(hist, fill_empty: float = 1e-5, required_entries: int = 3)
127131
value = hist.GetBinContent(i)
128132
error = hist.GetBinError(i)
129133
if value <= 0:
130-
print(f"==== Found empty or negative bin {i}, (value: {value}, error: {error})")
134+
logger.info(f"==== Found empty or negative bin {i}, (value: {value}, error: {error})")
131135
count += 1
132136
if fill_empty >= 0:
133-
print(f" Bin {i} value + error will be filled with {fill_empty}")
137+
logger.info(f" Bin {i} value + error will be filled with {fill_empty}")
134138
hist.SetBinContent(i, fill_empty)
135139
hist.SetBinError(i, fill_empty)
136140

137141
if error > max_error(value):
138-
print(
142+
logger.warning(
139143
f"==== Bin {i} has less than {required_entries} entries (value: {value}, error: {error}); "
140144
f"Rebinning procedure might have to be restarted with less bins than {hist.GetNbinsX()}",
141145
)
142146
return count
143147

144148

145149
def print_hist(hist, max_bins: int = 20):
146-
print("Printing bin number, lower edge and bin content")
150+
logger.info("Printing bin number, lower edge and bin content")
147151
for i in range(0, hist.GetNbinsX() + 2):
148152
if i > max_bins:
149153
return
150154

151-
print(f"{i} \t {hist.GetBinLowEdge(i)} \t {hist.GetBinContent(i)}")
155+
logger.info(f"{i} \t {hist.GetBinLowEdge(i)} \t {hist.GetBinContent(i)}")
152156

153157

154158
class ModifyDatacardsFlatRebin(
@@ -185,11 +189,17 @@ def resolve_param_values(cls, params):
185189
params = super().resolve_param_values(params)
186190

187191
if config_inst := params.get("config_inst"):
188-
def resolve_category_groups(param, group_str):
192+
def resolve_category_groups(param):
193+
outp_param = {}
189194
for cat_name in list(param.keys()):
190-
if resolved_cats := config_inst.x(group_str, {}).get(cat_name, None):
195+
resolved_cats = cls.find_config_objects(
196+
(cat_name,), config_inst, od.Category,
197+
object_groups=config_inst.x.category_groups, deep=True,
198+
)
199+
if resolved_cats:
191200
for resolved_cat in law.util.make_tuple(resolved_cats):
192-
param[resolved_cat] = param[cat_name]
201+
outp_param[resolved_cat] = param[cat_name]
202+
return outp_param
193203

194204
# resolve default and groups for `bins_per_category`
195205
params["bins_per_category"] = cls.resolve_config_default(
@@ -198,7 +208,7 @@ def resolve_category_groups(param, group_str):
198208
container=config_inst,
199209
default_str="default_bins_per_category",
200210
)
201-
resolve_category_groups(params["bins_per_category"], "inference_category_groups")
211+
params["bins_per_category"] = resolve_category_groups(params["bins_per_category"])
202212

203213
# set `inference_category_rebin_processes` as parameter and resolve groups
204214
params["inference_category_rebin_processes"] = cls.resolve_config_default(
@@ -207,25 +217,34 @@ def resolve_category_groups(param, group_str):
207217
container=config_inst,
208218
default_str="inference_category_rebin_processes",
209219
)
210-
resolve_category_groups(params["inference_category_rebin_processes"], "inference_category_groups")
211-
220+
params["inference_category_rebin_processes"] = resolve_category_groups(
221+
params["inference_category_rebin_processes"],
222+
)
212223
return params
213224

214225
def get_n_bins(self, DEFAULT_N_BINS=8):
215226
""" Method to get the requested number of bins for the current category. Defaults to *DEFAULT_N_BINS*"""
216-
cat_name = self.branch_data.name
217-
return int(self.bins_per_category.get(cat_name, DEFAULT_N_BINS))
227+
config_category = self.branch_data.config_category
228+
n_bins = self.bins_per_category.get(config_category, None)
229+
if not n_bins:
230+
logger.warning(f"No number of bins setup for category {config_category}; will default to {DEFAULT_N_BINS}.")
231+
n_bins = DEFAULT_N_BINS
232+
return int(n_bins)
218233

219234
def get_rebin_processes(self):
220235
"""
221236
Method to resolve the requested processes on which to flatten the histograms of the current category.
222237
Defaults to all processes of the current category.
223238
"""
224-
cat_name = self.branch_data.name
239+
config_category = self.branch_data.config_category
225240
proc_names = [proc.name for proc in self.branch_data.processes]
226241

227-
rebin_process_condition = self.inference_category_rebin_processes.get(cat_name, None)
242+
rebin_process_condition = self.inference_category_rebin_processes.get(config_category, None)
228243
if not rebin_process_condition:
244+
logger.warning(
245+
f"No rebin condition found for category {config_category}; rebinning will be flat "
246+
f"on all processes {proc_names}",
247+
)
229248
return proc_names
230249

231250
# transform `rebin_process_condition` into Callable if required
@@ -237,7 +256,7 @@ def get_rebin_processes(self):
237256
# check for each process if the *rebin_process_condition* is fulfilled
238257
if not rebin_process_condition(proc_name):
239258
proc_names.remove(proc_name)
240-
259+
logger.info(f"Category {config_category} will be rebinned flat in processes {proc_names}")
241260
return proc_names
242261

243262
def create_branch_map(self):
@@ -286,7 +305,7 @@ def run(self):
286305
outputs["card"].dump(datacard, formatter="text")
287306

288307
with uproot.open(inp_shapes.fn) as file:
289-
print(f"File keys: {file.keys()}")
308+
logger.info(f"File keys: {file.keys()}")
290309
# determine which histograms are present
291310
cat_names, proc_names, syst_names = get_cat_proc_syst_names(file)
292311

@@ -319,7 +338,7 @@ def run(self):
319338
for h in hists[1:]:
320339
hist += h
321340

322-
print(f"Finding rebin values for category {cat_name} using processes {rebin_processes}")
341+
logger.info(f"Finding rebin values for category {cat_name} using processes {rebin_processes}")
323342
rebin_values = get_rebin_values(hist, self.get_n_bins())
324343
outputs["edges"].dump(rebin_values, formatter="json")
325344

@@ -339,5 +358,5 @@ def run(self):
339358

340359
h_rebin = apply_binning(h, rebin_values)
341360
problematic_bin_count = check_empty_bins(h_rebin) # noqa
342-
print(f"Inserting histogram with name {key}")
361+
logger.info(f"Inserting histogram with name {key}")
343362
out_file[key] = uproot.from_pyroot(h_rebin)

0 commit comments

Comments
 (0)