From 8226adba30e274db58b458181d228ebfc73c9412 Mon Sep 17 00:00:00 2001 From: Alexander Gutkin Date: Fri, 31 Jan 2025 19:22:23 +0000 Subject: [PATCH] Discrete glyph prediction from log-mel spectral features. PiperOrigin-RevId: 721840417 --- .../configs/glyph_logmel-spectrum/README.md | 1 + .../arch_p1_t5_1_1_flaxformer.gin | 71 ++++++++++++++++ .../configs/glyph_logmel-spectrum/dataset.gin | 85 +++++++++++++++++++ .../configs/glyph_logmel-spectrum/infer.gin | 39 +++++++++ .../glyph_logmel-spectrum/model_base.gin | 27 ++++++ .../glyph_logmel-spectrum/model_common.gin | 73 ++++++++++++++++ .../glyph_logmel-spectrum/model_tiny.gin | 27 ++++++ .../models/pmmx/model_config_gin_test.py | 2 + 8 files changed, 325 insertions(+) create mode 100644 protoscribe/models/pmmx/configs/glyph_logmel-spectrum/README.md create mode 100644 protoscribe/models/pmmx/configs/glyph_logmel-spectrum/arch_p1_t5_1_1_flaxformer.gin create mode 100644 protoscribe/models/pmmx/configs/glyph_logmel-spectrum/dataset.gin create mode 100644 protoscribe/models/pmmx/configs/glyph_logmel-spectrum/infer.gin create mode 100644 protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_base.gin create mode 100644 protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_common.gin create mode 100644 protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_tiny.gin diff --git a/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/README.md b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/README.md new file mode 100644 index 0000000..3a75320 --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/README.md @@ -0,0 +1 @@ +# Discrete glyph prediction from log-mel spectral features only. diff --git a/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/arch_p1_t5_1_1_flaxformer.gin b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/arch_p1_t5_1_1_flaxformer.gin new file mode 100644 index 0000000..9e215eb --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/arch_p1_t5_1_1_flaxformer.gin @@ -0,0 +1,71 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Multimodal PMMX-One flaxformer architecture. + +from __gin__ import dynamic_registration + +import seqio +from protoscribe.pmmx import feature_converters +from protoscribe.pmmx import pmmx_architecture + +from flaxformer.architectures.t5 import t5_architecture +from flaxformer.components import embedding + +include "protoscribe/pmmx/configs/architectures/p1_t5_1_1_flaxformer.gin" + +# Architecture (Flax Module). +ARCHITECTURE = @pmmx_architecture.MultimodalEncoderDecoder() + +# Vocabulary for the encoder. +inputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary() +inputs/seqio.PassThroughVocabulary.size = 0 + +# Output vocabulary for the decoder. The `GLYPH_TOKEN_VOCAB_SIZE` corresponds +# to the real glyph token vocabulary size (all the glyphs + the special +# symbols, 314 elements) + non-administrative concepts (468 elements) + some +# provision for extra glyphs, rounded to 128 for TPU efficiency. + +GLYPH_TOKEN_VOCAB_SIZE = 1024 +END_OF_SKETCH = 2 # glyph_vocab.GLYPH_EOS +NUM_EMBEDDINGS = %GLYPH_TOKEN_VOCAB_SIZE + +outputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary() +outputs/seqio.PassThroughVocabulary.size = 0 +outputs/seqio.PassThroughVocabulary.eos_id = %END_OF_SKETCH + +# Actual multimodal encoder-decoder architecture. +pmmx_architecture.MultimodalEncoderDecoder: + encoder_factory = @pmmx_architecture.MultimodalEncoder + decoder_factory = @t5_architecture.Decoder + shared_token_embedder_factory = @token_embedder/embedding.Embed + dtype = %ACTIVATION_DTYPE + +feature_converters.MultimodalEncDecFeatureConverterFactory: + task_feature_lengths = %TASK_FEATURE_LENGTHS + feature_specs = ( + ("speech.log_mel_spectrum", "float32", 2), + ) + +# Encoder +pmmx_architecture.MultimodalEncoder: + feature_spec = [ + ("speech.log_mel_spectrum", "speech.log_mel_spectrum"), + ] + modality_spec = ["speech.log_mel_spectrum"] + modality_embedders_spec = { + "speech.log_mel_spectrum": [ + ("speech.log_mel_spectrum", @pmmx_architecture.DenseEmbed) + ], + } diff --git a/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/dataset.gin b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/dataset.gin new file mode 100644 index 0000000..2da147d --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/dataset.gin @@ -0,0 +1,85 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Settings for Protoscribe dataset reader using discrete sketch tokens. + +from __gin__ import dynamic_registration + +from t5x import utils + +from protoscribe.corpus.reader import tasks + +DATA_DIR = %gin.REQUIRED +TRAIN_DATA_DIR = %DATA_DIR +EVAL_DATA_DIR = %DATA_DIR +INFER_EVAL_DATA_DIR = %DATA_DIR + +MAX_GLYPH_SEQUENCE_LENGTH = 20 +MAX_SPEECH_FRAME_SEQUENCE_LENGTH = 200 + +tasks.register: + concept_embedding_type = "bnc" + glyph_only_targets = True + max_glyph_sequence_length = %MAX_GLYPH_SEQUENCE_LENGTH + max_speech_frame_sequence_length = %MAX_SPEECH_FRAME_SEQUENCE_LENGTH + speech_frame_normalization = "none" + speech_spectrum_augmentation = True + speech_framework_type = "dmvr" + speech_frame_length_ms = 25.0 + speech_frame_step_ms = 10.0 + +train_task/tasks.register: + task_name = "bnc_glyphs.train" + dataset_dir = %TRAIN_DATA_DIR + is_training = True + noisify_embeddings = False + +eval_task/tasks.register: + task_name = "bnc_glyphs.eval" + dataset_dir = %EVAL_DATA_DIR + is_training = False + +infer_eval_task/tasks.register: + task_name = "bnc_glyphs.infer_eval" + dataset_dir = %INFER_EVAL_DATA_DIR + is_training = False + +TRAIN_TASK = @train_task/tasks.register() +EVAL_TASK = @eval_task/tasks.register() +INFER_EVAL_TASK = @infer_eval_task/tasks.register() +MIXTURE_OR_TASK_NAME = %TRAIN_TASK +MIXTURE_OR_TASK_MODULE = "protoscribe.corpus.reader.tasks" +USE_CACHED_TASKS = False + +TASK_FEATURE_LENGTHS = { + "speech.log_mel_spectrum": %MAX_SPEECH_FRAME_SEQUENCE_LENGTH, + "targets": %MAX_GLYPH_SEQUENCE_LENGTH +} + +train/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + pack = False + use_custom_packing_ops = False + +train_eval/utils.DatasetConfig: + mixture_or_task_name = %EVAL_TASK + pack = False + shuffle = False + use_custom_packing_ops = False + +infer_eval/utils.DatasetConfig: + mixture_or_task_name = %INFER_EVAL_TASK + pack = False + shuffle = False + use_custom_packing_ops = False \ No newline at end of file diff --git a/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/infer.gin b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/infer.gin new file mode 100644 index 0000000..68337af --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/infer.gin @@ -0,0 +1,39 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __gin__ import dynamic_registration + +import __main__ as infer_script +from protoscribe.sketches.inference import json_utils +from t5x import utils + +include "protoscribe/pmmx/configs/runs/infer.gin" + +utils.DatasetConfig: + mixture_or_task_name = %INFER_EVAL_TASK + +infer_script.infer: + mode = "predict_with_aux" + write_fn = @json_utils.write_inferences_to_file + merge_fn = @infer_script.merge_chunks_to_file + +json_utils.write_inferences_to_file: + include_all_inputs = False + input_fields_to_include = [ + "doc.id", + "concept.name", + "number.name", + "text.sampa", + "text.words", + ] diff --git a/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_base.gin b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_base.gin new file mode 100644 index 0000000..fe5ecd6 --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_base.gin @@ -0,0 +1,27 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Base configuration. + +from __gin__ import dynamic_registration + +include "protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_common.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = 12 +NUM_DECODER_LAYERS = 6 +NUM_HEADS = 12 +HEAD_DIM = 64 +EMBED_DIM = 768 +MLP_DIM = 2048 diff --git a/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_common.gin b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_common.gin new file mode 100644 index 0000000..014409d --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_common.gin @@ -0,0 +1,73 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Base configuration of the model. + +from __gin__ import dynamic_registration + +from protoscribe.pmmx import feature_converters +from protoscribe.pmmx import models +from t5x import adafactor +from t5x import utils + +ARCHITECTURE = %gin.REQUIRED + +include "protoscribe/models/pmmx/configs/glyph_logmel-spectrum/arch_p1_t5_1_1_flaxformer.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = %gin.REQUIRED +NUM_DECODER_LAYERS = %gin.REQUIRED +NUM_HEADS = %gin.REQUIRED +HEAD_DIM = %gin.REQUIRED +EMBED_DIM = %gin.REQUIRED +MLP_DIM = %gin.REQUIRED + +# Optimizer +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + +# Loss defaults. +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +LOSS_NORMALIZING_FACTOR = None + +# Model +MODEL = @models.MultimodalEncoderDecoderModel() +models.MultimodalEncoderDecoderModel: + feature_converter_cls = @feature_converters.MultimodalEncDecFeatureConverterFactory() + module = %ARCHITECTURE # provided by t5_flaxformer + input_vocabulary = %inputs/PASSTHROUGH_VOCABULARY + output_vocabulary = %outputs/PASSTHROUGH_VOCABULARY + optimizer_def = %OPTIMIZER + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +# Decoding +NUM_DECODES = 5 +RETURN_ALL_DECODES = True +models.MultimodalEncoderDecoderModel.predict_batch_with_aux: + num_decodes = %NUM_DECODES + return_all_decodes = %RETURN_ALL_DECODES + +# Checkpoints +CHECKPOINT_PERIOD = 10_000 +EVAL_PERIOD = %CHECKPOINT_PERIOD +utils.SaveCheckpointConfig: + period = %CHECKPOINT_PERIOD + keep = None # Keep all checkpoints. + save_dataset = False diff --git a/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_tiny.gin b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_tiny.gin new file mode 100644 index 0000000..da01c26 --- /dev/null +++ b/protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_tiny.gin @@ -0,0 +1,27 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny configuration. + +from __gin__ import dynamic_registration + +include "protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_common.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = 2 +NUM_DECODER_LAYERS = 2 +NUM_HEADS = 6 +HEAD_DIM = 16 +EMBED_DIM = 16 +MLP_DIM = 16 diff --git a/protoscribe/models/pmmx/model_config_gin_test.py b/protoscribe/models/pmmx/model_config_gin_test.py index 9dde4db..73690c7 100644 --- a/protoscribe/models/pmmx/model_config_gin_test.py +++ b/protoscribe/models/pmmx/model_config_gin_test.py @@ -51,6 +51,7 @@ def tearDown(self): @parameterized.parameters( "glyph_concepts", + "glyph_logmel-spectrum", "glyph_phonemes", ) def test_model_train(self, model_dir: str) -> None: @@ -75,6 +76,7 @@ def test_model_train(self, model_dir: str) -> None: @parameterized.parameters( "glyph_concepts", + "glyph_logmel-spectrum", "glyph_phonemes", ) def test_model_infer(self, model_dir: str) -> None: