From c06f37f71253b3a9cfe555511cb8ff05b5fac547 Mon Sep 17 00:00:00 2001 From: Alexander Gutkin Date: Tue, 17 Dec 2024 20:26:22 +0000 Subject: [PATCH] Support sorting the summary by either concept names (default) or confidence. PiperOrigin-RevId: 707204409 --- .../inference/glyphs_from_jsonl_main.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/protoscribe/sketches/inference/glyphs_from_jsonl_main.py b/protoscribe/sketches/inference/glyphs_from_jsonl_main.py index 3c43889..0448a98 100644 --- a/protoscribe/sketches/inference/glyphs_from_jsonl_main.py +++ b/protoscribe/sketches/inference/glyphs_from_jsonl_main.py @@ -16,7 +16,7 @@ import json import logging -from typing import Sequence +from typing import Any, Sequence from absl import app from absl import flags @@ -48,6 +48,13 @@ required=True ) +_SORT_BY = flags.DEFINE_enum( + "sort_by", "concepts", + ["concepts", "confidence"], + "Sort the resulting summary by the specified column. Support values: " + "`concept`: input concept names, `confidence`: glyph confidence." +) + def main(argv: Sequence[str]) -> None: if len(argv) > 1: @@ -90,8 +97,17 @@ def main(argv: Sequence[str]) -> None: if num_errors: logging.warning("Encountered %d errors.", num_errors) - results = sorted(results, key=lambda x: x[0]) - logging.info("Writing results %s ...", _OUTPUT_TSV_FILE.value) + def _sort_by(result: tuple[str, list[str], dict[str, Any]]) -> str: + if _SORT_BY.value == "concepts": + return result[0] + else: + return result[2]["glyph.confidence"] + + results = sorted(results, key=_sort_by) + logging.info( + "Writing results %s (sorting by %s) ...", + _OUTPUT_TSV_FILE.value, _SORT_BY.value + ) with open(_OUTPUT_TSV_FILE.value, mode="w") as f: f.write("Input concepts\tConcept Pron\tGlyphs\tGlyph Prons\tConfidence\n") for input_concepts, output_glyphs, scorer_dict in results: