Skip to content

Commit fc83b52

Browse files
Lara813mafrahm
andauthored
Apply suggestions from code review
Co-authored-by: Mathis Frahm <49306645+mafrahm@users.noreply.github.com>
1 parent 3febce2 commit fc83b52

File tree

2 files changed

+6
-35
lines changed

2 files changed

+6
-35
lines changed

hbw/ml/data_loader.py

+5-32
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ def get_proc_mask(
2323
events: ak.Array,
2424
proc: str | od.Process,
2525
config_inst: od.Config | None = None,
26-
) -> np.ndarray:
26+
) -> tuple(np.ndarray, list):
2727
"""
28-
Creates a list of the Ids of all subprocesses and teh corresponding mask for all events.
28+
Creates the mask selecting events belonging to the process *proc* and a list of all ids belonging to this process.
2929
30-
:param events: Event array
31-
:param config_inst: An instance of the Config, can be None if Porcess instance is given.
30+
:param events: Event array
3231
:param proc: Either string or process instance.
32+
:param config_inst: An instance of the Config, can be None if Porcess instance is given.
33+
:return process mask and the corresponding process ids
3334
"""
3435
# get process instance
3536
if config_inst:
@@ -52,27 +53,6 @@ def get_proc_mask(
5253
return proc_mask, sub_id
5354

5455

55-
def del_sub_proc_stats(
56-
stats: dict,
57-
proc: str,
58-
sub_id: list,
59-
) -> np.ndarray:
60-
"""
61-
Function deletes dict keys which are not part of the requested process
62-
63-
:param stats: Dictionaire containing ML stats for each process.
64-
:param proc: String of the process.
65-
:param sub_id: List of ids of sub processes that should be reatined (!).
66-
"""
67-
id_list = list(stats[proc]["num_events_per_process"].keys())
68-
item_list = list(stats[proc].keys())
69-
for id in id_list:
70-
if int(id) not in sub_id:
71-
for item in item_list:
72-
if "per_process" in item:
73-
del stats[proc][item][id]
74-
75-
7656
def input_features_sanity_checks(ml_model_inst: MLModel, input_features: list[str]):
7757
"""
7858
Perform sanity checks on the input features.
@@ -134,9 +114,7 @@ def __init__(self, ml_model_inst: MLModel, process: "str", events: ak.Array, sta
134114
self._process = process
135115

136116
proc_mask, _ = get_proc_mask(events, process, ml_model_inst.config_inst)
137-
# TODO: die ohne _per_process müssen auch noch, still, per fold never make sense then anymore -> DISCUSS
138117
self._stats = stats
139-
# del_sub_proc_stats(process, sub_id)
140118
self._events = events[proc_mask]
141119

142120
def __repr__(self):
@@ -323,9 +301,6 @@ def train_weights(self) -> np.ndarray:
323301
train_weights = self.get_equal_train_weights()
324302
else:
325303
train_weights = self.get_xsec_train_weights()
326-
# self._train_weights = self.get_equal_train_weights()
327-
# else:
328-
# self._train_weights = self.get_xsec_train_weights()
329304

330305
self._train_weights = ak.to_numpy(train_weights).astype(np.float32)
331306

@@ -360,8 +335,6 @@ def equal_weights(self) -> np.ndarray:
360335
num_events_per_proc = np.sum([self.stats[proc]["num_events_per_process"][str(id)] for id in sub_id])
361336
num_events_per_process[proc] = num_events_per_proc
362337

363-
# sum_abs_weights = self.stats[self.process]["sum_abs_weights"]
364-
# num_events_per_process = {proc: self.stats[proc]["num_events"] for proc in processes}
365338
validation_weights = self.weights / sum_abs_weights * max(num_events_per_process.values())
366339
self._validation_weights = ak.to_numpy(validation_weights).astype(np.float32)
367340

hbw/tasks/ml.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,8 @@ def run(self):
404404
n_events_per_fold = len(ml_dataset.train_weights)
405405
logger.info(f"Sum of traing weights is: {sum_train_weights} for {n_events_per_fold} {process} events")
406406

407+
# check that equal weighting works as intended
407408
if self.ml_model_inst.config_inst.get_process(process).x("ml_config", None):
408-
409409
if self.ml_model_inst.config_inst.get_process(process).x.ml_config.weighting == "equal":
410410
for sub_proc in self.ml_model_inst.config_inst.get_process(process).x.ml_config.sub_processes:
411411
proc_mask, sub_id = get_proc_mask(events, sub_proc, self.ml_model_inst.config_inst)
@@ -712,8 +712,6 @@ def run(self):
712712
"test": MLProcessData(self.ml_model_inst, input_files, "test", self.ml_model_inst.processes, self.fold),
713713
})
714714

715-
# ML WEIGHTING data.train.equal_weights
716-
717715
# create plots
718716
# NOTE: this is currently hard-coded, could be made customizable and could also be parallelized since
719717
# input reading is quite fast, while producing certain plots takes a long time

0 commit comments

Comments
 (0)