Skip to content

Commit

Permalink
Discrete glyph prediction from log-mel spectral features.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721840417
  • Loading branch information
agutkin committed Jan 31, 2025
1 parent bd0a8d4 commit 8226adb
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Discrete glyph prediction from log-mel spectral features only.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Multimodal PMMX-One flaxformer architecture.

from __gin__ import dynamic_registration

import seqio
from protoscribe.pmmx import feature_converters
from protoscribe.pmmx import pmmx_architecture

from flaxformer.architectures.t5 import t5_architecture
from flaxformer.components import embedding

include "protoscribe/pmmx/configs/architectures/p1_t5_1_1_flaxformer.gin"

# Architecture (Flax Module).
ARCHITECTURE = @pmmx_architecture.MultimodalEncoderDecoder()

# Vocabulary for the encoder.
inputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary()
inputs/seqio.PassThroughVocabulary.size = 0

# Output vocabulary for the decoder. The `GLYPH_TOKEN_VOCAB_SIZE` corresponds
# to the real glyph token vocabulary size (all the glyphs + the special
# symbols, 314 elements) + non-administrative concepts (468 elements) + some
# provision for extra glyphs, rounded to 128 for TPU efficiency.

GLYPH_TOKEN_VOCAB_SIZE = 1024
END_OF_SKETCH = 2 # glyph_vocab.GLYPH_EOS
NUM_EMBEDDINGS = %GLYPH_TOKEN_VOCAB_SIZE

outputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary()
outputs/seqio.PassThroughVocabulary.size = 0
outputs/seqio.PassThroughVocabulary.eos_id = %END_OF_SKETCH

# Actual multimodal encoder-decoder architecture.
pmmx_architecture.MultimodalEncoderDecoder:
encoder_factory = @pmmx_architecture.MultimodalEncoder
decoder_factory = @t5_architecture.Decoder
shared_token_embedder_factory = @token_embedder/embedding.Embed
dtype = %ACTIVATION_DTYPE

feature_converters.MultimodalEncDecFeatureConverterFactory:
task_feature_lengths = %TASK_FEATURE_LENGTHS
feature_specs = (
("speech.log_mel_spectrum", "float32", 2),
)

# Encoder
pmmx_architecture.MultimodalEncoder:
feature_spec = [
("speech.log_mel_spectrum", "speech.log_mel_spectrum"),
]
modality_spec = ["speech.log_mel_spectrum"]
modality_embedders_spec = {
"speech.log_mel_spectrum": [
("speech.log_mel_spectrum", @pmmx_architecture.DenseEmbed)
],
}
85 changes: 85 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_logmel-spectrum/dataset.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Settings for Protoscribe dataset reader using discrete sketch tokens.

from __gin__ import dynamic_registration

from t5x import utils

from protoscribe.corpus.reader import tasks

DATA_DIR = %gin.REQUIRED
TRAIN_DATA_DIR = %DATA_DIR
EVAL_DATA_DIR = %DATA_DIR
INFER_EVAL_DATA_DIR = %DATA_DIR

MAX_GLYPH_SEQUENCE_LENGTH = 20
MAX_SPEECH_FRAME_SEQUENCE_LENGTH = 200

tasks.register:
concept_embedding_type = "bnc"
glyph_only_targets = True
max_glyph_sequence_length = %MAX_GLYPH_SEQUENCE_LENGTH
max_speech_frame_sequence_length = %MAX_SPEECH_FRAME_SEQUENCE_LENGTH
speech_frame_normalization = "none"
speech_spectrum_augmentation = True
speech_framework_type = "dmvr"
speech_frame_length_ms = 25.0
speech_frame_step_ms = 10.0

train_task/tasks.register:
task_name = "bnc_glyphs.train"
dataset_dir = %TRAIN_DATA_DIR
is_training = True
noisify_embeddings = False

eval_task/tasks.register:
task_name = "bnc_glyphs.eval"
dataset_dir = %EVAL_DATA_DIR
is_training = False

infer_eval_task/tasks.register:
task_name = "bnc_glyphs.infer_eval"
dataset_dir = %INFER_EVAL_DATA_DIR
is_training = False

TRAIN_TASK = @train_task/tasks.register()
EVAL_TASK = @eval_task/tasks.register()
INFER_EVAL_TASK = @infer_eval_task/tasks.register()
MIXTURE_OR_TASK_NAME = %TRAIN_TASK
MIXTURE_OR_TASK_MODULE = "protoscribe.corpus.reader.tasks"
USE_CACHED_TASKS = False

