Skip to content

Commit

Permalink
Sketch generation from semantic concepts.
Browse files Browse the repository at this point in the history
Also adding implementation of arithmetic sampling described in
Luke Vilnis, Yury Zemlyanskiy, Patrick Murray, Alexandre Passos, Sumit Sanghai:
"Arithmetic Sampling: Parallel Diverse Decoding for Large Language Models"
https://arxiv.org/abs/2210.15458

PiperOrigin-RevId: 722549106
  • Loading branch information
agutkin committed Feb 3, 2025
1 parent 8226adb commit 6c56a46
Show file tree
Hide file tree
Showing 12 changed files with 1,824 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Sketch generation using discrete tokens from concepts.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Multimodal PMMX-One flaxformer architecture.

from __gin__ import dynamic_registration

import seqio
from protoscribe.pmmx import feature_converters
from protoscribe.pmmx import pmmx_architecture

from flaxformer.architectures.t5 import t5_architecture
from flaxformer.components import embedding

include "protoscribe/pmmx/configs/architectures/p1_t5_1_1_flaxformer.gin"

# Architecture (Flax Module).
ARCHITECTURE = @pmmx_architecture.MultimodalEncoderDecoder()

# Vocabulary for the encoder.
inputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary()
inputs/seqio.PassThroughVocabulary.size = 0

# Output vocabulary for the decoder. The `SKETCH_TOKEN_VOCAB_SIZE` corresponds
# to the real sketch token vocabulary size (N+3), where 3 is the number of special
# symbols rounded by the batch size B (16): In other words, N + B. This is
#
# - 2064: for N=2048
# - 4112: for N=4096

SKETCH_TOKEN_VOCAB_SIZE = 2064
END_OF_SKETCH = 3 # sketch_tokenizer.END_OF_SKETCH
NUM_EMBEDDINGS = %SKETCH_TOKEN_VOCAB_SIZE

outputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary()
outputs/seqio.PassThroughVocabulary.size = 0
outputs/seqio.PassThroughVocabulary.eos_id = %END_OF_SKETCH

# Actual multimodal encoder-decoder architecture.
pmmx_architecture.MultimodalEncoderDecoder:
encoder_factory = @pmmx_architecture.MultimodalEncoder
decoder_factory = @t5_architecture.Decoder
shared_token_embedder_factory = @token_embedder/embedding.Embed
dtype = %ACTIVATION_DTYPE

feature_converters.MultimodalEncDecFeatureConverterFactory:
task_feature_lengths = %TASK_FEATURE_LENGTHS
feature_specs = (
("text.concept_embedding", "float32", 2),
)

