Skip to content

Commit 6933274

Browse files
committed
implement PlotPostfitShapes task
1 parent 869055a commit 6933274

File tree

2 files changed

+275
-3
lines changed

2 files changed

+275
-3
lines changed

hbw/tasks/inference.py

+81-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ def requires(self):
273273
reqs = {
274274
"datacards": self.reqs.CreateDatacards.req(self),
275275
}
276-
277276
return reqs
278277

279278
def output(self):
@@ -360,3 +359,84 @@ def run(self):
360359
problematic_bin_count = check_empty_bins(h_rebin) # noqa
361360
logger.info(f"Inserting histogram with name {key}")
362361
out_file[key] = uproot.from_pyroot(h_rebin)
362+
363+
364+
class PrepareInferenceTaskCalls(
365+
HBWTask,
366+
InferenceModelMixin,
367+
MLModelsMixin,
368+
ProducersMixin,
369+
SelectorStepsMixin,
370+
CalibratorsMixin,
371+
):
372+
"""
373+
Simple task that produces string to run certain tasks in Inference
374+
"""
375+
376+
# upstream requirements
377+
reqs = Requirements(
378+
ModifyDatacardsFlatRebin=ModifyDatacardsFlatRebin,
379+
)
380+
381+
def workflow_requires(self):
382+
reqs = super().workflow_requires()
383+
384+
reqs["rebinned_datacards"] = self.reqs.ModifyDatacardsFlatRebin.req(self)
385+
386+
return reqs
387+
388+
def requires(self):
389+
reqs = {
390+
"rebinned_datacards": self.reqs.ModifyDatacardsFlatRebin.req(self),
391+
}
392+
return reqs
393+
394+
def output(self):
395+
return {
396+
"PlotUpperLimitsAtPoint": self.target("PlotUpperLimitsAtPoint.txt"),
397+
"PlotUpperLimitsPoint": self.target("PlotUpperLimitsPoint.txt"),
398+
"FitDiagnostics": self.target("FitDiagnostics.txt"),
399+
}
400+
401+
def run(self):
402+
inputs = self.input()
403+
output = self.output()
404+
405+
# string that represents the version of datacards
406+
identifier = "__".join([self.config, self.selector, self.inference_model, self.version])
407+
408+
# get the datacard names from the inputs
409+
collection = inputs["rebinned_datacards"]["collection"]
410+
card_fns = [collection[key]["card"].fn for key in collection.keys()]
411+
412+
# get the category names from the inference models
413+
categories = self.inference_model_inst.categories
414+
cat_names = [c.name for c in categories]
415+
416+
# combine category names with card fn to a single string
417+
datacards = ",".join([f"{cat_name}={card_fn}" for cat_name, card_fn in zip(cat_names, card_fns)])
418+
419+
print("\n\n")
420+
# creating upper limits for kl=1
421+
cmd = (
422+
f"law run PlotUpperLimitsAtPoint --version {identifier} --multi-datacards {datacards} "
423+
f"--datacard-names {identifier}"
424+
)
425+
print(cmd, "\n\n")
426+
output["PlotUpperLimitsAtPoint"].dump(cmd, formatter="text")
427+
428+
# creating kl scan
429+
cmd = (
430+
f"law run PlotUpperLimits --version {identifier} --datacards {datacards} "
431+
f"--xsec fb --y-log"
432+
)
433+
print(cmd, "\n\n")
434+
output["PlotUpperLimitsPoint"].dump(cmd, formatter="text")
435+
436+
# running FitDiagnostics for Pre+Postfit plots
437+
cmd = (
438+
f"law run FitDiagnostics --version {identifier} --datacards {datacards} "
439+
f"--skip-b-only"
440+
)
441+
print(cmd, "\n\n")
442+
output["FitDiagnostics"].dump(cmd, formatter="text")

hbw/tasks/plotting.py

