Skip to content

Commit

Permalink
Discrete glyph sequence prediction from phonetic embeddings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720621793
  • Loading branch information
agutkin committed Jan 28, 2025
1 parent 84a2346 commit c4a1058
Show file tree
Hide file tree
Showing 14 changed files with 330 additions and 9 deletions.
2 changes: 1 addition & 1 deletion protoscribe/corpus/builder/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _prepare_language_components() -> None:
language_dir = os.path.join(output_dir, "language")
logging.info("Language directory: %s", language_dir)
if not os.path.exists(language_dir):
os.makedirs(language_dir, exist_ok=True)
os.makedirs(language_dir)

# Copy these files that are copyable verbatim.
file_utils.copy_src_file("texts/configs", _NUMBER_CONFIG.value, language_dir)
Expand Down
2 changes: 1 addition & 1 deletion protoscribe/evolution/make_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def make_html() -> None:
output_svg_dir = os.path.join(_OUTPUT_HTML_DIR.value, "svgs")
if not os.path.exists(output_svg_dir):
logging.info("Making directory %s ...", output_svg_dir)
os.makedirs(output_svg_dir, exist_ok=True)
os.makedirs(output_svg_dir)

# Create index page.
if not _EXTENSIONS_FILE.value:
Expand Down
2 changes: 1 addition & 1 deletion protoscribe/evolution/new_spellings_basic_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main(unused_argv):
# Create a summary of all the extensions in this round.
output_dir = os.path.join(_DATA_LOCATION.value, "inference_extensions")
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir)
extensions_path = os.path.join(output_dir, "extensions.tsv")
logging.info("Writing extensions to %s ...", extensions_path)
with open(extensions_path, "w") as s:
Expand Down
4 changes: 2 additions & 2 deletions protoscribe/evolution/stages/build_dataset_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _setup_builder(round_data_dir: str) -> list[tuple[str, Any]]:
# language definitions from the previous round.
logging.info("Making %s ...", round_data_dir)
language_dir = os.path.join(round_data_dir, "language")
os.makedirs(language_dir, exist_ok=True)
os.makedirs(language_dir)
file_utils.copy_dir(
os.path.join(previous_data_dir, "language"), language_dir
)
Expand Down Expand Up @@ -116,7 +116,7 @@ def _setup_builder(round_data_dir: str) -> list[tuple[str, Any]]:

