Skip to content

Commit 477b9ec

Browse files
committed
add custom yields task (temporarily)
1 parent a2a8396 commit 477b9ec

File tree

1 file changed

+316
-0
lines changed

1 file changed

+316
-0
lines changed

hbw/tasks/yields.py

+316
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
# coding: utf-8
2+
3+
"""
4+
Tasks to produce yield tables.
5+
Taken from columnflow and customized.
6+
TODO: after merging MultiConfigPlotting, we should propagate the changes back to columnflow
7+
"""
8+
9+
import math
10+
from collections import defaultdict, OrderedDict
11+
12+
import law
13+
import luigi
14+
import order as od
15+
from scinum import Number
16+
17+
from columnflow.tasks.framework.base import Requirements, ConfigTask
18+
from columnflow.tasks.framework.mixins import (
19+
CalibratorsMixin, SelectorStepsMixin, ProducersMixin,
20+
DatasetsProcessesMixin, CategoriesMixin, WeightProducerMixin,
21+
)
22+
from columnflow.tasks.framework.remote import RemoteWorkflow
23+
from columnflow.tasks.histograms import MergeHistograms
24+
from columnflow.util import dev_sandbox, try_int
25+
26+
from hbw.tasks.base import HBWTask
27+
28+
logger = law.logger.get_logger(__name__)
29+
30+
31+
class CustomCreateYieldTable(
32+
HBWTask,
33+
DatasetsProcessesMixin,
34+
CategoriesMixin,
35+
WeightProducerMixin,
36+
ProducersMixin,
37+
# MLModelsMixin,
38+
SelectorStepsMixin,
39+
CalibratorsMixin,
40+
ConfigTask,
41+
law.LocalWorkflow,
42+
RemoteWorkflow,
43+
):
44+
sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox"))
45+
46+
yields_variable = "mll"
47+
48+
table_format = luigi.Parameter(
49+
default="fancy_grid",
50+
significant=False,
51+
description="format of the yield table; accepts all formats of the tabulate package; "
52+
"default: fancy_grid",
53+
)
54+
number_format = luigi.Parameter(
55+
default="pdg",
56+
significant=False,
57+
description="rounding format of each number in the yield table; accepts all formats "
58+
"understood by scinum.Number.str(), e.g. 'pdg', 'publication', '%.1f' or an integer "
59+
"(number of signficant digits); default: pdg",
60+
)
61+
skip_uncertainties = luigi.BoolParameter(
62+
default=False,
63+
significant=False,
64+
description="when True, uncertainties are not displayed in the table; default: False",
65+
)
66+
normalize_yields = luigi.ChoiceParameter(
67+
choices=(law.NO_STR, "per_process", "per_category", "all"),
68+
default=law.NO_STR,
69+
significant=False,
70+
description="string parameter to define the normalization of the yields; "
71+
"choices: '', per_process, per_category, all; empty default",
72+
)
73+
output_suffix = luigi.Parameter(
74+
default=law.NO_STR,
75+
description="Adds a suffix to the output name of the yields table; empty default",
76+
)
77+
transpose = luigi.BoolParameter(
78+
default=False,
79+
significant=False,
80+
description="Transpose the yield table; default: False",
81+
)
82+
ratio = law.CSVParameter(
83+
default=("data", "background"),
84+
significant=False,
85+
description="Ratio of two processes to be calculated and added to the table",
86+
)
87+
88+
# upstream requirements
89+
reqs = Requirements(
90+
RemoteWorkflow.reqs,
91+
MergeHistograms=MergeHistograms,
92+
)
93+
94+
# dummy branch map
95+
def create_branch_map(self):
96+
return [0]
97+
98+
def requires(self):
99+
return {
100+
d: self.reqs.MergeHistograms.req(
101+
self,
102+
dataset=d,
103+
variables=(self.yields_variable,),
104+
_prefer_cli={"variables"},
105+
)
106+
for d in self.datasets
107+
}
108+
109+
def workflow_requires(self):
110+
reqs = super().workflow_requires()
111+
112+
reqs["merged_hists"] = [
113+
self.reqs.MergeHistograms.req(
114+
self,
115+
dataset=d,
116+
variables=(self.yields_variable,),
117+
_exclude={"branches"},
118+
)
119+
for d in self.datasets
120+
]
121+
122+
return reqs
123+
124+
@classmethod
125+
def resolve_param_values(cls, params):
126+
params = super().resolve_param_values(params)
127+
128+
if "number_format" in params and try_int(params["number_format"]):
129+
# convert 'number_format' in integer if possible
130+
params["number_format"] = int(params["number_format"])
131+
132+
return params
133+
134+
def output(self):
135+
suffix = ""
136+
if self.output_suffix and self.output_suffix != law.NO_STR:
137+
suffix = f"__{self.output_suffix}"
138+
139+
return {
140+
"table": self.target(f"table__proc_{self.processes_repr}__cat_{self.categories_repr}{suffix}.txt"),
141+
"csv": self.target(f"table__proc_{self.processes_repr}__cat_{self.categories_repr}{suffix}.csv"),
142+
"yields": self.target(f"yields__proc_{self.processes_repr}__cat_{self.categories_repr}{suffix}.json"),
143+
}
144+
145+
@law.decorator.notify
146+
@law.decorator.log
147+
def run(self):
148+
import hist
149+
from tabulate import tabulate
150+
151+
inputs = self.input()
152+
outputs = self.output()
153+
154+
category_insts = list(map(self.config_inst.get_category, sorted(self.categories)))
155+
process_insts = list(map(self.config_inst.get_process, self.processes))
156+
sub_process_insts = {
157+
proc: [sub for sub, _, _ in proc.walk_processes(include_self=True)]
158+
for proc in process_insts
159+
}
160+
161+
# histogram data per process
162+
hists = {}
163+
164+
with self.publish_step(f"Creating yields for processes {self.processes}, categories {self.categories}"):
165+
for dataset, inp in inputs.items():
166+
dataset_inst = self.config_inst.get_dataset(dataset)
167+
168+
# load the histogram of the variable named self.yields_variable
169+
h_in = inp["hists"][self.yields_variable].load(formatter="pickle")
170+
171+
# loop and extract one histogram per process
172+
for process_inst in process_insts:
173+
# skip when the dataset is already known to not contain any sub process
174+
if not any(map(dataset_inst.has_process, sub_process_insts[process_inst])):
175+
continue
176+
177+
# work on a copy
178+
h = h_in.copy()
179+
180+
# axis selections
181+
h = h[{
182+
"process": [
183+
hist.loc(p.id)
184+
for p in sub_process_insts[process_inst]
185+
if p.id in h.axes["process"]
186+
],
187+
}]
188+
189+
# axis reductions
190+
h = h[{"process": sum, "shift": sum, self.yields_variable: sum}]
191+
192+
# add the histogram
193+
if process_inst in hists:
194+
hists[process_inst] += h
195+
else:
196+
hists[process_inst] = h
197+
198+
# there should be hists to plot
199+
if not hists:
200+
raise Exception("no histograms found to plot")
201+
202+
# sort hists by process order
203+
hists = OrderedDict(
204+
(process_inst, hists[process_inst])
205+
for process_inst in sorted(hists, key=process_insts.index)
206+
)
207+
208+
yields, processes = defaultdict(list), []
209+
210+
# read out yields per category and per process
211+
for process_inst, h in hists.items():
212+
processes.append(process_inst)
213+
214+
# TODO: ratio
215+
216+
for category_inst in category_insts:
217+
leaf_category_insts = category_inst.get_leaf_categories() or [category_inst]
218+
219+
h_cat = h[{"category": [
220+
hist.loc(c.name)
221+
for c in leaf_category_insts
222+
if c.name in h.axes["category"]
223+
]}]
224+
h_cat = h_cat[{"category": sum}]
225+
226+
value = Number(h_cat.value)
227+
if not self.skip_uncertainties:
228+
# set a unique uncertainty name for correct propagation below
229+
value.set_uncertainty(
230+
f"mcstat_{process_inst.name}_{category_inst.name}",
231+
math.sqrt(h_cat.variance),
232+
)
233+
yields[category_inst].append(value)
234+
235+
if self.ratio:
236+
processes.append(od.Process("Ratio", id=-9871, label=f"{self.ratio[0]} / {self.ratio[1]}"))
237+
num_idx, den_idxs = [processes.index(self.config_inst.get_process(_p)) for _p in self.ratio]
238+
for category_inst in category_insts:
239+
num = yields[category_inst][num_idx]
240+
den = yields[category_inst][den_idxs]
241+
yields[category_inst].append(num / den)
242+
243+
# obtain normalizaton factors
244+
norm_factors = 1
245+
if self.normalize_yields == "all":
246+
norm_factors = sum(
247+
sum(category_yields)
248+
for category_yields in yields.values()
249+
)
250+
elif self.normalize_yields == "per_process":
251+
norm_factors = [
252+
sum(yields[category][i] for category in yields.keys())
253+
for i in range(len(yields[category_insts[0]]))
254+
]
255+
elif self.normalize_yields == "per_category":
256+
norm_factors = {
257+
category: sum(category_yields)
258+
for category, category_yields in yields.items()
259+
}
260+
261+
# initialize dicts
262+
main_label = "Category" if self.transpose else "Process"
263+
yields_str = defaultdict(list, {main_label: [proc.label for proc in processes]})
264+
raw_yields = defaultdict(dict, {})
265+
266+
# apply normalization and format
267+
for category, category_yields in yields.items():
268+
for i, value in enumerate(category_yields):
269+
# get correct norm factor per category and process
270+
if self.normalize_yields == "per_process":
271+
norm_factor = norm_factors[i]
272+
elif self.normalize_yields == "per_category":
273+
norm_factor = norm_factors[category]
274+
else:
275+
norm_factor = norm_factors
276+
277+
raw_yield = (value / norm_factor).nominal
278+
raw_yields[category.name][processes[i].name] = raw_yield
279+
280+
# format yields into strings
281+
yield_str = (value / norm_factor).str(
282+
combine_uncs="all",
283+
format=self.number_format,
284+
style="latex" if "latex" in self.table_format else "plain",
285+
)
286+
if "latex" in self.table_format:
287+
yield_str = f"${yield_str}$"
288+
cat_label = category.name.replace("__", " ")
289+
# cat_label = category.label
290+
yields_str[cat_label].append(yield_str)
291+
292+
# Transposing the table
293+
data = [list(yields_str.keys())] + list(zip(*yields_str.values()))
294+
if self.transpose:
295+
data = list(zip(*data))
296+
297+
headers = data[0]
298+
299+
# create, print and save the yield table
300+
yield_table = tabulate(data[1:], headers=headers, tablefmt=self.table_format)
301+
302+
with_grid = True
303+
if with_grid and self.table_format == "latex_raw":
304+
# identify line breaks and add hlines after every line break
305+
yield_table = yield_table.replace("\\\\", "\\\\ \\hline")
306+
# TODO: lll -> |l|l|l|, etc.
307+
self.publish_message(yield_table)
308+
309+
outputs["table"].dump(yield_table, formatter="text")
310+
outputs["yields"].dump(raw_yields, formatter="json")
311+
312+
outputs["csv"].touch()
313+
with open(outputs["csv"].abspath, "w", newline="") as csvfile:
314+
import csv
315+
writer = csv.writer(csvfile)
316+
writer.writerows(data)

0 commit comments

Comments
 (0)