+194-2
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,32 @@
55
e.g. default sets of plots or datacards
66
"""
77

8+
from __future__ import annotations
9+
10+
from collections import OrderedDict
11+
812
import law
913
import luigi
14+
import order as od
1015

1116
from columnflow.tasks.framework.base import Requirements
1217
from columnflow.tasks.framework.mixins import (
1318
InferenceModelMixin, MLModelsMixin, ProducersMixin, SelectorStepsMixin,
1419
CalibratorsMixin,
1520
)
1621
from columnflow.tasks.framework.plotting import (
17-
PlotBase1D, VariablePlotSettingMixin, ProcessPlotSettingMixin,
22+
PlotBase, PlotBase1D, VariablePlotSettingMixin, ProcessPlotSettingMixin,
1823
)
1924
from columnflow.tasks.plotting import PlotVariables1D
2025
# from columnflow.tasks.framework.remote import RemoteWorkflow
2126
from hbw.tasks.base import HBWTask
2227

23-
from columnflow.util import dev_sandbox
28+
from columnflow.util import dev_sandbox, DotDict, maybe_import
29+
30+
uproot = maybe_import("uproot")
31+
32+
33+
logger = law.logger.get_logger(__name__)
2434

2535

2636
class InferencePlots(
@@ -106,3 +116,185 @@ def output(self):
106116

107117
def run(self):
108118
pass
119+
120+
121+
def load_hists_uproot(fit_diagnostics_path):
122+
""" Helper to load histograms from a fit_diagnostics file """
123+
# prepare output dict
124+
hists = DotDict()
125+
with uproot.open(fit_diagnostics_path) as tfile:
126+
keys = [key.split("/") for key in tfile.keys()]
127+
for key in keys:
128+
if len(key) != 3:
129+
continue
130+
131+
# get the histogram from the tfile
132+
h_in = tfile["/".join(key)]
133+
134+
# unpack key
135+
fit, channel, process = key
136+
process = process.split(";")[0]
137+
138+
if "data" not in process:
139+
# transform TH1F to hist
140+
h_in = h_in.to_hist()
141+
142+
# set the histogram in a deep dictionary
143+
hists = law.util.merge_dicts(hists, DotDict.wrap({fit: {channel: {process: h_in}}}), deep=True)
144+
145+
return hists
146+
147+
148+
# imports regarding plot function
149+
mpl = maybe_import("matplotlib")
150+
plt = maybe_import("matplotlib.pyplot")
151+
mplhep = maybe_import("mplhep")
152+
153+
from columnflow.plotting.plot_all import plot_all
154+
from columnflow.plotting.plot_util import (
155+
prepare_plot_config,
156+
prepare_style_config,
157+
)
158+
159+
160+
def plot_postfit_shapes(
161+
hists: OrderedDict,
162+
config_inst: od.Config,
163+
category_inst: od.Category,
164+
variable_insts: list[od.Variable],
165+
style_config: dict | None = None,
166+
density: bool | None = False,
167+
shape_norm: bool | None = False,
168+
yscale: str | None = "",
169+
hide_errors: bool | None = None,
170+
process_settings: dict | None = None,
171+
variable_settings: dict | None = None,
172+
**kwargs,
173+
) -> tuple(plt.Figure, tuple(plt.Axes)):
174+
variable_inst = law.util.make_tuple(variable_insts)[0]
175+
176+
plot_config = prepare_plot_config(
177+
hists,
178+
shape_norm=shape_norm,
179+
hide_errors=hide_errors,
180+
)
181+
182+
default_style_config = prepare_style_config(
183+
config_inst, category_inst, variable_inst, density, shape_norm, yscale,
184+
)
185+
default_style_config["ax_cfg"].pop("xlim")
186+
187+
style_config = law.util.merge_dicts(default_style_config, style_config, deep=True)
188+
if shape_norm:
189+
style_config["ax_cfg"]["ylabel"] = r"$\Delta N/N$"
190+
191+
return plot_all(plot_config, style_config, **kwargs)
192+
193+
194+
class PlotPostfitShapes(
195+
HBWTask,
196+
PlotBase1D,
197+
# to correctly setup our InferenceModel, we need all these mixins, but hopefully, all these
198+
# parameters are automatically resolved correctly
199+
InferenceModelMixin,
200+
MLModelsMixin,
201+
ProducersMixin,
202+
SelectorStepsMixin,
203+
CalibratorsMixin,
204+
):
205+
"""
206+
Task that creates Postfit shape plots based on a fit_diagnostics file.
207+
208+
Work in Progress!
209+
TODO:
210+
- include data
211+
- include correct uncertainty bands
212+
- pass correct binning information
213+
"""
214+
215+
sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox"))
216+
217+
plot_function = PlotBase.plot_function.copy(
218+
default="hbw.tasks.plotting.plot_postfit_shapes",
219+
add_default_to_description=True,
220+
)
221+
222+
fit_diagnostics_file = luigi.Parameter(
223+
default=law.NO_STR,
224+
description="fit_diagnostics file that is used to load histograms",
225+
)
226+
227+
prefit = luigi.BoolParameter(
228+
default=False,
229+
description="Whether to do prefit or postfit plots; defaults to False",
230+
)
231+
232+
def requires(self):
233+
return {}
234+
235+
def output(self):
236+
return {"plots": self.target("plots", dir=True)}
237+
238+
def run(self):
239+
logger.warning(
240+
f"Note! It is important that the requested inference_model {self.inference_model} "
241+
"is identical to the one that has been used to create the datacards",
242+
)
243+
all_hists = load_hists_uproot(self.fit_diagnostics_file)
244+
245+
outp = self.output()
246+
if self.prefit:
247+
fit_type = "prefit"
248+
else:
249+
fit_type = "fit_s"
250+
251+
all_hists = all_hists[f"shapes_{fit_type}"]
252+
253+
for channel, hists in all_hists.items():
254+
has_category = self.inference_model_inst.has_category(channel)
255+
if not has_category:
256+
logger.warning(f"Category {channel} is not part of the inference model {self.inference_model}")
257+
258+
for proc_key in list(hists.keys()):
259+
# remove unnecessary histograms
260+
if "data" in proc_key or "total" in proc_key:
261+
hists.pop(proc_key)
262+
continue
263+
264+
proc_inst = None
265+
# try getting the config process via InferenceModel
266+
if has_category:
267+
# TODO: process customization based on inference process? e.g. scale
268+
inference_process = self.inference_model_inst.get_process(proc_key, channel)
269+
proc_inst = self.config_inst.get_process(inference_process.config_process)
270+
else:
271+
# try getting proc inst directly via config
272+
proc_inst = self.config_inst.get_process(proc_key, default=None)
273+
274+
# replace string keys with process instances
275+
if proc_inst:
276+
hists[proc_inst] = hists[proc_key]
277+
hists.pop(proc_key)
278+
279+
# try getting the config category and variable via InferenceModel
280+
if has_category:
281+
# TODO: category/variable customization based on inference model?
282+
inference_category = self.inference_model_inst.get_category(channel)
283+
config_category = self.config_inst.get_category(inference_category.config_category)
284+
variable_inst = self.config_inst.get_variable(inference_category.config_variable)
285+
else:
286+
# default to dummy Category and Variable
287+
config_category = od.Category(channel, id=1)
288+
variable_inst = od.Variable("dummy")
289+
290+
# call the plot function
291+
fig, _ = self.call_plot_func(
292+
self.plot_function,
293+
hists=hists,
294+
config_inst=self.config_inst,
295+
category_inst=config_category,
296+
variable_insts=variable_inst,
297+
**self.get_plot_parameters(),
298+
)
299+
300+
outp["plots"].child(f"{channel}_{fit_type}.pdf", type="f").dump(fig, formatter="mpl")

0 commit comments

Comments
 (0)