# Encoder
pmmx_architecture.MultimodalEncoder:
feature_spec = [
("text.concept_embedding", "text.concept_embedding"),
]
modality_spec = ["text.concept_embedding"]
modality_embedders_spec = {
"text.concept_embedding": [
("text.concept_embedding", @pmmx_architecture.DenseEmbed)
],
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Arithmetic sampling: "Arithmetic Sampling: Parallel Diverse Decoding for Large
# Language Models" (2023).
# Paper: https://proceedings.mlr.press/v202/vilnis23a/vilnis23a.pdf
# Implementation: https://github.com/google-research/google-research/tree/master/arithmetic_sampling/

from __gin__ import dynamic_registration

from protoscribe.pmmx import arithmetic_sampling as decoding
from protoscribe.pmmx import models

# If non-zero only use the top-k logits to sample next token, otherwise don't
# use any cutoff and sample from full logits over vocabulary. Both TOPK and
# TOPP defined below cannot be non-zero.
SAMPLING_TOPK = 40

# If non-zero, only use the smallest number of logits whose cumulative sum of
# probs adds up to (at least) TOPP.
SAMPLING_TOPP = 0.

# Sampling temperature.
TEMPERATURE = 0.6

models.MultimodalEncoderDecoderModel:
decode_fn = @decoding.arithmetic_sample

decoding.arithmetic_sample:
topk = %SAMPLING_TOPK
topp = %SAMPLING_TOPP
temperature = %TEMPERATURE
84 changes: 84 additions & 0 deletions protoscribe/models/pmmx/configs/sketch-token_concepts/dataset.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Settings for Protoscribe dataset reader using discrete sketch tokens.

from __gin__ import dynamic_registration

from t5x import utils

from protoscribe.corpus.reader import tasks

DATA_DIR = %gin.REQUIRED
TRAIN_DATA_DIR = %DATA_DIR
EVAL_DATA_DIR = %DATA_DIR
INFER_EVAL_DATA_DIR = %DATA_DIR

MAX_STROKE_SEQUENCE_LENGTH = 250

tasks.register:
concept_embedding_type = "bnc"
max_stroke_sequence_length = %MAX_STROKE_SEQUENCE_LENGTH
max_glyph_sequence_length = 20
stroke_random_scale_factor = 0.0
stroke_normalization_type = "sketch-rnn"
stroke_token_vocab_filename = "vocab2048_normalized_sketchrnn.npy"

train_task/tasks.register:
task_name = "bnc_tokens.train"
dataset_dir = %TRAIN_DATA_DIR
is_training = True
noisify_embeddings = True
noisify_neftune_alphas = {
%tasks.EMBEDDING_SEMANTICS: 5.0,
}

eval_task/tasks.register:
task_name = "bnc_tokens.eval"
dataset_dir = %EVAL_DATA_DIR
is_training = False

infer_eval_task/tasks.register:
task_name = "bnc_tokens.infer_eval"
dataset_dir = %INFER_EVAL_DATA_DIR
is_training = False

TRAIN_TASK = @train_task/tasks.register()
EVAL_TASK = @eval_task/tasks.register()
INFER_EVAL_TASK = @infer_eval_task/tasks.register()
MIXTURE_OR_TASK_NAME = %TRAIN_TASK
MIXTURE_OR_TASK_MODULE = "protoscribe.corpus.reader.tasks"
USE_CACHED_TASKS = False

TASK_FEATURE_LENGTHS = {
"text.concept_embedding": 2,
"targets": %MAX_STROKE_SEQUENCE_LENGTH
}

train/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
pack = False
use_custom_packing_ops = False

train_eval/utils.DatasetConfig:
mixture_or_task_name = %EVAL_TASK
pack = False
shuffle = False
use_custom_packing_ops = False

infer_eval/utils.DatasetConfig:
mixture_or_task_name = %INFER_EVAL_TASK
pack = False
shuffle = False
use_custom_packing_ops = False
39 changes: 39 additions & 0 deletions protoscribe/models/pmmx/configs/sketch-token_concepts/infer.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __gin__ import dynamic_registration

import __main__ as infer_script
from protoscribe.sketches.inference import json_utils
from t5x import utils

include "protoscribe/pmmx/configs/runs/infer.gin"

utils.DatasetConfig:
mixture_or_task_name = %INFER_EVAL_TASK

infer_script.infer:
mode = "predict_with_aux"
write_fn = @json_utils.write_inferences_to_file
merge_fn = @infer_script.merge_chunks_to_file

json_utils.write_inferences_to_file:
include_all_inputs = False
input_fields_to_include = [
"doc.id",
"concept.name",
"number.name",
"text.sampa",
"text.words",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base configuration.

from __gin__ import dynamic_registration

include "protoscribe/models/pmmx/configs/sketch-token_concepts/model_common.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = 12
NUM_DECODER_LAYERS = 12
NUM_HEADS = 12
HEAD_DIM = 64
EMBED_DIM = 768
MLP_DIM = 2048
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base configuration of the model.

from __gin__ import dynamic_registration

from protoscribe.pmmx import feature_converters
from protoscribe.pmmx import models
from t5x import adafactor
from t5x import utils

ARCHITECTURE = %gin.REQUIRED

include "protoscribe/models/pmmx/configs/sketch-token_concepts/arch_p1_t5_1_1_flaxformer.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = %gin.REQUIRED
NUM_DECODER_LAYERS = %gin.REQUIRED
NUM_HEADS = %gin.REQUIRED
HEAD_DIM = %gin.REQUIRED
EMBED_DIM = %gin.REQUIRED
MLP_DIM = %gin.REQUIRED

# Optimizer
# `learning_rate` is set by `Trainer.learning_rate_fn`.
OPTIMIZER = @adafactor.Adafactor()
adafactor.Adafactor:
decay_rate = 0.8
step_offset = 0

# Loss defaults.
Z_LOSS = 0.0001
LABEL_SMOOTHING = 0.0
LOSS_NORMALIZING_FACTOR = None

# Model
MODEL = @models.MultimodalEncoderDecoderModel()
models.MultimodalEncoderDecoderModel:
feature_converter_cls = @feature_converters.MultimodalEncDecFeatureConverterFactory()
module = %ARCHITECTURE # provided by t5_flaxformer
input_vocabulary = %inputs/PASSTHROUGH_VOCABULARY
output_vocabulary = %outputs/PASSTHROUGH_VOCABULARY
optimizer_def = %OPTIMIZER
z_loss = %Z_LOSS
label_smoothing = %LABEL_SMOOTHING
loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR

# Decoding
NUM_DECODES = 8
RETURN_ALL_DECODES = True
models.MultimodalEncoderDecoderModel.predict_batch_with_aux:
num_decodes = %NUM_DECODES
return_all_decodes = %RETURN_ALL_DECODES

# Checkpoints
CHECKPOINT_PERIOD = 20_000
EVAL_PERIOD = %CHECKPOINT_PERIOD
utils.SaveCheckpointConfig:
period = %CHECKPOINT_PERIOD
keep = None # Keep all checkpoints.
save_dataset = False
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny configuration.

from __gin__ import dynamic_registration

include "protoscribe/models/pmmx/configs/sketch-token_concepts/model_common.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2
NUM_HEADS = 6
HEAD_DIM = 16
EMBED_DIM = 16
MLP_DIM = 16
Loading

0 comments on commit 6c56a46

Please sign in to comment.