# At this stage it is safe to do this again.
if not os.path.isdir(round_data_dir):
os.makedirs(round_data_dir, exist_ok=True)
os.makedirs(round_data_dir)
logging.info(
"Created `%s` for outputs for round %d.", round_data_dir, round_id
)
Expand Down
2 changes: 1 addition & 1 deletion protoscribe/evolution/stages/new_spellings_basic_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def main(argv: Sequence[str]) -> None:
round_data_dir, "glyph_extensions_svg"
)
if not os.path.exists(output_glyph_graphics_dir):
os.makedirs(output_glyph_graphics_dir, exist_ok=True)
os.makedirs(output_glyph_graphics_dir)
args.extend([
"--output_glyph_graphics_dir", svg_temp_dir.name
])
Expand Down
2 changes: 1 addition & 1 deletion protoscribe/evolution/stages/sketches_from_jsonl_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _sketches_and_glyphs_for_model_type(
images_dir = os.path.join(output_dir, "images")
logging.info("Copying sketches to %s ...", images_dir)
if not os.path.exists(images_dir):
os.makedirs(images_dir, exist_ok=True)
os.makedirs(images_dir)
file_utils.copy_dir(temp_dir_name, images_dir)


Expand Down
1 change: 1 addition & 0 deletions protoscribe/models/pmmx/configs/glyph_phonemes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Discrete glyph-only models with phonemes.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Multimodal PMMX-One flaxformer architecture.

from __gin__ import dynamic_registration

import seqio
from protoscribe.pmmx import feature_converters
from protoscribe.pmmx import pmmx_architecture

from flaxformer.architectures.t5 import t5_architecture
from flaxformer.components import embedding

include "protoscribe/pmmx/configs/architectures/p1_t5_1_1_flaxformer.gin"

# Architecture (Flax Module).
ARCHITECTURE = @pmmx_architecture.MultimodalEncoderDecoder()

# Vocabulary for the encoder.
inputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary()
inputs/seqio.PassThroughVocabulary.size = 0

# Output vocabulary for the decoder. The `GLYPH_TOKEN_VOCAB_SIZE` corresponds
# to the real glyph token vocabulary size (all the glyphs + the special
# symbols, 314 elements) + non-administrative concepts (468 elements) + some
# provision for extra glyphs, rounded to 128 for TPU efficiency.

GLYPH_TOKEN_VOCAB_SIZE = 1024
END_OF_SKETCH = 2 # glyph_vocab.GLYPH_EOS
NUM_EMBEDDINGS = %GLYPH_TOKEN_VOCAB_SIZE

outputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary()
outputs/seqio.PassThroughVocabulary.size = 0
outputs/seqio.PassThroughVocabulary.eos_id = %END_OF_SKETCH

# Actual multimodal encoder-decoder architecture.
pmmx_architecture.MultimodalEncoderDecoder:
encoder_factory = @pmmx_architecture.MultimodalEncoder
decoder_factory = @t5_architecture.Decoder
shared_token_embedder_factory = @token_embedder/embedding.Embed
dtype = %ACTIVATION_DTYPE

feature_converters.MultimodalEncDecFeatureConverterFactory:
task_feature_lengths = %TASK_FEATURE_LENGTHS
feature_specs = (
("text.phonetic_embedding", "float32", 2),
)

# Encoder
pmmx_architecture.MultimodalEncoder:
feature_spec = [
("text.phonetic_embedding", "text.phonetic_embedding"),
]
modality_spec = ["text.phonetic_embedding"]
modality_embedders_spec = {
"text.phonetic_embedding": [
("text.phonetic_embedding", @pmmx_architecture.DenseEmbed)
],
}
83 changes: 83 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_phonemes/dataset.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Settings for Protoscribe dataset reader using discrete sketch tokens.

from __gin__ import dynamic_registration

from t5x import utils

from protoscribe.corpus.reader import tasks

DATA_DIR = %gin.REQUIRED
TRAIN_DATA_DIR = %DATA_DIR
EVAL_DATA_DIR = %DATA_DIR
INFER_EVAL_DATA_DIR = %DATA_DIR

MAX_GLYPH_SEQUENCE_LENGTH = 20
MAX_PHONETIC_SEQUENCE_LENGTH = 10

tasks.register:
concept_embedding_type = "bnc"
glyph_only_targets = True
max_glyph_sequence_length = %MAX_GLYPH_SEQUENCE_LENGTH
max_phonetic_sequence_length = %MAX_PHONETIC_SEQUENCE_LENGTH

train_task/tasks.register:
task_name = "bnc_glyphs.train"
dataset_dir = %TRAIN_DATA_DIR
is_training = True
noisify_embeddings = True
noisify_neftune_alphas = {
%tasks.EMBEDDING_PHONETICS: 0.01,
}

eval_task/tasks.register:
task_name = "bnc_glyphs.eval"
dataset_dir = %EVAL_DATA_DIR
is_training = False

infer_eval_task/tasks.register:
task_name = "bnc_glyphs.infer_eval"
dataset_dir = %INFER_EVAL_DATA_DIR
is_training = False

TRAIN_TASK = @train_task/tasks.register()
EVAL_TASK = @eval_task/tasks.register()
INFER_EVAL_TASK = @infer_eval_task/tasks.register()
MIXTURE_OR_TASK_NAME = %TRAIN_TASK
MIXTURE_OR_TASK_MODULE = "protoscribe.corpus.reader.tasks"
USE_CACHED_TASKS = False

TASK_FEATURE_LENGTHS = {
"text.phonetic_embedding": %MAX_PHONETIC_SEQUENCE_LENGTH,
"targets": %MAX_GLYPH_SEQUENCE_LENGTH
}

train/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
pack = False
use_custom_packing_ops = False

train_eval/utils.DatasetConfig:
mixture_or_task_name = %EVAL_TASK
pack = False
shuffle = False
use_custom_packing_ops = False

infer_eval/utils.DatasetConfig:
mixture_or_task_name = %INFER_EVAL_TASK
pack = False
shuffle = False
use_custom_packing_ops = False
39 changes: 39 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_phonemes/infer.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __gin__ import dynamic_registration

import __main__ as infer_script
from protoscribe.sketches.inference import json_utils
from t5x import utils

include "protoscribe/pmmx/configs/runs/infer.gin"

utils.DatasetConfig:
mixture_or_task_name = %INFER_EVAL_TASK

infer_script.infer:
mode = "predict_with_aux"
write_fn = @json_utils.write_inferences_to_file
merge_fn = @infer_script.merge_chunks_to_file

json_utils.write_inferences_to_file:
include_all_inputs = False
input_fields_to_include = [
"doc.id",
"concept.name",
"number.name",
"text.sampa",
"text.words",
]
27 changes: 27 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_phonemes/model_base.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base configuration.

from __gin__ import dynamic_registration

include "protoscribe/models/pmmx/configs/glyph_phonemes/model_common.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
NUM_HEADS = 4
HEAD_DIM = 32
EMBED_DIM = 96
MLP_DIM = 512
73 changes: 73 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_phonemes/model_common.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base configuration of the model.

from __gin__ import dynamic_registration

from protoscribe.pmmx import feature_converters
from protoscribe.pmmx import models
from t5x import adafactor
from t5x import utils

ARCHITECTURE = %gin.REQUIRED

include "protoscribe/models/pmmx/configs/glyph_phonemes/arch_p1_t5_1_1_flaxformer.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = %gin.REQUIRED
NUM_DECODER_LAYERS = %gin.REQUIRED
NUM_HEADS = %gin.REQUIRED
HEAD_DIM = %gin.REQUIRED
EMBED_DIM = %gin.REQUIRED
MLP_DIM = %gin.REQUIRED

# Optimizer
# `learning_rate` is set by `Trainer.learning_rate_fn`.
OPTIMIZER = @adafactor.Adafactor()
adafactor.Adafactor:
decay_rate = 0.8
step_offset = 0

# Loss defaults.
Z_LOSS = 0.0001
LABEL_SMOOTHING = 0.0
LOSS_NORMALIZING_FACTOR = None

# Model
MODEL = @models.MultimodalEncoderDecoderModel()
models.MultimodalEncoderDecoderModel:
feature_converter_cls = @feature_converters.MultimodalEncDecFeatureConverterFactory()
module = %ARCHITECTURE # provided by t5_flaxformer
input_vocabulary = %inputs/PASSTHROUGH_VOCABULARY
output_vocabulary = %outputs/PASSTHROUGH_VOCABULARY
optimizer_def = %OPTIMIZER
z_loss = %Z_LOSS
label_smoothing = %LABEL_SMOOTHING
loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR

# Decoding
NUM_DECODES = 5
RETURN_ALL_DECODES = True
models.MultimodalEncoderDecoderModel.predict_batch_with_aux:
num_decodes = %NUM_DECODES
return_all_decodes = %RETURN_ALL_DECODES

# Checkpoints
CHECKPOINT_PERIOD = 10_000
EVAL_PERIOD = %CHECKPOINT_PERIOD
utils.SaveCheckpointConfig:
period = %CHECKPOINT_PERIOD
keep = None # Keep all checkpoints.
save_dataset = False
27 changes: 27 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_phonemes/model_tiny.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny configuration.

from __gin__ import dynamic_registration

include "protoscribe/models/pmmx/configs/glyph_phonemes/model_common.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2
NUM_HEADS = 6
HEAD_DIM = 16
EMBED_DIM = 16
MLP_DIM = 16
Loading

0 comments on commit c4a1058

Please sign in to comment.