From 3e47cd9ac3005b4d1fbda1bd3b9a44f8c513dae7 Mon Sep 17 00:00:00 2001 From: Alexander Gutkin Date: Wed, 8 Jan 2025 16:32:44 +0000 Subject: [PATCH] Discrete glyph prediction results post-processing tool. PiperOrigin-RevId: 713301809 --- .../stages/glyphs_from_jsonl_main.py | 151 ++++++++++++++++++ protoscribe/utils/file_utils.py | 23 +++ 2 files changed, 174 insertions(+) create mode 100644 protoscribe/evolution/stages/glyphs_from_jsonl_main.py diff --git a/protoscribe/evolution/stages/glyphs_from_jsonl_main.py b/protoscribe/evolution/stages/glyphs_from_jsonl_main.py new file mode 100644 index 0000000..30b4ab1 --- /dev/null +++ b/protoscribe/evolution/stages/glyphs_from_jsonl_main.py @@ -0,0 +1,151 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage-specific helper for postprocessing discrete glyph inference results. + +This tool is intended to be used on outputs of discrete glyph predictor. +""" + +from collections.abc import Sequence +import logging +import os + +from absl import app +from absl import flags +from protoscribe.evolution.stages import common_flags +from protoscribe.utils import file_utils +from protoscribe.utils import subprocess_utils + +_EXPERIMENT_NAME = flags.DEFINE_string( + "experiment_name", None, + "An experiment name which will define the directory in which the " + "evolving system data is placed.", + required=True +) + +_SEMANTICS_XID = flags.DEFINE_string( + "semantics_xid", None, + "XManager job ID for the inference run with semantics model for this round." +) + +_PHONETICS_XID = flags.DEFINE_string( + "phonetics_xid", None, + "XManager job ID for the inference run with phonetics model for this round." +) + +_JSONL_FILE_NAME_GLYPHS = flags.DEFINE_string( + "jsonl_file_name_glyphs", None, + "File name used for storing the outputs of glyph inference.", + required=True +) + +# Actual inference post-processing tool. +_GLYPHS_TOOL = ( + "protoscribe/sketches/inference/glyphs_from_jsonl" +) + +# Discrete glyph prediction mode. +_MODE = "glyph" + + +def _setup_inference_directories( + round_data_dir: str, + experiment_name: str, + experiment_id: str | None +) -> str: + """Sets up the directory for storing the post-processed inference outputs. + + Args: + round_data_dir: Data directory for this round. + experiment_name: Symbol name for the experiment. + experiment_id: XManager Job ID (integer string). + + Returns: + Output directory where postprocessed results will be stored. + + Raises: + ValueError if output directory could not be determined. + """ + output_dir = os.path.join( + round_data_dir, f"{experiment_name}:inference_outputs" + ) + if experiment_id: + output_dir = os.path.join(output_dir, experiment_id) + else: + experiment_dirs = file_utils.list_subdirs(output_dir) + if not experiment_dirs: + raise ValueError( + f"No inference experiment directories found under {output_dir}!" + ) + output_dir = experiment_dirs[-1] + + logging.info("Reading and writing output data to %s ...", output_dir) + return output_dir + + +def _glyphs_for_model_type( + round_data_dir: str, model_type: str, experiment_id: str +) -> None: + """Run glyph extractions from the inference run for a given model type. + + Args: + round_data_dir: Data directory for this round. + model_type: Type of the model. + experiment_id: XManager job ID. + """ + round_id = common_flags.ROUND.value + experiment_name = ( + f"{_EXPERIMENT_NAME.value}:{round_id}:{_MODE}_{model_type}" + ) + output_dir = _setup_inference_directories( + round_data_dir=round_data_dir, + experiment_name=experiment_name, + experiment_id=experiment_id + ) + jsonl_file = os.path.join(output_dir, _JSONL_FILE_NAME_GLYPHS.value) + subprocess_utils.run_subprocess( + _GLYPHS_TOOL, + args=[ + "--dataset_dir", round_data_dir, + "--input_jsonl_file", jsonl_file, + "--output_tsv_file", f"{output_dir}/results.tsv", + "--output_file_for_scorer", f"{output_dir}/results.jsonl", + "--ignore_errors", True, + ] + ) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + round_data_dir = common_flags.round_data_dir() + logging.info("Using data location: %s", round_data_dir) + + # Post-process inference results for the semantic stream. + _glyphs_for_model_type( + round_data_dir=round_data_dir, + model_type=common_flags.SEMANTIC_MODEL.value, + experiment_id=_SEMANTICS_XID.value + ) + # Post-process inference results for the phonetic stream. + _glyphs_for_model_type( + round_data_dir=round_data_dir, + model_type=common_flags.PHONETIC_MODEL.value, + experiment_id=_PHONETICS_XID.value + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/protoscribe/utils/file_utils.py b/protoscribe/utils/file_utils.py index 29b1f71..2a4d1ca 100644 --- a/protoscribe/utils/file_utils.py +++ b/protoscribe/utils/file_utils.py @@ -135,3 +135,26 @@ def copy_dir(source_dir: str, target_dir: str) -> None: if not os.path.isdir(path): source_paths.append(path) copy_files(source_paths, target_dir) + + +def list_subdirs(root_dir: str) -> list[str]: + """Retrieves all subdirectories under the specified root. + + Root directory must exist. + + Args: + root_dir: Root directory. + + Returns: + A list of subdirectories (may be empty). + """ + if not os.path.isdir(root_dir): + raise ValueError(f"Source directory {root_dir} does not exist!") + + logging.info("Searching for subdirs of `%s` ...", root_dir) + subdir_paths = [] + for path in glob.glob(f"{root_dir}/*"): + if os.path.isdir(path): + subdir_paths.append(path) + + return subdir_paths