|
5 | 5 | e.g. default sets of plots or datacards
|
6 | 6 | """
|
7 | 7 |
|
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +from collections import OrderedDict |
| 11 | + |
8 | 12 | import law
|
9 | 13 | import luigi
|
| 14 | +import order as od |
10 | 15 |
|
11 | 16 | from columnflow.tasks.framework.base import Requirements
|
12 | 17 | from columnflow.tasks.framework.mixins import (
|
13 | 18 | InferenceModelMixin, MLModelsMixin, ProducersMixin, SelectorStepsMixin,
|
14 | 19 | CalibratorsMixin,
|
15 | 20 | )
|
16 | 21 | from columnflow.tasks.framework.plotting import (
|
17 |
| - PlotBase1D, VariablePlotSettingMixin, ProcessPlotSettingMixin, |
| 22 | + PlotBase, PlotBase1D, VariablePlotSettingMixin, ProcessPlotSettingMixin, |
18 | 23 | )
|
19 | 24 | from columnflow.tasks.plotting import PlotVariables1D
|
20 | 25 | # from columnflow.tasks.framework.remote import RemoteWorkflow
|
21 | 26 | from hbw.tasks.base import HBWTask
|
22 | 27 |
|
23 |
| -from columnflow.util import dev_sandbox |
| 28 | +from columnflow.util import dev_sandbox, DotDict, maybe_import |
| 29 | + |
| 30 | +uproot = maybe_import("uproot") |
| 31 | + |
| 32 | + |
| 33 | +logger = law.logger.get_logger(__name__) |
24 | 34 |
|
25 | 35 |
|
26 | 36 | class InferencePlots(
|
@@ -106,3 +116,185 @@ def output(self):
|
106 | 116 |
|
107 | 117 | def run(self):
|
108 | 118 | pass
|
| 119 | + |
| 120 | + |
| 121 | +def load_hists_uproot(fit_diagnostics_path): |
| 122 | + """ Helper to load histograms from a fit_diagnostics file """ |
| 123 | + # prepare output dict |
| 124 | + hists = DotDict() |
| 125 | + with uproot.open(fit_diagnostics_path) as tfile: |
| 126 | + keys = [key.split("/") for key in tfile.keys()] |
| 127 | + for key in keys: |
| 128 | + if len(key) != 3: |
| 129 | + continue |
| 130 | + |
| 131 | + # get the histogram from the tfile |
| 132 | + h_in = tfile["/".join(key)] |
| 133 | + |
| 134 | + # unpack key |
| 135 | + fit, channel, process = key |
| 136 | + process = process.split(";")[0] |
| 137 | + |
| 138 | + if "data" not in process: |
| 139 | + # transform TH1F to hist |
| 140 | + h_in = h_in.to_hist() |
| 141 | + |
| 142 | + # set the histogram in a deep dictionary |
| 143 | + hists = law.util.merge_dicts(hists, DotDict.wrap({fit: {channel: {process: h_in}}}), deep=True) |
| 144 | + |
| 145 | + return hists |
| 146 | + |
| 147 | + |
| 148 | +# imports regarding plot function |
| 149 | +mpl = maybe_import("matplotlib") |
| 150 | +plt = maybe_import("matplotlib.pyplot") |
| 151 | +mplhep = maybe_import("mplhep") |
| 152 | + |
| 153 | +from columnflow.plotting.plot_all import plot_all |
| 154 | +from columnflow.plotting.plot_util import ( |
| 155 | + prepare_plot_config, |
| 156 | + prepare_style_config, |
| 157 | +) |
| 158 | + |
| 159 | + |
| 160 | +def plot_postfit_shapes( |
| 161 | + hists: OrderedDict, |
| 162 | + config_inst: od.Config, |
| 163 | + category_inst: od.Category, |
| 164 | + variable_insts: list[od.Variable], |
| 165 | + style_config: dict | None = None, |
| 166 | + density: bool | None = False, |
| 167 | + shape_norm: bool | None = False, |
| 168 | + yscale: str | None = "", |
| 169 | + hide_errors: bool | None = None, |
| 170 | + process_settings: dict | None = None, |
| 171 | + variable_settings: dict | None = None, |
| 172 | + **kwargs, |
| 173 | +) -> tuple(plt.Figure, tuple(plt.Axes)): |
| 174 | + variable_inst = law.util.make_tuple(variable_insts)[0] |
| 175 | + |
| 176 | + plot_config = prepare_plot_config( |
| 177 | + hists, |
| 178 | + shape_norm=shape_norm, |
| 179 | + hide_errors=hide_errors, |
| 180 | + ) |
| 181 | + |
| 182 | + default_style_config = prepare_style_config( |
| 183 | + config_inst, category_inst, variable_inst, density, shape_norm, yscale, |
| 184 | + ) |
| 185 | + default_style_config["ax_cfg"].pop("xlim") |
| 186 | + |
| 187 | + style_config = law.util.merge_dicts(default_style_config, style_config, deep=True) |
| 188 | + if shape_norm: |
| 189 | + style_config["ax_cfg"]["ylabel"] = r"$\Delta N/N$" |
| 190 | + |
| 191 | + return plot_all(plot_config, style_config, **kwargs) |
| 192 | + |
| 193 | + |
| 194 | +class PlotPostfitShapes( |
| 195 | + HBWTask, |
| 196 | + PlotBase1D, |
| 197 | + # to correctly setup our InferenceModel, we need all these mixins, but hopefully, all these |
| 198 | + # parameters are automatically resolved correctly |
| 199 | + InferenceModelMixin, |
| 200 | + MLModelsMixin, |
| 201 | + ProducersMixin, |
| 202 | + SelectorStepsMixin, |
| 203 | + CalibratorsMixin, |
| 204 | +): |
| 205 | + """ |
| 206 | + Task that creates Postfit shape plots based on a fit_diagnostics file. |
| 207 | +
|
| 208 | + Work in Progress! |
| 209 | + TODO: |
| 210 | + - include data |
| 211 | + - include correct uncertainty bands |
| 212 | + - pass correct binning information |
| 213 | + """ |
| 214 | + |
| 215 | + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) |
| 216 | + |
| 217 | + plot_function = PlotBase.plot_function.copy( |
| 218 | + default="hbw.tasks.plotting.plot_postfit_shapes", |
| 219 | + add_default_to_description=True, |
| 220 | + ) |
| 221 | + |
| 222 | + fit_diagnostics_file = luigi.Parameter( |
| 223 | + default=law.NO_STR, |
| 224 | + description="fit_diagnostics file that is used to load histograms", |
| 225 | + ) |
| 226 | + |
| 227 | + prefit = luigi.BoolParameter( |
| 228 | + default=False, |
| 229 | + description="Whether to do prefit or postfit plots; defaults to False", |
| 230 | + ) |
| 231 | + |
| 232 | + def requires(self): |
| 233 | + return {} |
| 234 | + |
| 235 | + def output(self): |
| 236 | + return {"plots": self.target("plots", dir=True)} |
| 237 | + |
| 238 | + def run(self): |
| 239 | + logger.warning( |
| 240 | + f"Note! It is important that the requested inference_model {self.inference_model} " |
| 241 | + "is identical to the one that has been used to create the datacards", |
| 242 | + ) |
| 243 | + all_hists = load_hists_uproot(self.fit_diagnostics_file) |
| 244 | + |
| 245 | + outp = self.output() |
| 246 | + if self.prefit: |
| 247 | + fit_type = "prefit" |
| 248 | + else: |
| 249 | + fit_type = "fit_s" |
| 250 | + |
| 251 | + all_hists = all_hists[f"shapes_{fit_type}"] |
| 252 | + |
| 253 | + for channel, hists in all_hists.items(): |
| 254 | + has_category = self.inference_model_inst.has_category(channel) |
| 255 | + if not has_category: |
| 256 | + logger.warning(f"Category {channel} is not part of the inference model {self.inference_model}") |
| 257 | + |
| 258 | + for proc_key in list(hists.keys()): |
| 259 | + # remove unnecessary histograms |
| 260 | + if "data" in proc_key or "total" in proc_key: |
| 261 | + hists.pop(proc_key) |
| 262 | + continue |
| 263 | + |
| 264 | + proc_inst = None |
| 265 | + # try getting the config process via InferenceModel |
| 266 | + if has_category: |
| 267 | + # TODO: process customization based on inference process? e.g. scale |
| 268 | + inference_process = self.inference_model_inst.get_process(proc_key, channel) |
| 269 | + proc_inst = self.config_inst.get_process(inference_process.config_process) |
| 270 | + else: |
| 271 | + # try getting proc inst directly via config |
| 272 | + proc_inst = self.config_inst.get_process(proc_key, default=None) |
| 273 | + |
| 274 | + # replace string keys with process instances |
| 275 | + if proc_inst: |
| 276 | + hists[proc_inst] = hists[proc_key] |
| 277 | + hists.pop(proc_key) |
| 278 | + |
| 279 | + # try getting the config category and variable via InferenceModel |
| 280 | + if has_category: |
| 281 | + # TODO: category/variable customization based on inference model? |
| 282 | + inference_category = self.inference_model_inst.get_category(channel) |
| 283 | + config_category = self.config_inst.get_category(inference_category.config_category) |
| 284 | + variable_inst = self.config_inst.get_variable(inference_category.config_variable) |
| 285 | + else: |
| 286 | + # default to dummy Category and Variable |
| 287 | + config_category = od.Category(channel, id=1) |
| 288 | + variable_inst = od.Variable("dummy") |
| 289 | + |
| 290 | + # call the plot function |
| 291 | + fig, _ = self.call_plot_func( |
| 292 | + self.plot_function, |
| 293 | + hists=hists, |
| 294 | + config_inst=self.config_inst, |
| 295 | + category_inst=config_category, |
| 296 | + variable_insts=variable_inst, |
| 297 | + **self.get_plot_parameters(), |
| 298 | + ) |
| 299 | + |
| 300 | + outp["plots"].child(f"{channel}_{fit_type}.pdf", type="f").dump(fig, formatter="mpl") |
0 commit comments