|
| 1 | +# coding: utf-8 |
| 2 | + |
| 3 | +""" |
| 4 | +Tasks to produce yield tables. |
| 5 | +Taken from columnflow and customized. |
| 6 | +TODO: after merging MultiConfigPlotting, we should propagate the changes back to columnflow |
| 7 | +""" |
| 8 | + |
| 9 | +import math |
| 10 | +from collections import defaultdict, OrderedDict |
| 11 | + |
| 12 | +import law |
| 13 | +import luigi |
| 14 | +import order as od |
| 15 | +from scinum import Number |
| 16 | + |
| 17 | +from columnflow.tasks.framework.base import Requirements, ConfigTask |
| 18 | +from columnflow.tasks.framework.mixins import ( |
| 19 | + CalibratorsMixin, SelectorStepsMixin, ProducersMixin, |
| 20 | + DatasetsProcessesMixin, CategoriesMixin, WeightProducerMixin, |
| 21 | +) |
| 22 | +from columnflow.tasks.framework.remote import RemoteWorkflow |
| 23 | +from columnflow.tasks.histograms import MergeHistograms |
| 24 | +from columnflow.util import dev_sandbox, try_int |
| 25 | + |
| 26 | +from hbw.tasks.base import HBWTask |
| 27 | + |
| 28 | +logger = law.logger.get_logger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +class CustomCreateYieldTable( |
| 32 | + HBWTask, |
| 33 | + DatasetsProcessesMixin, |
| 34 | + CategoriesMixin, |
| 35 | + WeightProducerMixin, |
| 36 | + ProducersMixin, |
| 37 | + # MLModelsMixin, |
| 38 | + SelectorStepsMixin, |
| 39 | + CalibratorsMixin, |
| 40 | + ConfigTask, |
| 41 | + law.LocalWorkflow, |
| 42 | + RemoteWorkflow, |
| 43 | +): |
| 44 | + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) |
| 45 | + |
| 46 | + yields_variable = "mll" |
| 47 | + |
| 48 | + table_format = luigi.Parameter( |
| 49 | + default="fancy_grid", |
| 50 | + significant=False, |
| 51 | + description="format of the yield table; accepts all formats of the tabulate package; " |
| 52 | + "default: fancy_grid", |
| 53 | + ) |
| 54 | + number_format = luigi.Parameter( |
| 55 | + default="pdg", |
| 56 | + significant=False, |
| 57 | + description="rounding format of each number in the yield table; accepts all formats " |
| 58 | + "understood by scinum.Number.str(), e.g. 'pdg', 'publication', '%.1f' or an integer " |
| 59 | + "(number of signficant digits); default: pdg", |
| 60 | + ) |
| 61 | + skip_uncertainties = luigi.BoolParameter( |
| 62 | + default=False, |
| 63 | + significant=False, |
| 64 | + description="when True, uncertainties are not displayed in the table; default: False", |
| 65 | + ) |
| 66 | + normalize_yields = luigi.ChoiceParameter( |
| 67 | + choices=(law.NO_STR, "per_process", "per_category", "all"), |
| 68 | + default=law.NO_STR, |
| 69 | + significant=False, |
| 70 | + description="string parameter to define the normalization of the yields; " |
| 71 | + "choices: '', per_process, per_category, all; empty default", |
| 72 | + ) |
| 73 | + output_suffix = luigi.Parameter( |
| 74 | + default=law.NO_STR, |
| 75 | + description="Adds a suffix to the output name of the yields table; empty default", |
| 76 | + ) |
| 77 | + transpose = luigi.BoolParameter( |
| 78 | + default=False, |
| 79 | + significant=False, |
| 80 | + description="Transpose the yield table; default: False", |
| 81 | + ) |
| 82 | + ratio = law.CSVParameter( |
| 83 | + default=("data", "background"), |
| 84 | + significant=False, |
| 85 | + description="Ratio of two processes to be calculated and added to the table", |
| 86 | + ) |
| 87 | + |
| 88 | + # upstream requirements |
| 89 | + reqs = Requirements( |
| 90 | + RemoteWorkflow.reqs, |
| 91 | + MergeHistograms=MergeHistograms, |
| 92 | + ) |
| 93 | + |
| 94 | + # dummy branch map |
| 95 | + def create_branch_map(self): |
| 96 | + return [0] |
| 97 | + |
| 98 | + def requires(self): |
| 99 | + return { |
| 100 | + d: self.reqs.MergeHistograms.req( |
| 101 | + self, |
| 102 | + dataset=d, |
| 103 | + variables=(self.yields_variable,), |
| 104 | + _prefer_cli={"variables"}, |
| 105 | + ) |
| 106 | + for d in self.datasets |
| 107 | + } |
| 108 | + |
| 109 | + def workflow_requires(self): |
| 110 | + reqs = super().workflow_requires() |
| 111 | + |
| 112 | + reqs["merged_hists"] = [ |
| 113 | + self.reqs.MergeHistograms.req( |
| 114 | + self, |
| 115 | + dataset=d, |
| 116 | + variables=(self.yields_variable,), |
| 117 | + _exclude={"branches"}, |
| 118 | + ) |
| 119 | + for d in self.datasets |
| 120 | + ] |
| 121 | + |
| 122 | + return reqs |
| 123 | + |
| 124 | + @classmethod |
| 125 | + def resolve_param_values(cls, params): |
| 126 | + params = super().resolve_param_values(params) |
| 127 | + |
| 128 | + if "number_format" in params and try_int(params["number_format"]): |
| 129 | + # convert 'number_format' in integer if possible |
| 130 | + params["number_format"] = int(params["number_format"]) |
| 131 | + |
| 132 | + return params |
| 133 | + |
| 134 | + def output(self): |
| 135 | + suffix = "" |
| 136 | + if self.output_suffix and self.output_suffix != law.NO_STR: |
| 137 | + suffix = f"__{self.output_suffix}" |
| 138 | + |
| 139 | + return { |
| 140 | + "table": self.target(f"table__proc_{self.processes_repr}__cat_{self.categories_repr}{suffix}.txt"), |
| 141 | + "csv": self.target(f"table__proc_{self.processes_repr}__cat_{self.categories_repr}{suffix}.csv"), |
| 142 | + "yields": self.target(f"yields__proc_{self.processes_repr}__cat_{self.categories_repr}{suffix}.json"), |
| 143 | + } |
| 144 | + |
| 145 | + @law.decorator.notify |
| 146 | + @law.decorator.log |
| 147 | + def run(self): |
| 148 | + import hist |
| 149 | + from tabulate import tabulate |
| 150 | + |
| 151 | + inputs = self.input() |
| 152 | + outputs = self.output() |
| 153 | + |
| 154 | + category_insts = list(map(self.config_inst.get_category, sorted(self.categories))) |
| 155 | + process_insts = list(map(self.config_inst.get_process, self.processes)) |
| 156 | + sub_process_insts = { |
| 157 | + proc: [sub for sub, _, _ in proc.walk_processes(include_self=True)] |
| 158 | + for proc in process_insts |
| 159 | + } |
| 160 | + |
| 161 | + # histogram data per process |
| 162 | + hists = {} |
| 163 | + |
| 164 | + with self.publish_step(f"Creating yields for processes {self.processes}, categories {self.categories}"): |
| 165 | + for dataset, inp in inputs.items(): |
| 166 | + dataset_inst = self.config_inst.get_dataset(dataset) |
| 167 | + |
| 168 | + # load the histogram of the variable named self.yields_variable |
| 169 | + h_in = inp["hists"][self.yields_variable].load(formatter="pickle") |
| 170 | + |
| 171 | + # loop and extract one histogram per process |
| 172 | + for process_inst in process_insts: |
| 173 | + # skip when the dataset is already known to not contain any sub process |
| 174 | + if not any(map(dataset_inst.has_process, sub_process_insts[process_inst])): |
| 175 | + continue |
| 176 | + |
| 177 | + # work on a copy |
| 178 | + h = h_in.copy() |
| 179 | + |
| 180 | + # axis selections |
| 181 | + h = h[{ |
| 182 | + "process": [ |
| 183 | + hist.loc(p.id) |
| 184 | + for p in sub_process_insts[process_inst] |
| 185 | + if p.id in h.axes["process"] |
| 186 | + ], |
| 187 | + }] |
| 188 | + |
| 189 | + # axis reductions |
| 190 | + h = h[{"process": sum, "shift": sum, self.yields_variable: sum}] |
| 191 | + |
| 192 | + # add the histogram |
| 193 | + if process_inst in hists: |
| 194 | + hists[process_inst] += h |
| 195 | + else: |
| 196 | + hists[process_inst] = h |
| 197 | + |
| 198 | + # there should be hists to plot |
| 199 | + if not hists: |
| 200 | + raise Exception("no histograms found to plot") |
| 201 | + |
| 202 | + # sort hists by process order |
| 203 | + hists = OrderedDict( |
| 204 | + (process_inst, hists[process_inst]) |
| 205 | + for process_inst in sorted(hists, key=process_insts.index) |
| 206 | + ) |
| 207 | + |
| 208 | + yields, processes = defaultdict(list), [] |
| 209 | + |
| 210 | + # read out yields per category and per process |
| 211 | + for process_inst, h in hists.items(): |
| 212 | + processes.append(process_inst) |
| 213 | + |
| 214 | + # TODO: ratio |
| 215 | + |
| 216 | + for category_inst in category_insts: |
| 217 | + leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] |
| 218 | + |
| 219 | + h_cat = h[{"category": [ |
| 220 | + hist.loc(c.name) |
| 221 | + for c in leaf_category_insts |
| 222 | + if c.name in h.axes["category"] |
| 223 | + ]}] |
| 224 | + h_cat = h_cat[{"category": sum}] |
| 225 | + |
| 226 | + value = Number(h_cat.value) |
| 227 | + if not self.skip_uncertainties: |
| 228 | + # set a unique uncertainty name for correct propagation below |
| 229 | + value.set_uncertainty( |
| 230 | + f"mcstat_{process_inst.name}_{category_inst.name}", |
| 231 | + math.sqrt(h_cat.variance), |
| 232 | + ) |
| 233 | + yields[category_inst].append(value) |
| 234 | + |
| 235 | + if self.ratio: |
| 236 | + processes.append(od.Process("Ratio", id=-9871, label=f"{self.ratio[0]} / {self.ratio[1]}")) |
| 237 | + num_idx, den_idxs = [processes.index(self.config_inst.get_process(_p)) for _p in self.ratio] |
| 238 | + for category_inst in category_insts: |
| 239 | + num = yields[category_inst][num_idx] |
| 240 | + den = yields[category_inst][den_idxs] |
| 241 | + yields[category_inst].append(num / den) |
| 242 | + |
| 243 | + # obtain normalizaton factors |
| 244 | + norm_factors = 1 |
| 245 | + if self.normalize_yields == "all": |
| 246 | + norm_factors = sum( |
| 247 | + sum(category_yields) |
| 248 | + for category_yields in yields.values() |
| 249 | + ) |
| 250 | + elif self.normalize_yields == "per_process": |
| 251 | + norm_factors = [ |
| 252 | + sum(yields[category][i] for category in yields.keys()) |
| 253 | + for i in range(len(yields[category_insts[0]])) |
| 254 | + ] |
| 255 | + elif self.normalize_yields == "per_category": |
| 256 | + norm_factors = { |
| 257 | + category: sum(category_yields) |
| 258 | + for category, category_yields in yields.items() |
| 259 | + } |
| 260 | + |
| 261 | + # initialize dicts |
| 262 | + main_label = "Category" if self.transpose else "Process" |
| 263 | + yields_str = defaultdict(list, {main_label: [proc.label for proc in processes]}) |
| 264 | + raw_yields = defaultdict(dict, {}) |
| 265 | + |
| 266 | + # apply normalization and format |
| 267 | + for category, category_yields in yields.items(): |
| 268 | + for i, value in enumerate(category_yields): |
| 269 | + # get correct norm factor per category and process |
| 270 | + if self.normalize_yields == "per_process": |
| 271 | + norm_factor = norm_factors[i] |
| 272 | + elif self.normalize_yields == "per_category": |
| 273 | + norm_factor = norm_factors[category] |
| 274 | + else: |
| 275 | + norm_factor = norm_factors |
| 276 | + |
| 277 | + raw_yield = (value / norm_factor).nominal |
| 278 | + raw_yields[category.name][processes[i].name] = raw_yield |
| 279 | + |
| 280 | + # format yields into strings |
| 281 | + yield_str = (value / norm_factor).str( |
| 282 | + combine_uncs="all", |
| 283 | + format=self.number_format, |
| 284 | + style="latex" if "latex" in self.table_format else "plain", |
| 285 | + ) |
| 286 | + if "latex" in self.table_format: |
| 287 | + yield_str = f"${yield_str}$" |
| 288 | + cat_label = category.name.replace("__", " ") |
| 289 | + # cat_label = category.label |
| 290 | + yields_str[cat_label].append(yield_str) |
| 291 | + |
| 292 | + # Transposing the table |
| 293 | + data = [list(yields_str.keys())] + list(zip(*yields_str.values())) |
| 294 | + if self.transpose: |
| 295 | + data = list(zip(*data)) |
| 296 | + |
| 297 | + headers = data[0] |
| 298 | + |
| 299 | + # create, print and save the yield table |
| 300 | + yield_table = tabulate(data[1:], headers=headers, tablefmt=self.table_format) |
| 301 | + |
| 302 | + with_grid = True |
| 303 | + if with_grid and self.table_format == "latex_raw": |
| 304 | + # identify line breaks and add hlines after every line break |
| 305 | + yield_table = yield_table.replace("\\\\", "\\\\ \\hline") |
| 306 | + # TODO: lll -> |l|l|l|, etc. |
| 307 | + self.publish_message(yield_table) |
| 308 | + |
| 309 | + outputs["table"].dump(yield_table, formatter="text") |
| 310 | + outputs["yields"].dump(raw_yields, formatter="json") |
| 311 | + |
| 312 | + outputs["csv"].touch() |
| 313 | + with open(outputs["csv"].abspath, "w", newline="") as csvfile: |
| 314 | + import csv |
| 315 | + writer = csv.writer(csvfile) |
| 316 | + writer.writerows(data) |
0 commit comments