From c4a105877e98136392d4f9008cb903152b50d8c2 Mon Sep 17 00:00:00 2001 From: Alexander Gutkin Date: Tue, 28 Jan 2025 18:20:26 +0000 Subject: [PATCH] Discrete glyph sequence prediction from phonetic embeddings. PiperOrigin-RevId: 720621793 --- protoscribe/corpus/builder/build_dataset.py | 2 +- protoscribe/evolution/make_html.py | 2 +- .../evolution/new_spellings_basic_main.py | 2 +- .../evolution/stages/build_dataset_main.py | 4 +- .../stages/new_spellings_basic_main.py | 2 +- .../stages/sketches_from_jsonl_main.py | 2 +- .../pmmx/configs/glyph_phonemes/README.md | 1 + .../arch_p1_t5_1_1_flaxformer.gin | 71 ++++++++++++++++ .../pmmx/configs/glyph_phonemes/dataset.gin | 83 +++++++++++++++++++ .../pmmx/configs/glyph_phonemes/infer.gin | 39 +++++++++ .../configs/glyph_phonemes/model_base.gin | 27 ++++++ .../configs/glyph_phonemes/model_common.gin | 73 ++++++++++++++++ .../configs/glyph_phonemes/model_tiny.gin | 27 ++++++ .../texts/generate_simple_corpus_main.py | 4 +- 14 files changed, 330 insertions(+), 9 deletions(-) create mode 100644 protoscribe/models/pmmx/configs/glyph_phonemes/README.md create mode 100644 protoscribe/models/pmmx/configs/glyph_phonemes/arch_p1_t5_1_1_flaxformer.gin create mode 100644 protoscribe/models/pmmx/configs/glyph_phonemes/dataset.gin create mode 100644 protoscribe/models/pmmx/configs/glyph_phonemes/infer.gin create mode 100644 protoscribe/models/pmmx/configs/glyph_phonemes/model_base.gin create mode 100644 protoscribe/models/pmmx/configs/glyph_phonemes/model_common.gin create mode 100644 protoscribe/models/pmmx/configs/glyph_phonemes/model_tiny.gin diff --git a/protoscribe/corpus/builder/build_dataset.py b/protoscribe/corpus/builder/build_dataset.py index 4ca00ee..ce1822a 100644 --- a/protoscribe/corpus/builder/build_dataset.py +++ b/protoscribe/corpus/builder/build_dataset.py @@ -238,7 +238,7 @@ def _prepare_language_components() -> None: language_dir = os.path.join(output_dir, "language") logging.info("Language directory: %s", language_dir) if not os.path.exists(language_dir): - os.makedirs(language_dir, exist_ok=True) + os.makedirs(language_dir) # Copy these files that are copyable verbatim. file_utils.copy_src_file("texts/configs", _NUMBER_CONFIG.value, language_dir) diff --git a/protoscribe/evolution/make_html.py b/protoscribe/evolution/make_html.py index 61a0cf2..3a09fe2 100644 --- a/protoscribe/evolution/make_html.py +++ b/protoscribe/evolution/make_html.py @@ -145,7 +145,7 @@ def make_html() -> None: output_svg_dir = os.path.join(_OUTPUT_HTML_DIR.value, "svgs") if not os.path.exists(output_svg_dir): logging.info("Making directory %s ...", output_svg_dir) - os.makedirs(output_svg_dir, exist_ok=True) + os.makedirs(output_svg_dir) # Create index page. if not _EXTENSIONS_FILE.value: diff --git a/protoscribe/evolution/new_spellings_basic_main.py b/protoscribe/evolution/new_spellings_basic_main.py index b82fd1b..f307570 100644 --- a/protoscribe/evolution/new_spellings_basic_main.py +++ b/protoscribe/evolution/new_spellings_basic_main.py @@ -72,7 +72,7 @@ def main(unused_argv): # Create a summary of all the extensions in this round. output_dir = os.path.join(_DATA_LOCATION.value, "inference_extensions") if not os.path.exists(output_dir): - os.makedirs(output_dir, exist_ok=True) + os.makedirs(output_dir) extensions_path = os.path.join(output_dir, "extensions.tsv") logging.info("Writing extensions to %s ...", extensions_path) with open(extensions_path, "w") as s: diff --git a/protoscribe/evolution/stages/build_dataset_main.py b/protoscribe/evolution/stages/build_dataset_main.py index b42e3ff..792d7ec 100644 --- a/protoscribe/evolution/stages/build_dataset_main.py +++ b/protoscribe/evolution/stages/build_dataset_main.py @@ -86,7 +86,7 @@ def _setup_builder(round_data_dir: str) -> list[tuple[str, Any]]: # language definitions from the previous round. logging.info("Making %s ...", round_data_dir) language_dir = os.path.join(round_data_dir, "language") - os.makedirs(language_dir, exist_ok=True) + os.makedirs(language_dir) file_utils.copy_dir( os.path.join(previous_data_dir, "language"), language_dir ) @@ -116,7 +116,7 @@ def _setup_builder(round_data_dir: str) -> list[tuple[str, Any]]: # At this stage it is safe to do this again. if not os.path.isdir(round_data_dir): - os.makedirs(round_data_dir, exist_ok=True) + os.makedirs(round_data_dir) logging.info( "Created `%s` for outputs for round %d.", round_data_dir, round_id ) diff --git a/protoscribe/evolution/stages/new_spellings_basic_main.py b/protoscribe/evolution/stages/new_spellings_basic_main.py index 166ebb6..7b1e050 100644 --- a/protoscribe/evolution/stages/new_spellings_basic_main.py +++ b/protoscribe/evolution/stages/new_spellings_basic_main.py @@ -149,7 +149,7 @@ def main(argv: Sequence[str]) -> None: round_data_dir, "glyph_extensions_svg" ) if not os.path.exists(output_glyph_graphics_dir): - os.makedirs(output_glyph_graphics_dir, exist_ok=True) + os.makedirs(output_glyph_graphics_dir) args.extend([ "--output_glyph_graphics_dir", svg_temp_dir.name ]) diff --git a/protoscribe/evolution/stages/sketches_from_jsonl_main.py b/protoscribe/evolution/stages/sketches_from_jsonl_main.py index 101fe31..1b220f6 100644 --- a/protoscribe/evolution/stages/sketches_from_jsonl_main.py +++ b/protoscribe/evolution/stages/sketches_from_jsonl_main.py @@ -117,7 +117,7 @@ def _sketches_and_glyphs_for_model_type( images_dir = os.path.join(output_dir, "images") logging.info("Copying sketches to %s ...", images_dir) if not os.path.exists(images_dir): - os.makedirs(images_dir, exist_ok=True) + os.makedirs(images_dir) file_utils.copy_dir(temp_dir_name, images_dir) diff --git a/protoscribe/models/pmmx/configs/glyph_phonemes/README.md b/protoscribe/models/pmmx/configs/glyph_phonemes/README.md new file mode 100644 index 0000000..7e9732c --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_phonemes/README.md @@ -0,0 +1 @@ +# Discrete glyph-only models with phonemes. diff --git a/protoscribe/models/pmmx/configs/glyph_phonemes/arch_p1_t5_1_1_flaxformer.gin b/protoscribe/models/pmmx/configs/glyph_phonemes/arch_p1_t5_1_1_flaxformer.gin new file mode 100644 index 0000000..3d5236f --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_phonemes/arch_p1_t5_1_1_flaxformer.gin @@ -0,0 +1,71 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Multimodal PMMX-One flaxformer architecture. + +from __gin__ import dynamic_registration + +import seqio +from protoscribe.pmmx import feature_converters +from protoscribe.pmmx import pmmx_architecture + +from flaxformer.architectures.t5 import t5_architecture +from flaxformer.components import embedding + +include "protoscribe/pmmx/configs/architectures/p1_t5_1_1_flaxformer.gin" + +# Architecture (Flax Module). +ARCHITECTURE = @pmmx_architecture.MultimodalEncoderDecoder() + +# Vocabulary for the encoder. +inputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary() +inputs/seqio.PassThroughVocabulary.size = 0 + +# Output vocabulary for the decoder. The `GLYPH_TOKEN_VOCAB_SIZE` corresponds +# to the real glyph token vocabulary size (all the glyphs + the special +# symbols, 314 elements) + non-administrative concepts (468 elements) + some +# provision for extra glyphs, rounded to 128 for TPU efficiency. + +GLYPH_TOKEN_VOCAB_SIZE = 1024 +END_OF_SKETCH = 2 # glyph_vocab.GLYPH_EOS +NUM_EMBEDDINGS = %GLYPH_TOKEN_VOCAB_SIZE + +outputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary() +outputs/seqio.PassThroughVocabulary.size = 0 +outputs/seqio.PassThroughVocabulary.eos_id = %END_OF_SKETCH + +# Actual multimodal encoder-decoder architecture. +pmmx_architecture.MultimodalEncoderDecoder: + encoder_factory = @pmmx_architecture.MultimodalEncoder + decoder_factory = @t5_architecture.Decoder + shared_token_embedder_factory = @token_embedder/embedding.Embed + dtype = %ACTIVATION_DTYPE + +feature_converters.MultimodalEncDecFeatureConverterFactory: + task_feature_lengths = %TASK_FEATURE_LENGTHS + feature_specs = ( + ("text.phonetic_embedding", "float32", 2), + ) + +# Encoder +pmmx_architecture.MultimodalEncoder: + feature_spec = [ + ("text.phonetic_embedding", "text.phonetic_embedding"), + ] + modality_spec = ["text.phonetic_embedding"] + modality_embedders_spec = { + "text.phonetic_embedding": [ + ("text.phonetic_embedding", @pmmx_architecture.DenseEmbed) + ], + } diff --git a/protoscribe/models/pmmx/configs/glyph_phonemes/dataset.gin b/protoscribe/models/pmmx/configs/glyph_phonemes/dataset.gin new file mode 100644 index 0000000..5b40e12 --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_phonemes/dataset.gin @@ -0,0 +1,83 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Settings for Protoscribe dataset reader using discrete sketch tokens. + +from __gin__ import dynamic_registration + +from t5x import utils + +from protoscribe.corpus.reader import tasks + +DATA_DIR = %gin.REQUIRED +TRAIN_DATA_DIR = %DATA_DIR +EVAL_DATA_DIR = %DATA_DIR +INFER_EVAL_DATA_DIR = %DATA_DIR + +MAX_GLYPH_SEQUENCE_LENGTH = 20 +MAX_PHONETIC_SEQUENCE_LENGTH = 10 + +tasks.register: + concept_embedding_type = "bnc" + glyph_only_targets = True + max_glyph_sequence_length = %MAX_GLYPH_SEQUENCE_LENGTH + max_phonetic_sequence_length = %MAX_PHONETIC_SEQUENCE_LENGTH + +train_task/tasks.register: + task_name = "bnc_glyphs.train" + dataset_dir = %TRAIN_DATA_DIR + is_training = True + noisify_embeddings = True + noisify_neftune_alphas = { + %tasks.EMBEDDING_PHONETICS: 0.01, + } + +eval_task/tasks.register: + task_name = "bnc_glyphs.eval" + dataset_dir = %EVAL_DATA_DIR + is_training = False + +infer_eval_task/tasks.register: + task_name = "bnc_glyphs.infer_eval" + dataset_dir = %INFER_EVAL_DATA_DIR + is_training = False + +TRAIN_TASK = @train_task/tasks.register() +EVAL_TASK = @eval_task/tasks.register() +INFER_EVAL_TASK = @infer_eval_task/tasks.register() +MIXTURE_OR_TASK_NAME = %TRAIN_TASK +MIXTURE_OR_TASK_MODULE = "protoscribe.corpus.reader.tasks" +USE_CACHED_TASKS = False + +TASK_FEATURE_LENGTHS = { + "text.phonetic_embedding": %MAX_PHONETIC_SEQUENCE_LENGTH, + "targets": %MAX_GLYPH_SEQUENCE_LENGTH +} + +train/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + pack = False + use_custom_packing_ops = False + +train_eval/utils.DatasetConfig: + mixture_or_task_name = %EVAL_TASK + pack = False + shuffle = False + use_custom_packing_ops = False + +infer_eval/utils.DatasetConfig: + mixture_or_task_name = %INFER_EVAL_TASK + pack = False + shuffle = False + use_custom_packing_ops = False \ No newline at end of file diff --git a/protoscribe/models/pmmx/configs/glyph_phonemes/infer.gin b/protoscribe/models/pmmx/configs/glyph_phonemes/infer.gin new file mode 100644 index 0000000..68337af --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_phonemes/infer.gin @@ -0,0 +1,39 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __gin__ import dynamic_registration + +import __main__ as infer_script +from protoscribe.sketches.inference import json_utils +from t5x import utils + +include "protoscribe/pmmx/configs/runs/infer.gin" + +utils.DatasetConfig: + mixture_or_task_name = %INFER_EVAL_TASK + +infer_script.infer: + mode = "predict_with_aux" + write_fn = @json_utils.write_inferences_to_file + merge_fn = @infer_script.merge_chunks_to_file + +json_utils.write_inferences_to_file: + include_all_inputs = False + input_fields_to_include = [ + "doc.id", + "concept.name", + "number.name", + "text.sampa", + "text.words", + ] diff --git a/protoscribe/models/pmmx/configs/glyph_phonemes/model_base.gin b/protoscribe/models/pmmx/configs/glyph_phonemes/model_base.gin new file mode 100644 index 0000000..5c66950 --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_phonemes/model_base.gin @@ -0,0 +1,27 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Base configuration. + +from __gin__ import dynamic_registration + +include "protoscribe/models/pmmx/configs/glyph_phonemes/model_common.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = 6 +NUM_DECODER_LAYERS = 6 +NUM_HEADS = 4 +HEAD_DIM = 32 +EMBED_DIM = 96 +MLP_DIM = 512 diff --git a/protoscribe/models/pmmx/configs/glyph_phonemes/model_common.gin b/protoscribe/models/pmmx/configs/glyph_phonemes/model_common.gin new file mode 100644 index 0000000..2350190 --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_phonemes/model_common.gin @@ -0,0 +1,73 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Base configuration of the model. + +from __gin__ import dynamic_registration + +from protoscribe.pmmx import feature_converters +from protoscribe.pmmx import models +from t5x import adafactor +from t5x import utils + +ARCHITECTURE = %gin.REQUIRED + +include "protoscribe/models/pmmx/configs/glyph_phonemes/arch_p1_t5_1_1_flaxformer.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = %gin.REQUIRED +NUM_DECODER_LAYERS = %gin.REQUIRED +NUM_HEADS = %gin.REQUIRED +HEAD_DIM = %gin.REQUIRED +EMBED_DIM = %gin.REQUIRED +MLP_DIM = %gin.REQUIRED + +# Optimizer +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + +# Loss defaults. +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +LOSS_NORMALIZING_FACTOR = None + +# Model +MODEL = @models.MultimodalEncoderDecoderModel() +models.MultimodalEncoderDecoderModel: + feature_converter_cls = @feature_converters.MultimodalEncDecFeatureConverterFactory() + module = %ARCHITECTURE # provided by t5_flaxformer + input_vocabulary = %inputs/PASSTHROUGH_VOCABULARY + output_vocabulary = %outputs/PASSTHROUGH_VOCABULARY + optimizer_def = %OPTIMIZER + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +# Decoding +NUM_DECODES = 5 +RETURN_ALL_DECODES = True +models.MultimodalEncoderDecoderModel.predict_batch_with_aux: + num_decodes = %NUM_DECODES + return_all_decodes = %RETURN_ALL_DECODES + +# Checkpoints +CHECKPOINT_PERIOD = 10_000 +EVAL_PERIOD = %CHECKPOINT_PERIOD +utils.SaveCheckpointConfig: + period = %CHECKPOINT_PERIOD + keep = None # Keep all checkpoints. + save_dataset = False diff --git a/protoscribe/models/pmmx/configs/glyph_phonemes/model_tiny.gin b/protoscribe/models/pmmx/configs/glyph_phonemes/model_tiny.gin new file mode 100644 index 0000000..7ab2608 --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_phonemes/model_tiny.gin @@ -0,0 +1,27 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny configuration. + +from __gin__ import dynamic_registration + +include "protoscribe/models/pmmx/configs/glyph_phonemes/model_common.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = 2 +NUM_DECODER_LAYERS = 2 +NUM_HEADS = 6 +HEAD_DIM = 16 +EMBED_DIM = 16 +MLP_DIM = 16 diff --git a/protoscribe/texts/generate_simple_corpus_main.py b/protoscribe/texts/generate_simple_corpus_main.py index 5d3fbb8..f79fbc2 100644 --- a/protoscribe/texts/generate_simple_corpus_main.py +++ b/protoscribe/texts/generate_simple_corpus_main.py @@ -88,7 +88,7 @@ def main(argv: Sequence[str]) -> None: params_dir = f"{initial_dir}/params" concepts_dir = f"{_SRC_DIR}/data/concepts" if not os.path.exists(params_dir): - os.makedirs(params_dir, exist_ok=True) + os.makedirs(params_dir) # Generate lexicon resources. logging.info("Generating the lexicon with ALL concepts in %s ...", params_dir) @@ -116,7 +116,7 @@ def main(argv: Sequence[str]) -> None: # Now generate the accounting texts. output_dir = f"{initial_dir}/output" if not os.path.exists(output_dir): - os.makedirs(output_dir, exist_ok=True) + os.makedirs(output_dir) for set_idx in range(_NUM_SETS.value): logging.info("Generating accounting texts set %d ...", set_idx) subprocess_utils.run_subprocess(