TASK_FEATURE_LENGTHS = {
"speech.log_mel_spectrum": %MAX_SPEECH_FRAME_SEQUENCE_LENGTH,
"targets": %MAX_GLYPH_SEQUENCE_LENGTH
}

train/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
pack = False
use_custom_packing_ops = False

train_eval/utils.DatasetConfig:
mixture_or_task_name = %EVAL_TASK
pack = False
shuffle = False
use_custom_packing_ops = False

infer_eval/utils.DatasetConfig:
mixture_or_task_name = %INFER_EVAL_TASK
pack = False
shuffle = False
use_custom_packing_ops = False
39 changes: 39 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_logmel-spectrum/infer.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __gin__ import dynamic_registration

import __main__ as infer_script
from protoscribe.sketches.inference import json_utils
from t5x import utils

include "protoscribe/pmmx/configs/runs/infer.gin"

utils.DatasetConfig:
mixture_or_task_name = %INFER_EVAL_TASK

infer_script.infer:
mode = "predict_with_aux"
write_fn = @json_utils.write_inferences_to_file
merge_fn = @infer_script.merge_chunks_to_file

json_utils.write_inferences_to_file:
include_all_inputs = False
input_fields_to_include = [
"doc.id",
"concept.name",
"number.name",
"text.sampa",
"text.words",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base configuration.

from __gin__ import dynamic_registration

include "protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_common.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = 12
NUM_DECODER_LAYERS = 6
NUM_HEADS = 12
HEAD_DIM = 64
EMBED_DIM = 768
MLP_DIM = 2048
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base configuration of the model.

from __gin__ import dynamic_registration

from protoscribe.pmmx import feature_converters
from protoscribe.pmmx import models
from t5x import adafactor
from t5x import utils

ARCHITECTURE = %gin.REQUIRED

include "protoscribe/models/pmmx/configs/glyph_logmel-spectrum/arch_p1_t5_1_1_flaxformer.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = %gin.REQUIRED
NUM_DECODER_LAYERS = %gin.REQUIRED
NUM_HEADS = %gin.REQUIRED
HEAD_DIM = %gin.REQUIRED
EMBED_DIM = %gin.REQUIRED
MLP_DIM = %gin.REQUIRED

# Optimizer
# `learning_rate` is set by `Trainer.learning_rate_fn`.
OPTIMIZER = @adafactor.Adafactor()
adafactor.Adafactor:
decay_rate = 0.8
step_offset = 0

# Loss defaults.
Z_LOSS = 0.0001
LABEL_SMOOTHING = 0.0
LOSS_NORMALIZING_FACTOR = None

# Model
MODEL = @models.MultimodalEncoderDecoderModel()
models.MultimodalEncoderDecoderModel:
feature_converter_cls = @feature_converters.MultimodalEncDecFeatureConverterFactory()
module = %ARCHITECTURE # provided by t5_flaxformer
input_vocabulary = %inputs/PASSTHROUGH_VOCABULARY
output_vocabulary = %outputs/PASSTHROUGH_VOCABULARY
optimizer_def = %OPTIMIZER
z_loss = %Z_LOSS
label_smoothing = %LABEL_SMOOTHING
loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR

# Decoding
NUM_DECODES = 5
RETURN_ALL_DECODES = True
models.MultimodalEncoderDecoderModel.predict_batch_with_aux:
num_decodes = %NUM_DECODES
return_all_decodes = %RETURN_ALL_DECODES

# Checkpoints
CHECKPOINT_PERIOD = 10_000
EVAL_PERIOD = %CHECKPOINT_PERIOD
utils.SaveCheckpointConfig:
period = %CHECKPOINT_PERIOD
keep = None # Keep all checkpoints.
save_dataset = False
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny configuration.

from __gin__ import dynamic_registration

include "protoscribe/models/pmmx/configs/glyph_logmel-spectrum/model_common.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2
NUM_HEADS = 6
HEAD_DIM = 16
EMBED_DIM = 16
MLP_DIM = 16
2 changes: 2 additions & 0 deletions protoscribe/models/pmmx/model_config_gin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def tearDown(self):

@parameterized.parameters(
"glyph_concepts",
"glyph_logmel-spectrum",
"glyph_phonemes",
)
def test_model_train(self, model_dir: str) -> None:
Expand All @@ -75,6 +76,7 @@ def test_model_train(self, model_dir: str) -> None:

@parameterized.parameters(
"glyph_concepts",
"glyph_logmel-spectrum",
"glyph_phonemes",
)
def test_model_infer(self, model_dir: str) -> None:
Expand Down

0 comments on commit 8226adb

Please sign in to comment.