Skip to content

Commit 4b437be

Browse files
authored
Merge pull request #96 from uhh-cms/ML_combine_proc
Ml combine proc
2 parents 1ae1663 + 4e3b009 commit 4b437be

File tree

13 files changed

+479
-21
lines changed

13 files changed

+479
-21
lines changed

hbw/categorization/categories.py

+1
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def catid_2b(self: Categorizer, events: ak.Array, **kwargs) -> tuple[ak.Array, a
271271

272272
# TODO: not hard-coded -> use config?
273273
ml_processes = [
274+
"signal_ggf", "signal_ggf2", "signal_vbf", "signal_vbf2",
274275
"hh_ggf_hbb_hvv_kl1_kt1", "hh_vbf_hbb_hvv_kv1_k2v1_kl1",
275276
"hh_ggf_hbb_hvvqqlnu_kl1_kt1", "hh_vbf_hbb_hvvqqlnu_kv1_k2v1_kl1",
276277
"hh_ggf_hbb_hvv2l2nu_kl1_kt1", "hh_vbf_hbb_hvv2l2nu_kv1_k2v1_kl1",

hbw/config/categories.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -323,16 +323,17 @@ def add_categories_ml(config, ml_model_inst):
323323
# add ml categories directly to the config
324324
# NOTE: this is a bit dangerous, because our ID depends on the MLModel, but
325325
# we can reconfigure our MLModel after having created these categories
326+
# TODO: config is empty and therefore fails
326327
ml_categories = []
327328
for i, proc in enumerate(ml_model_inst.processes):
328-
cat_label = config.get_process(proc).x.ml_label
329+
# cat_label = config.get_process(proc).x.ml_label
329330
ml_categories.append(config.add_category(
330331
# NOTE: name and ID is unique as long as we don't use
331332
# multiple ml_models simutaneously
332333
name=f"ml_{proc}",
333334
id=(i + 1) * 1000,
334335
selection=f"catid_ml_{proc}",
335-
label=f"{cat_label} category",
336+
# label=f"{cat_label} category",
336337
aux={"ml_proc": proc},
337338
))
338339

hbw/config/processes.py

+21
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from scinum import Number
1111

1212
from cmsdb.util import add_decay_process
13+
from columnflow.util import DotDict
1314

1415
from hbw.config.styling import color_palette
1516

@@ -199,3 +200,23 @@ def configure_hbw_processes(config: od.Config):
199200
if config.has_process(bg):
200201
bg = config.get_process(bg)
201202
background.add_process(bg)
203+
204+
205+
from random import randint
206+
207+
208+
def create_combined_proc_forML(config: od.Config, proc_name: str, proc_dict: dict, color=None):
209+
210+
combining_proc = []
211+
for proc in proc_dict.sub_processes:
212+
combining_proc.append(config.get_process(proc, default=None))
213+
proc_name = add_parent_process(config,
214+
combining_proc,
215+
name=proc_name,
216+
id=randint(10000000, 99999999),
217+
# TODO: random number (could by chance be a already used number --> should be checked)
218+
label=proc_dict.get("label", "combined custom process"),
219+
color=proc_dict.get("color", None),
220+
)
221+
ml_config = DotDict({"weighting": proc_dict.get("weighting", None), "sub_processes": proc_dict.sub_processes})
222+
proc_name.x.ml_config = ml_config

hbw/inference/dl.py

+84
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,90 @@
142142

143143
dl = HBWInferenceModelBase.derive("dl", cls_dict=default_cls_dict)
144144

145+
# "hh_vbf_hbb_hvv2l2nu_kv1_k2v1_kl1",
146+
# "hh_vbf_hbb_hvv2l2nu_kv1_k2v0_kl1",
147+
# "hh_vbf_hbb_hvv2l2nu_kvm0p962_k2v0p959_klm1p43",
148+
# "hh_vbf_hbb_hvv2l2nu_kvm1p21_k2v1p94_klm0p94",
149+
# "hh_vbf_hbb_hvv2l2nu_kvm1p6_k2v2p72_klm1p36",
150+
# "hh_vbf_hbb_hvv2l2nu_kvm1p83_k2v3p57_klm3p39",
151+
# "hh_ggf_hbb_hvv2l2nu_kl0_kt1",
152+
# "hh_ggf_hbb_hvv2l2nu_kl1_kt1",
153+
# "hh_ggf_hbb_hvv2l2nu_kl2p45_kt1",
154+
# "hh_ggf_hbb_hvv2l2nu_kl5_kt1",
155+
156+
157+
dl.derive("dl_ml_study_1", cls_dict={
158+
"ml_model_name": "dl_22post_ml_study_1",
159+
"config_categories": [
160+
"sr__1b__ml_signal_ggf",
161+
"sr__1b__ml_signal_vbf",
162+
"sr__1b__ml_tt",
163+
"sr__1b__ml_st",
164+
"sr__1b__ml_dy",
165+
"sr__1b__ml_h",
166+
"sr__2b__ml_signal_ggf",
167+
"sr__2b__ml_signal_vbf",
168+
"sr__2b__ml_tt",
169+
"sr__2b__ml_st",
170+
"sr__2b__ml_dy",
171+
"sr__2b__ml_h",
172+
],
173+
"processes": [
174+
"hh_vbf_hbb_hvv2l2nu_kv1_k2v1_kl1",
175+
"hh_vbf_hbb_hvv2l2nu_kv1_k2v0_kl1",
176+
"hh_vbf_hbb_hvv2l2nu_kvm0p962_k2v0p959_klm1p43",
177+
"hh_vbf_hbb_hvv2l2nu_kvm1p21_k2v1p94_klm0p94",
178+
"hh_vbf_hbb_hvv2l2nu_kvm1p6_k2v2p72_klm1p36",
179+
"hh_vbf_hbb_hvv2l2nu_kvm1p83_k2v3p57_klm3p39",
180+
"hh_ggf_hbb_hvv2l2nu_kl0_kt1",
181+
"hh_ggf_hbb_hvv2l2nu_kl1_kt1",
182+
"hh_ggf_hbb_hvv2l2nu_kl2p45_kt1",
183+
"hh_ggf_hbb_hvv2l2nu_kl5_kt1",
184+
"tt",
185+
"dy",
186+
"w_lnu",
187+
"vv",
188+
"h_ggf", "h_vbf", "zh", "wh", "zh_gg", "tth",
189+
],
190+
"systematics": rate_systematics,
191+
})
192+
193+
dl.derive("dl_ml_study_2", cls_dict={
194+
"ml_model_name": "dl_22post_ml_study_2",
195+
"config_categories": [
196+
"sr__1b__ml_signal_ggf2",
197+
"sr__1b__ml_signal_vbf2",
198+
"sr__1b__ml_tt",
199+
"sr__1b__ml_st",
200+
"sr__1b__ml_dy",
201+
"sr__1b__ml_h",
202+
"sr__2b__ml_signal_ggf2",
203+
"sr__2b__ml_signal_vbf2",
204+
"sr__2b__ml_tt",
205+
"sr__2b__ml_st",
206+
"sr__2b__ml_dy",
207+
"sr__2b__ml_h",
208+
],
209+
"processes": [
210+
"hh_vbf_hbb_hvv2l2nu_kv1_k2v1_kl1",
211+
"hh_vbf_hbb_hvv2l2nu_kv1_k2v0_kl1",
212+
"hh_vbf_hbb_hvv2l2nu_kvm0p962_k2v0p959_klm1p43",
213+
"hh_vbf_hbb_hvv2l2nu_kvm1p21_k2v1p94_klm0p94",
214+
"hh_vbf_hbb_hvv2l2nu_kvm1p6_k2v2p72_klm1p36",
215+
"hh_vbf_hbb_hvv2l2nu_kvm1p83_k2v3p57_klm3p39",
216+
"hh_ggf_hbb_hvv2l2nu_kl0_kt1",
217+
"hh_ggf_hbb_hvv2l2nu_kl1_kt1",
218+
"hh_ggf_hbb_hvv2l2nu_kl2p45_kt1",
219+
"hh_ggf_hbb_hvv2l2nu_kl5_kt1",
220+
"tt",
221+
"dy",
222+
"w_lnu",
223+
"vv",
224+
"h_ggf", "h_vbf", "zh", "wh", "zh_gg", "tth",
225+
],
226+
"systematics": rate_systematics,
227+
})
228+
145229
dl.derive("dl_hww_and_hzz", cls_dict={
146230
"processes": [
147231
"hh_ggf_hbb_hww_kl0_kt1",

hbw/ml/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def uses(self, config_inst: od.Config) -> set[Route | str]:
273273
columns = {"mli_*"}
274274
# TODO: switch to full event weight
275275
# TODO: this might not work with data, to be checked
276+
columns.add("process_id")
276277
columns.add("normalization_weight")
277278
columns.add("stitched_normalization_weight")
278279
columns.add("event_weight")

hbw/ml/data_loader.py

+129-9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,40 @@
1919
logger = law.logger.get_logger(__name__)
2020

2121

22+
def get_proc_mask(
23+
events: ak.Array,
24+
proc: str | od.Process,
25+
config_inst: od.Config | None = None,
26+
) -> tuple(np.ndarray, list):
27+
"""
28+
Creates the mask selecting events belonging to the process *proc* and a list of all ids belonging to this process.
29+
30+
:param events: Event array
31+
:param proc: Either string or process instance.
32+
:param config_inst: An instance of the Config, can be None if Porcess instance is given.
33+
:return process mask and the corresponding process ids
34+
"""
35+
# get process instance
36+
if config_inst:
37+
proc_inst = config_inst.get_process(proc)
38+
elif isinstance(proc, od.Process):
39+
proc_inst = proc
40+
41+
proc_id = events.process_id
42+
unique_proc_ids = set(proc_id)
43+
44+
# get list of Ids that are belonging to the process and are present in the event array
45+
sub_id = [
46+
proc_inst.id
47+
for proc_inst, _, _ in proc_inst.walk_processes(include_self=True)
48+
if proc_inst.id in unique_proc_ids
49+
]
50+
51+
# Create process mask
52+
proc_mask = np.isin(proc_id, sub_id)
53+
return proc_mask, sub_id
54+
55+
2256
def input_features_sanity_checks(ml_model_inst: MLModel, input_features: list[str]):
2357
"""
2458
Perform sanity checks on the input features.
@@ -78,8 +112,10 @@ def __init__(self, ml_model_inst: MLModel, process: "str", events: ak.Array, sta
78112
"""
79113
self._ml_model_inst = ml_model_inst
80114
self._process = process
115+
116+
proc_mask, _ = get_proc_mask(events, process, ml_model_inst.config_inst)
81117
self._stats = stats
82-
self._events = events
118+
self._events = events[proc_mask]
83119

84120
def __repr__(self):
85121
return f"{self.__class__.__name__}({self.ml_model_inst.cls_name}, {self.process})"
@@ -185,21 +221,89 @@ def shuffle_indices(self) -> np.ndarray:
185221
self._shuffle_indices = np.random.permutation(self.n_events)
186222
return self._shuffle_indices
187223

224+
def get_xsec_train_weights(self) -> np.ndarray:
225+
"""
226+
Weighting such that each event has roughly the same weight,
227+
sub processes are weighted accoridng to their cross section
228+
"""
229+
if hasattr(self, "_xsec_train_weights"):
230+
return self._xsec_train_weights
231+
232+
if not self.stats:
233+
raise Exception("cannot determine train weights without stats")
234+
235+
_, sub_id = get_proc_mask(self._events, self.process, self.ml_model_inst.config_inst)
236+
sum_abs_weights = np.sum([self.stats[self.process]["sum_abs_weights_per_process"][str(id)] for id in sub_id])
237+
num_events = np.sum([self.stats[self.process]["num_events_per_process"][str(id)] for id in sub_id])
238+
239+
xsec_train_weights = self.weights / sum_abs_weights * num_events
240+
241+
return xsec_train_weights
242+
243+
def get_equal_train_weights(self) -> np.ndarray:
244+
"""
245+
Weighting such that events of each sub processes are weighted equally
246+
"""
247+
if hasattr(self, "_equally_train_weights"):
248+
return self._equal_train_weights
249+
250+
if not self.stats:
251+
raise Exception("cannot determine train weights without stats")
252+
253+
combined_proc_inst = self.ml_model_inst.config_inst.get_process(self.process)
254+
_, sub_id_proc = get_proc_mask(self._events, self.process, self.ml_model_inst.config_inst)
255+
num_events = np.sum([self.stats[self.process]["num_events_per_process"][str(id)] for id in sub_id_proc])
256+
targeted_sum_of_weights_per_process = (
257+
num_events / len(combined_proc_inst.x.ml_config.sub_processes)
258+
)
259+
equal_train_weights = ak.full_like(self.weights, 1.)
260+
sub_class_factors = {}
261+
262+
for proc in combined_proc_inst.x.ml_config.sub_processes:
263+
proc_mask, sub_id = get_proc_mask(self._events, proc, self.ml_model_inst.config_inst)
264+
sum_pos_weights_per_sub_proc = 0.
265+
sum_pos_weights_per_proc = self.stats[self.process]["sum_pos_weights_per_process"]
266+
267+
for id in sub_id:
268+
id = str(id)
269+
if id in self.stats[self.process]["num_events_per_process"]:
270+
sum_pos_weights_per_sub_proc += sum_pos_weights_per_proc[id]
271+
272+
if sum_pos_weights_per_sub_proc == 0:
273+
norm_const_per_proc = 1.
274+
logger.info(
275+
f"No weight sum found in stats for sub process {proc}."
276+
f"Normalization constant set to 1 but results are probably not correct.")
277+
else:
278+
norm_const_per_proc = targeted_sum_of_weights_per_process / sum_pos_weights_per_sub_proc
279+
logger.info(f"Normalizing constant for {proc} is {norm_const_per_proc}")
280+
281+
sub_class_factors[proc] = norm_const_per_proc
282+
equal_train_weights = np.where(proc_mask, self.weights * norm_const_per_proc, equal_train_weights)
283+
284+
return equal_train_weights
285+
188286
@property
189287
def train_weights(self) -> np.ndarray:
190288
"""
191-
Weighting such that each event has roughly the same weight
289+
Weighting according to the parameters set in the ML model config
192290
"""
193291
if hasattr(self, "_train_weights"):
194292
return self._train_weights
195293

196294
if not self.stats:
197295
raise Exception("cannot determine train weights without stats")
198296

199-
sum_abs_weights = self.stats[self.process]["sum_abs_weights"]
200-
num_events = self.stats[self.process]["num_events"]
297+
# TODO: hier muss np.float gemacht werden
298+
proc = self.process
299+
proc_inst = self.ml_model_inst.config_inst.get_process(proc)
300+
if proc_inst.x("ml_config", None) and proc_inst.x.ml_config.weighting == "equal":
301+
train_weights = self.get_equal_train_weights()
302+
else:
303+
train_weights = self.get_xsec_train_weights()
304+
305+
self._train_weights = ak.to_numpy(train_weights).astype(np.float32)
201306

202-
self._train_weights = self.weights / sum_abs_weights * num_events
203307
return self._train_weights
204308

205309
@property
@@ -213,11 +317,26 @@ def equal_weights(self) -> np.ndarray:
213317
if not self.stats:
214318
raise Exception("cannot determine val weights without stats")
215319

320+
# TODO: per process pls [done] and now please tidy up
216321
processes = self.ml_model_inst.processes
217-
sum_abs_weights = self.stats[self.process]["sum_abs_weights"]
218-
num_events_per_process = {proc: self.stats[proc]["num_events"] for proc in processes}
219-
220-
self._validation_weights = self.weights / sum_abs_weights * max(num_events_per_process.values())
322+
num_events_per_process = {}
323+
for proc in processes:
324+
id_list = list(self.stats[proc]["num_events_per_process"].keys())
325+
proc_inst = self.ml_model_inst.config_inst.get_process(proc)
326+
sub_id = [
327+
p_inst.id
328+
for p_inst, _, _ in proc_inst.walk_processes(include_self=True)
329+
if str(p_inst.id) in id_list
330+
]
331+
if proc == self.process:
332+
sum_abs_weights = np.sum([
333+
self.stats[self.process]["sum_abs_weights_per_process"][str(id)] for id in sub_id
334+
])
335+
num_events_per_proc = np.sum([self.stats[proc]["num_events_per_process"][str(id)] for id in sub_id])
336+
num_events_per_process[proc] = num_events_per_proc
337+
338+
validation_weights = self.weights / sum_abs_weights * max(num_events_per_process.values())
339+
self._validation_weights = ak.to_numpy(validation_weights).astype(np.float32)
221340

222341
return self._validation_weights
223342

@@ -544,6 +663,7 @@ def target(self) -> np.ndarray:
544663
if self._ml_model_inst.negative_weights == "handle":
545664
target[self.m_negative_weights] = 1 - target[self.m_negative_weights]
546665

666+
# NOTE: I think here the targets are somehow 64floats... Maybe check that
547667
self._target = target
548668
return self._target
549669

0 commit comments

Comments
 (0)