From 6c56a4630dc5ac144b8b048017d71aabc9337ff6 Mon Sep 17 00:00:00 2001 From: Alexander Gutkin Date: Mon, 3 Feb 2025 09:03:56 +0000 Subject: [PATCH] Sketch generation from semantic concepts. Also adding implementation of arithmetic sampling described in Luke Vilnis, Yury Zemlyanskiy, Patrick Murray, Alexandre Passos, Sumit Sanghai: "Arithmetic Sampling: Parallel Diverse Decoding for Large Language Models" https://arxiv.org/abs/2210.15458 PiperOrigin-RevId: 722549106 --- .../configs/sketch-token_concepts/README.md | 1 + .../arch_p1_t5_1_1_flaxformer.gin | 73 ++ .../arithmetic_sample.gin | 43 + .../configs/sketch-token_concepts/dataset.gin | 84 ++ .../configs/sketch-token_concepts/infer.gin | 39 + .../sketch-token_concepts/model_base.gin | 27 + .../sketch-token_concepts/model_common.gin | 73 ++ .../sketch-token_concepts/model_tiny.gin | 27 + .../temperature_sample.gin | 45 + .../models/pmmx/model_config_gin_test.py | 2 + protoscribe/pmmx/arithmetic_sampling.py | 619 ++++++++++++++ protoscribe/pmmx/arithmetic_sampling_test.py | 791 ++++++++++++++++++ 12 files changed, 1824 insertions(+) create mode 100644 protoscribe/models/pmmx/configs/sketch-token_concepts/README.md create mode 100644 protoscribe/models/pmmx/configs/sketch-token_concepts/arch_p1_t5_1_1_flaxformer.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_concepts/arithmetic_sample.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_concepts/dataset.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_concepts/infer.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_concepts/model_base.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_concepts/model_common.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_concepts/model_tiny.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_concepts/temperature_sample.gin create mode 100644 protoscribe/pmmx/arithmetic_sampling.py create mode 100644 protoscribe/pmmx/arithmetic_sampling_test.py diff --git a/protoscribe/models/pmmx/configs/sketch-token_concepts/README.md b/protoscribe/models/pmmx/configs/sketch-token_concepts/README.md new file mode 100644 index 0000000..d5d9b43 --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_concepts/README.md @@ -0,0 +1 @@ +# Sketch generation using discrete tokens from concepts. diff --git a/protoscribe/models/pmmx/configs/sketch-token_concepts/arch_p1_t5_1_1_flaxformer.gin b/protoscribe/models/pmmx/configs/sketch-token_concepts/arch_p1_t5_1_1_flaxformer.gin new file mode 100644 index 0000000..411aa61 --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_concepts/arch_p1_t5_1_1_flaxformer.gin @@ -0,0 +1,73 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Multimodal PMMX-One flaxformer architecture. + +from __gin__ import dynamic_registration + +import seqio +from protoscribe.pmmx import feature_converters +from protoscribe.pmmx import pmmx_architecture + +from flaxformer.architectures.t5 import t5_architecture +from flaxformer.components import embedding + +include "protoscribe/pmmx/configs/architectures/p1_t5_1_1_flaxformer.gin" + +# Architecture (Flax Module). +ARCHITECTURE = @pmmx_architecture.MultimodalEncoderDecoder() + +# Vocabulary for the encoder. +inputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary() +inputs/seqio.PassThroughVocabulary.size = 0 + +# Output vocabulary for the decoder. The `SKETCH_TOKEN_VOCAB_SIZE` corresponds +# to the real sketch token vocabulary size (N+3), where 3 is the number of special +# symbols rounded by the batch size B (16): In other words, N + B. This is +# +# - 2064: for N=2048 +# - 4112: for N=4096 + +SKETCH_TOKEN_VOCAB_SIZE = 2064 +END_OF_SKETCH = 3 # sketch_tokenizer.END_OF_SKETCH +NUM_EMBEDDINGS = %SKETCH_TOKEN_VOCAB_SIZE + +outputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary() +outputs/seqio.PassThroughVocabulary.size = 0 +outputs/seqio.PassThroughVocabulary.eos_id = %END_OF_SKETCH + +# Actual multimodal encoder-decoder architecture. +pmmx_architecture.MultimodalEncoderDecoder: + encoder_factory = @pmmx_architecture.MultimodalEncoder + decoder_factory = @t5_architecture.Decoder + shared_token_embedder_factory = @token_embedder/embedding.Embed + dtype = %ACTIVATION_DTYPE + +feature_converters.MultimodalEncDecFeatureConverterFactory: + task_feature_lengths = %TASK_FEATURE_LENGTHS + feature_specs = ( + ("text.concept_embedding", "float32", 2), + ) + +# Encoder +pmmx_architecture.MultimodalEncoder: + feature_spec = [ + ("text.concept_embedding", "text.concept_embedding"), + ] + modality_spec = ["text.concept_embedding"] + modality_embedders_spec = { + "text.concept_embedding": [ + ("text.concept_embedding", @pmmx_architecture.DenseEmbed) + ], + } diff --git a/protoscribe/models/pmmx/configs/sketch-token_concepts/arithmetic_sample.gin b/protoscribe/models/pmmx/configs/sketch-token_concepts/arithmetic_sample.gin new file mode 100644 index 0000000..30a0550 --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_concepts/arithmetic_sample.gin @@ -0,0 +1,43 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Arithmetic sampling: "Arithmetic Sampling: Parallel Diverse Decoding for Large +# Language Models" (2023). +# Paper: https://proceedings.mlr.press/v202/vilnis23a/vilnis23a.pdf +# Implementation: https://github.com/google-research/google-research/tree/master/arithmetic_sampling/ + +from __gin__ import dynamic_registration + +from protoscribe.pmmx import arithmetic_sampling as decoding +from protoscribe.pmmx import models + +# If non-zero only use the top-k logits to sample next token, otherwise don't +# use any cutoff and sample from full logits over vocabulary. Both TOPK and +# TOPP defined below cannot be non-zero. +SAMPLING_TOPK = 40 + +# If non-zero, only use the smallest number of logits whose cumulative sum of +# probs adds up to (at least) TOPP. +SAMPLING_TOPP = 0. + +# Sampling temperature. +TEMPERATURE = 0.6 + +models.MultimodalEncoderDecoderModel: + decode_fn = @decoding.arithmetic_sample + +decoding.arithmetic_sample: + topk = %SAMPLING_TOPK + topp = %SAMPLING_TOPP + temperature = %TEMPERATURE diff --git a/protoscribe/models/pmmx/configs/sketch-token_concepts/dataset.gin b/protoscribe/models/pmmx/configs/sketch-token_concepts/dataset.gin new file mode 100644 index 0000000..c3e81d2 --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_concepts/dataset.gin @@ -0,0 +1,84 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Settings for Protoscribe dataset reader using discrete sketch tokens. + +from __gin__ import dynamic_registration + +from t5x import utils + +from protoscribe.corpus.reader import tasks + +DATA_DIR = %gin.REQUIRED +TRAIN_DATA_DIR = %DATA_DIR +EVAL_DATA_DIR = %DATA_DIR +INFER_EVAL_DATA_DIR = %DATA_DIR + +MAX_STROKE_SEQUENCE_LENGTH = 250 + +tasks.register: + concept_embedding_type = "bnc" + max_stroke_sequence_length = %MAX_STROKE_SEQUENCE_LENGTH + max_glyph_sequence_length = 20 + stroke_random_scale_factor = 0.0 + stroke_normalization_type = "sketch-rnn" + stroke_token_vocab_filename = "vocab2048_normalized_sketchrnn.npy" + +train_task/tasks.register: + task_name = "bnc_tokens.train" + dataset_dir = %TRAIN_DATA_DIR + is_training = True + noisify_embeddings = True + noisify_neftune_alphas = { + %tasks.EMBEDDING_SEMANTICS: 5.0, + } + +eval_task/tasks.register: + task_name = "bnc_tokens.eval" + dataset_dir = %EVAL_DATA_DIR + is_training = False + +infer_eval_task/tasks.register: + task_name = "bnc_tokens.infer_eval" + dataset_dir = %INFER_EVAL_DATA_DIR + is_training = False + +TRAIN_TASK = @train_task/tasks.register() +EVAL_TASK = @eval_task/tasks.register() +INFER_EVAL_TASK = @infer_eval_task/tasks.register() +MIXTURE_OR_TASK_NAME = %TRAIN_TASK +MIXTURE_OR_TASK_MODULE = "protoscribe.corpus.reader.tasks" +USE_CACHED_TASKS = False + +TASK_FEATURE_LENGTHS = { + "text.concept_embedding": 2, + "targets": %MAX_STROKE_SEQUENCE_LENGTH +} + +train/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + pack = False + use_custom_packing_ops = False + +train_eval/utils.DatasetConfig: + mixture_or_task_name = %EVAL_TASK + pack = False + shuffle = False + use_custom_packing_ops = False + +infer_eval/utils.DatasetConfig: + mixture_or_task_name = %INFER_EVAL_TASK + pack = False + shuffle = False + use_custom_packing_ops = False \ No newline at end of file diff --git a/protoscribe/models/pmmx/configs/sketch-token_concepts/infer.gin b/protoscribe/models/pmmx/configs/sketch-token_concepts/infer.gin new file mode 100644 index 0000000..68337af --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_concepts/infer.gin @@ -0,0 +1,39 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __gin__ import dynamic_registration + +import __main__ as infer_script +from protoscribe.sketches.inference import json_utils +from t5x import utils + +include "protoscribe/pmmx/configs/runs/infer.gin" + +utils.DatasetConfig: + mixture_or_task_name = %INFER_EVAL_TASK + +infer_script.infer: + mode = "predict_with_aux" + write_fn = @json_utils.write_inferences_to_file + merge_fn = @infer_script.merge_chunks_to_file + +json_utils.write_inferences_to_file: + include_all_inputs = False + input_fields_to_include = [ + "doc.id", + "concept.name", + "number.name", + "text.sampa", + "text.words", + ] diff --git a/protoscribe/models/pmmx/configs/sketch-token_concepts/model_base.gin b/protoscribe/models/pmmx/configs/sketch-token_concepts/model_base.gin new file mode 100644 index 0000000..a37936b --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_concepts/model_base.gin @@ -0,0 +1,27 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Base configuration. + +from __gin__ import dynamic_registration + +include "protoscribe/models/pmmx/configs/sketch-token_concepts/model_common.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = 12 +NUM_DECODER_LAYERS = 12 +NUM_HEADS = 12 +HEAD_DIM = 64 +EMBED_DIM = 768 +MLP_DIM = 2048 diff --git a/protoscribe/models/pmmx/configs/sketch-token_concepts/model_common.gin b/protoscribe/models/pmmx/configs/sketch-token_concepts/model_common.gin new file mode 100644 index 0000000..1925dd2 --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_concepts/model_common.gin @@ -0,0 +1,73 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Base configuration of the model. + +from __gin__ import dynamic_registration + +from protoscribe.pmmx import feature_converters +from protoscribe.pmmx import models +from t5x import adafactor +from t5x import utils + +ARCHITECTURE = %gin.REQUIRED + +include "protoscribe/models/pmmx/configs/sketch-token_concepts/arch_p1_t5_1_1_flaxformer.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = %gin.REQUIRED +NUM_DECODER_LAYERS = %gin.REQUIRED +NUM_HEADS = %gin.REQUIRED +HEAD_DIM = %gin.REQUIRED +EMBED_DIM = %gin.REQUIRED +MLP_DIM = %gin.REQUIRED + +# Optimizer +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + +# Loss defaults. +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +LOSS_NORMALIZING_FACTOR = None + +# Model +MODEL = @models.MultimodalEncoderDecoderModel() +models.MultimodalEncoderDecoderModel: + feature_converter_cls = @feature_converters.MultimodalEncDecFeatureConverterFactory() + module = %ARCHITECTURE # provided by t5_flaxformer + input_vocabulary = %inputs/PASSTHROUGH_VOCABULARY + output_vocabulary = %outputs/PASSTHROUGH_VOCABULARY + optimizer_def = %OPTIMIZER + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +# Decoding +NUM_DECODES = 8 +RETURN_ALL_DECODES = True +models.MultimodalEncoderDecoderModel.predict_batch_with_aux: + num_decodes = %NUM_DECODES + return_all_decodes = %RETURN_ALL_DECODES + +# Checkpoints +CHECKPOINT_PERIOD = 20_000 +EVAL_PERIOD = %CHECKPOINT_PERIOD +utils.SaveCheckpointConfig: + period = %CHECKPOINT_PERIOD + keep = None # Keep all checkpoints. + save_dataset = False diff --git a/protoscribe/models/pmmx/configs/sketch-token_concepts/model_tiny.gin b/protoscribe/models/pmmx/configs/sketch-token_concepts/model_tiny.gin new file mode 100644 index 0000000..45587d2 --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_concepts/model_tiny.gin @@ -0,0 +1,27 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny configuration. + +from __gin__ import dynamic_registration + +include "protoscribe/models/pmmx/configs/sketch-token_concepts/model_common.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = 2 +NUM_DECODER_LAYERS = 2 +NUM_HEADS = 6 +HEAD_DIM = 16 +EMBED_DIM = 16 +MLP_DIM = 16 diff --git a/protoscribe/models/pmmx/configs/sketch-token_concepts/temperature_sample.gin b/protoscribe/models/pmmx/configs/sketch-token_concepts/temperature_sample.gin new file mode 100644 index 0000000..2ba520e --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_concepts/temperature_sample.gin @@ -0,0 +1,45 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard t5x temperature sampling. +# +# This can either top-K or nucleus (top-P). See +# Holtzman, A., Buys, J., Du, L., Forbes, M., & Choi, Y. (2020). ``The Curious +# Case of Neural Text Degeneration.''. In International Conference on Learning +# Representations. https://openreview.net/forum?id=rygGQyrFvH + +from __gin__ import dynamic_registration + +from protoscribe.pmmx import models +from t5x import decoding + +# If non-zero only use the top-k logits to sample next token, otherwise don't +# use any cutoff and sample from full logits over vocabulary. Both TOPK and +# TOPP defined below cannot be non-zero. +SAMPLING_TOPK = 0 + +# If non-zero, only use the smallest number of logits whose cumulative sum of +# probs adds up to (at least) TOPP. +SAMPLING_TOPP = 0.4 + +# Sampling temperature. +TEMPERATURE = 0.6 + +models.MultimodalEncoderDecoderModel: + decode_fn = @decoding.temperature_sample + +decoding.temperature_sample: + topk = %SAMPLING_TOPK + topp = %SAMPLING_TOPP + temperature = %TEMPERATURE \ No newline at end of file diff --git a/protoscribe/models/pmmx/model_config_gin_test.py b/protoscribe/models/pmmx/model_config_gin_test.py index 73690c7..b8de0f2 100644 --- a/protoscribe/models/pmmx/model_config_gin_test.py +++ b/protoscribe/models/pmmx/model_config_gin_test.py @@ -53,6 +53,7 @@ def tearDown(self): "glyph_concepts", "glyph_logmel-spectrum", "glyph_phonemes", + "sketch-token_concepts", ) def test_model_train(self, model_dir: str) -> None: """Tests tiny model configuration for training.""" @@ -78,6 +79,7 @@ def test_model_train(self, model_dir: str) -> None: "glyph_concepts", "glyph_logmel-spectrum", "glyph_phonemes", + "sketch-token_concepts", ) def test_model_infer(self, model_dir: str) -> None: """Tests tiny model configuration in inference mode.""" diff --git a/protoscribe/pmmx/arithmetic_sampling.py b/protoscribe/pmmx/arithmetic_sampling.py new file mode 100644 index 0000000..1d42701 --- /dev/null +++ b/protoscribe/pmmx/arithmetic_sampling.py @@ -0,0 +1,619 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5X decoding routine for arithmetic sampling. + +Luke Vilnis, Yury Zemlyanskiy, Patrick Murray, Alexandre Passos, +Sumit Sanghai: "Arithmetic Sampling: Parallel Diverse Decoding for Large +Language Models", https://arxiv.org/abs/2210.15458. + +Cloned from https://github.com/google-research/google-research/tree/master/arithmetic_sampling. # pylint: disable=line-too-long +""" + +import functools + +from typing import Any, Callable, Mapping, Optional, Tuple, Union +import flax +import jax +from jax import lax +from jax import random +import jax.numpy as jnp +import numpy as np +from t5x import decoding + +# Constants +# "Effective negative infinity" constant for masking in beam search. +NEG_INF = np.array(-1.0e7) + +# Temperatures lower than this are considered 0.0, which is handled specially +# with a conditional. This is to avoid numeric issues from exponentiating on +# 1.0/temperature when temperature is close to 0.0. +MIN_TEMPERATURE = np.array(1e-4) + +#------------------------------------------------------------------------------ +# Arithmetic Sampling +#------------------------------------------------------------------------------ + + +@flax.struct.dataclass +class ArithmeticSamplingLoopState: + """Holds sampling state data. + + Attributes: + cur_index: [batch_size] array position of the sampling loop in the length + dimension. + sequences: [batch_size * num_decodes, max_decode_len] array of current + sampled sequence prefixes. + cache: any mapping of arrays, e.g. flax attention cache. + cur_token: [batch_size, num_decodes] single timestep slice containing + current tokens. + ended: [batch_size, num_decodes] binary array marking completed sequences. + rng: Jax PRNGKey + log_prob: [batch_size, num_decodes] array of log probs for each sequence. + codes: [batch_size, num_decodes] array containing the arithmetic codes for + the remainder of the sequence at the current time step for each sample. + """ + cur_index: jnp.ndarray + sequences: jnp.ndarray + cache: Mapping[str, jnp.ndarray] + cur_token: jnp.ndarray + ended: jnp.ndarray + rng: jnp.ndarray + log_prob: jnp.ndarray + codes: jnp.ndarray + + +_dynamic_update_vector_slice_in_dim = jax.vmap( + lax.dynamic_update_slice_in_dim, in_axes=(0, 0, 0, None)) + + +def _is_tracer(value: Any): + return isinstance(value, jax.core.Tracer) + + +def _sequential_cumsum(arr: jnp.ndarray, axis: int) -> jnp.ndarray: + """Sequential scan-based implementation of cumulative sum for Jax. + + The Jax implementation of cumulative sum does not guarantee that the output + array is nondecreasing when applied to nonnegative outputs. This breaks + the use of cumulative sum for bucketing. Using scan guarantees forces the + sum to happen sequentially, which avoids the floating point nonsense that + causes normal Jax cumsum to exhibit bad behavior. + + Args: + arr: Jax array to sum. + axis: axis to sum over. + + Returns: + Jax array of partial cumulative sums. + """ + + # Swap axes so that the axis to be scanned over is the leading axis. + xs = jnp.swapaxes(arr, 0, axis) + init_carry = jnp.zeros(xs.shape[1:], xs.dtype) + _, res = jax.lax.scan(lambda c, x: (c + x, c + x), init_carry, xs) + return jnp.swapaxes(res, 0, axis) + + +def _arithmetic_categorical( + rng: jnp.ndarray, logits: jnp.ndarray, + codes: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Sample from a categorical using arithmetic sampling. + + Returns samples from an arithmetic codebook based on provided codes. This + gives an unbiased sample for each code randomly picked from the unit interval. + + Args: + rng: JAX PRNGKey. + logits: array: [batch_size, vocab_size] float32 sequence of logits. + codes: array: [batch_size] float32 codes for each batch element. + + Returns: + A tuple (samples, new_codes) where `samples` are sampled indices with shape + [batch_size], and `new_codes` are shape [batch_size] containing codes for + the remaining suffix if doing ancestral sampling. + """ + # We randomly permute the logits here at each timestep to avoid depending on + # The default order of the vocabulary. This isn't strictly necessary. + # We need to invert this permutation at the end cause it changes the + # identities of the sampled indices. + _, vocab_size = logits.shape + perm = jax.random.permutation(rng, vocab_size) + invperm = jnp.argsort(perm) + + logits = logits[:, perm] + + # Now we want to, for each element in the batch, get the normalized + # probabilities, stack them in the unit interval into buckets, and figure + # out what bucket the code falls into. + probs = jax.nn.softmax(logits, axis=1) + + # Use the numpy cumsum with host callback to guarantee nondecreasing array + # of partial sums. + cumprobs = _sequential_cumsum(probs, axis=1) + + # Because of precision, make sure the max value (and everything with that + # value, to not change bucket widths) is at least 1.0. + max_probs = jnp.expand_dims(jnp.max(cumprobs, axis=1), 1) + all_bucket_maxes = jnp.where((cumprobs == max_probs) & (cumprobs < 1.0), 1.0, + cumprobs) + + # Now the cumulative probabilities represent the max value of each of the + # buckets. So let's make a mask of all the buckets whose maxes are less + # than and greater than the given codes. + expanded_codes = jnp.expand_dims(codes, axis=1) + bucket_maxes_lte_codes = all_bucket_maxes <= expanded_codes + bucket_maxes_gt_codes = all_bucket_maxes > expanded_codes + + # Pick the minimum value for the bucket for the code. Note this will be + # 0.0 if the code falls into the zero'th bucket, as desired. + code_bucket_mins = jnp.max(all_bucket_maxes * bucket_maxes_lte_codes, axis=1) + + # We have to do some masking here, and for probabilities, anything > 1.0 + # is as good as infinity. + prob_infty = 1.1 + # Pick the maximum value for the bucket, the first bucket whose max is + # greater than the code. + code_bucket_maxes = jnp.min( + all_bucket_maxes * bucket_maxes_gt_codes + + bucket_maxes_lte_codes * prob_infty, + axis=1) + # We have to take the argmin before inverting the permutation, + # otherwise it messes up the default tie breaking behavior for size zero + # buckets (take lowest index). + sampled_indices_permed = jnp.argmin( + (all_bucket_maxes * bucket_maxes_gt_codes + + bucket_maxes_lte_codes * prob_infty), + axis=1) + sampled_indices = jnp.argmax( + jax.nn.one_hot(sampled_indices_permed, vocab_size)[:, invperm], axis=1) + + remainder_codes = (codes - code_bucket_mins) / ( + code_bucket_maxes - code_bucket_mins) + + samples = sampled_indices + new_codes = remainder_codes + + return samples, new_codes + + +def arithmetic_sample( + inputs: jnp.ndarray, + cache: Mapping[str, jnp.ndarray], + tokens_to_logits: Callable[[decoding.DecodingState], + Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]], + eos_id: int, + decode_rng: Optional[jnp.ndarray] = None, + num_decodes: int = 1, + temperature: Union[float, jnp.ndarray] = 1.0, + topk: int = 1, + topp: float = 0.0, + cache_offset: int = 0, + initial_index: Optional[jnp.ndarray] = None, + max_decode_steps: Optional[Union[int, jnp.ndarray]] = None, + max_decode_steps_hard_limit: Optional[int] = None, + rescale_log_probs: bool = True, + state_callback_fn: Optional[Callable[[ArithmeticSamplingLoopState], + ArithmeticSamplingLoopState]] = None, + logit_callback_fn: Optional[Callable[ + [jnp.ndarray, ArithmeticSamplingLoopState], jnp.ndarray]] = None +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Arithmetic sampling for language model generation. + + The sampling is performed `num_decodes` times in a vectorized + manner by expanding the batch dimension. This is similar to how beam search + expands the batch dimension to process each batch element with multiple beams. + + Args: + inputs: array: [batch_size, max_decode_len] int32 sequence of tokens. + cache: flax attention cache. + tokens_to_logits: fast autoregressive decoder function taking single token + slices and cache and returning next-token logits and updated cache. + eos_id: int: end-of-sentence token for target vocabulary. + decode_rng: JAX PRNGKey. + num_decodes: number of decoded sequences to be returned. + temperature: float: sampling temperature factor. As it approaches zero this + becomes equivalent to greedy sampling. + topk: integer: if nonzero only use the top-k logits to sample next token, if + zero don't use any cutoff and sample from full logits over vocabulary. + topp: float: if nonzero only use the smallest number of logits whose + cumulative sum of probs adds up to (at least) topp. Will raise ValueError + if it's nonzero when topk is nonzero. + cache_offset: axis offset for cache, arising from scanned layers. + initial_index: Optional[array]: [batch_size] int32 a vector of loop indexes + to start decoding at. + max_decode_steps: int: an optional maximum number of decoding steps. If + None, it will decode until the full input shape `inputs.shape[1]` is + filled. max_decode_steps begins counting after the prompt, so it will + decode at most len(prompt) + max_decode_steps tokens. + max_decode_steps_hard_limit: int: an optional fixed hard limit on + max_decode_steps. If this is set (not None and > 0), and max_decode_steps + is also set, then max_decode_steps will be clipped to this limit. The + value max_decode_steps can be an ndarray, but max_decode_steps_hard_limit + must be a Python integer or None. + rescale_log_probs: bool: whether to apply temperature, topp, and topk + rescaling to the log probs which are returned. If True, the log_probs will + include these transformations (for example, with topk=1, all log_probs + will be identically 0.0). If False, the log_probs will not be affected, + and topk/topp/temperature will not affect sequence probabilities. + state_callback_fn: Function that modifies the sampling loop state before + each step. This can be used to manipulate any part of the state either on + the accelerator or on the host using host callback. The function should + take a SamplingLoopState as argument, and it returns the updated state. + See `decoding_test.py` for an example usage. + logit_callback_fn: Function that modifies the logits before each temperature + sampling step. The function should take arguments (logits, state) and it + should return the modified logits. See `decoding_test.py` for an example + usage. + + Returns: + A tuple (decodes, log_prob) where `decodes` is sampled sequences with shape + [batch_size, num_decodes, max_decode_len] sorted by `log_prob`, which is log + probability of each of the sampled sequences. + """ + if decode_rng is None: + decode_rng = jax.random.PRNGKey(0) + + if (max_decode_steps_hard_limit is not None and + max_decode_steps_hard_limit > 0 and max_decode_steps is not None): + max_decode_steps = jnp.minimum(max_decode_steps, + max_decode_steps_hard_limit) + + initial_codes = _make_default_codes(inputs.shape[0], num_decodes, decode_rng) + flattened_codes = decoding.flatten_beam_dim(initial_codes) + + # [batch, len] -> [batch * num_decodes, len] + expanded_inputs = decoding.flat_batch_beam_expand(inputs, num_decodes) + expanded_cache = decoding.cache_map( + functools.partial( + decoding.flat_batch_beam_expand, + beam_size=num_decodes, + offset=cache_offset), + cache, + # When we start with a prefilled cache, the cache index is no longer a + # scalar that will broadcast across multiple decodes, it is a vector and + # needs to be updated to handle the multiple decodes. + apply_to_index=initial_index is not None) + if initial_index is not None: + initial_index = decoding.flat_batch_beam_expand(initial_index, num_decodes) + + # expanded_decodes: [batch * num_decodes, len] + # expanded_log_prob: [batch * num_decodes] + expanded_decodes, expanded_log_prob = _arithmetic_sample_single_trial( + expanded_inputs, + flattened_codes, + expanded_cache, + tokens_to_logits, + eos_id, + decode_rng, + temperature, + topk, + topp, + initial_index=initial_index, + max_decode_steps=max_decode_steps, + rescale_log_probs=rescale_log_probs, + state_callback_fn=state_callback_fn, + logit_callback_fn=logit_callback_fn) + + batch_size = inputs.shape[0] + # [batch * num_decodes, len] -> [batch, num_decodes, len] + decodes = decoding.unflatten_beam_dim(expanded_decodes, batch_size, + num_decodes) + # [batch * num_decodes] -> [batch, num_decodes] + log_prob = decoding.unflatten_beam_dim(expanded_log_prob, batch_size, + num_decodes) + + # Sort `decodes` and `log_prob` by increasing log probabilities of the sampled + # sequence. + # [batch, num_decodes, 1] + idxs = jnp.expand_dims(jnp.argsort(log_prob, axis=-1), axis=-1) + + # returns [batch, num_decodes, len], [batch, num_decodes] in sorted order. + return jnp.take_along_axis( + decodes, idxs, axis=1), jnp.take_along_axis( + log_prob, jnp.squeeze(idxs, axis=-1), axis=-1) + + +def _make_default_codes(batch_size: int, num_decodes: int, + rng: jnp.ndarray) -> jnp.ndarray: + """Make default codebook for a batch of `num_decodes` samples. + + The codes are initialized evenly spaced in the unit interval, with a random + offset applied. This lets them evenly cover the sample space while also + providing an unbiased estimate of any sample average. + + Args: + batch_size: size of input batch. + num_decodes: number of samples per batch element. + rng: random seed. + + Returns: + [batch_size, num_decodes] array of codes. + """ + offset = jax.random.uniform(rng, (batch_size, 1)) + codes = jnp.tile( + jnp.expand_dims( + jnp.arange(1, num_decodes + 1, dtype=jnp.float32) / (num_decodes + 1), + axis=0), (batch_size, 1)) + return jnp.mod(codes + offset, 1.0) + + +def _arithmetic_sample_single_trial( + inputs: jnp.ndarray, + initial_codes: jnp.ndarray, + cache: Mapping[str, jnp.ndarray], + tokens_to_logits: Callable[[decoding.DecodingState], + Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]], + eos_id: int, + prng_key: jnp.ndarray, + temperature: Union[float, jnp.ndarray] = 1.0, + topk: int = 20, + topp: Union[float, jnp.ndarray] = 0.0, + initial_index: Optional[jnp.ndarray] = None, + max_decode_steps: Optional[Union[int, jnp.ndarray]] = None, + rescale_log_probs: bool = True, + state_callback_fn: Optional[Callable[[ArithmeticSamplingLoopState], + ArithmeticSamplingLoopState]] = None, + logit_callback_fn: Optional[Callable[ + [jnp.ndarray, ArithmeticSamplingLoopState], jnp.ndarray]] = None +) -> jnp.ndarray: + """A helper function for `arithmetic_sample`.""" + + # We can check the values of topp and topk only if they are not dynamic. + if not _is_tracer(topp) and topp and topk: + raise ValueError('At most one of `topp` or `topk` may be non-zero.') + + batch_size, max_decode_len = inputs.shape + + if max_decode_steps is not None: + # We can check the max_decode_steps bounds only if it is not dynamic. + if not _is_tracer(max_decode_steps) and max_decode_steps > inputs.shape[1]: + raise ValueError('Cannot decode more steps than the sequence length.') + + # The number of decode steps required to process the prefix is the number + # of non-zero tokens, since inputs[0] == 0 is the BOS token. + # `max_decode_len[j]` is the number of non-padding tokens in the jth element + # of the returned sequences capped at `len(inputs)`, assuming that the + # early stop doesn't occur. This is true with or without + # `max_decode_steps`. + # When the while loop index `i` for the `j`th element `i[j] = + # max_decode_len[j] - 1`, the generated token populate sequences[i[j]+1]]. + # Since sequences[:, 0] is BOS token, the generated token is + # `max_decode_len[j]`th non-padding tokens and hence `j`th element is + # ended. + max_decode_len = jnp.sum(inputs != 0, axis=1) + max_decode_steps + max_decode_len = jnp.minimum(inputs.shape[1], max_decode_len) + + # In the case of starting generation from a non-zero index, it is possible for + # one batch element to reach `max_decode_len` number of decoding steps before + # another. In order to let the last element decoder all the way to + # `max_decode_len` number of steps, we add a final garbage token to the end of + # The sequences. Any element that has reached `max_decode_len` before the rest + # of the elements will continually overwrite this token until all elements + # finish. + # [batch, length+1] -> [batch, length+2] + extra_input_tokens = 2 + expanded_prompt_inputs = jnp.append( + inputs, + jnp.zeros((batch_size, extra_input_tokens), dtype=inputs.dtype), + axis=1) + end_marker = jnp.array(eos_id) + + temperature = jnp.asarray(temperature) + + # Initialize sampling loop state. + # initial loop PRNGKey + rng0 = prng_key + + # The per batch-item holding current token in loop. + if initial_index is None: + # The per batch-item loop position counter. + i0 = jnp.zeros((batch_size), dtype=jnp.int32) + # The per batch-item holding current token in loop. + token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32) + else: + # The per batch-item loop position counter. + i0 = initial_index + # The per batch-item holding current token in loop. + # Select the token that the initial index is pointing to. + token0 = jnp.take_along_axis( + expanded_prompt_inputs, jnp.expand_dims(i0, axis=1), axis=1) + # per batch-item state bit indicating if sentence has finished. + ended0 = jnp.zeros((batch_size, 1), dtype=jnp.bool_) + # (batch, length+2) array containing prefix prompt tokens for sampling loop + # as well as the generated output of newly sampled tokens. + sequences0 = expanded_prompt_inputs + log_prob0 = jnp.zeros((batch_size,), dtype=jnp.float32) + + sampling_loop_init_state = ArithmeticSamplingLoopState( + i0, sequences0, cache, token0, ended0, rng0, log_prob0, initial_codes) + # Initial eos count to be used to determine whether eos is "generated". Many + # inputs follow the format bos, inputs..., eos, targets..., eos. By counting + # The number of eos tokens we can detect when a new one is added, instead of + # just finding the one that probably ends the inputs. + # [batch, 1] + initial_eos_count = jnp.sum(sequences0 == end_marker, axis=-1, keepdims=True) + + def sampling_loop_cond_fn(state: ArithmeticSamplingLoopState) -> bool: + """Sampling loop termination condition.""" + # Have all sampled sequences reached an end marker? + # Different elements in the batch can be at different loop indices, if any + # of our examples are not at the end, keep going. + all_sequences_ended = jnp.all(state.ended) + return ~all_sequences_ended # pytype: disable=bad-return-type # jnp-type + + def sampling_loop_body_fn( + state: ArithmeticSamplingLoopState) -> ArithmeticSamplingLoopState: + """Sampling loop state update.""" + + if state_callback_fn is not None: + state = state_callback_fn(state) + + # Split RNG for sampling. + rng1, rng2 = random.split(state.rng) + # Call fast-decoder model on current tokens to get next-position logits. + decoding_state = decoding.DecodingState( + cur_index=state.cur_index, + sequences=state.sequences[:, :-extra_input_tokens], + cur_token=state.cur_token, + cache=state.cache) + logits, new_cache = tokens_to_logits(decoding_state) + # Sample next token from logits. + + if logit_callback_fn is not None: + logits = logit_callback_fn(logits, state) + + def sample_logits_with_nonzero_temperature(logits): + + # Before setting up the arithmetic sampling, we preprocess the logits into + # Their final form. + scaled_logits = logits / jnp.maximum(temperature, MIN_TEMPERATURE) + if topk: + # Get top-k logits and their indices, sample within these top-k tokens. + topk_logits, _ = lax.top_k(scaled_logits, topk) + cutoff_logit = topk_logits[:, -1, None] + scaled_logits = jnp.where(scaled_logits < cutoff_logit, + jnp.full_like(scaled_logits, NEG_INF), + scaled_logits) + + # When topp is dynamic, we always use it since we cannot check + # non-zeroness (but it will have no effect if topp is 0.0). + if _is_tracer(topp) or topp: + logits_sorted = jnp.sort( + scaled_logits, axis=-1)[:, ::-1] # sort descending + sorted_cum_probs = jnp.cumsum( + jax.nn.softmax(logits_sorted, axis=-1), axis=-1) + cutoff_index = jnp.sum(sorted_cum_probs < topp, axis=-1, keepdims=True) + cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) + scaled_logits = jnp.where(scaled_logits < cutoff_logit, + jnp.full_like(scaled_logits, NEG_INF), + scaled_logits) + + next_token, next_code = _arithmetic_categorical(rng1, scaled_logits, + state.codes) + + # log probability of the current token conditioned on the previously + # sampled and prefix tokens. + # [batch, vocab] -> [batch, vocab] + if rescale_log_probs: + log_probs = jax.nn.log_softmax(scaled_logits) + else: + log_probs = jax.nn.log_softmax(logits) + # [batch, vocab] -> [batch] + next_log_prob = jnp.squeeze( + jnp.take_along_axis( + log_probs, jnp.expand_dims(next_token, axis=1), axis=-1), + axis=-1) + + return (next_token, next_log_prob, next_code) + + def sample_logits_with_zero_temperature(logits): + # For zero temperature, we always want the greedy output, regardless + # of the values of topk and topp. + + next_token = jnp.argmax(logits, -1).astype(jnp.int32) + + if rescale_log_probs: + next_log_prob = jnp.zeros_like(next_token, dtype=jnp.float32) + else: + log_probs = jax.nn.log_softmax(logits) + next_log_prob = jnp.squeeze( + jnp.take_along_axis( + log_probs, jnp.expand_dims(next_token, axis=1), axis=-1), + axis=-1) + + return (next_token, next_log_prob, state.codes) + + # Perform sampling with temperature + (next_token, next_log_prob, + next_code) = lax.cond(temperature > MIN_TEMPERATURE, + sample_logits_with_nonzero_temperature, + sample_logits_with_zero_temperature, logits) + + # When different batch elements are at different points in the loop counter, + # it is possible that an element that started at a higher index will reach + # `max_decode_len` before other elements. When this happens we need to make + # sure this element continuous overwrites our new garbage collection index. + # Here we clamp `i` to `max_decode_len`. This will cause the a write to + # `max_decode_len + 1` which is the final index in `sequences`. Subsequent + # loop body executions will also get their value clamped causing continual + # overwriting of the final garbage position until all examples are finished. + i = jnp.minimum(state.cur_index, max_decode_len) + + # Only use sampled tokens if we're past provided prefix tokens. + # Select the next token from sequences. + # [batch] + next_input_token = jnp.squeeze( + jnp.take_along_axis( + state.sequences, jnp.expand_dims(i + 1, axis=1), axis=1), + axis=1) + # Check if the next token is padding (a target) or non-padding (an input). + # Mask will have `1` for targets and `0` for inputs. + out_of_prompt = (next_input_token == 0) + # Select the sampled next token for targets and the actual next token for + # inputs (teacher forcing). + # [batch] + next_token = ( + next_token * out_of_prompt + next_input_token * ~out_of_prompt) + + # only add probability if outside prefix region + # [batch] -> [batch] + next_log_prob = state.log_prob + ( + next_log_prob * out_of_prompt) * jnp.squeeze( + ~state.ended, axis=-1).astype(jnp.int32) + + # [batch] -> [batch, 1] + next_token = jnp.expand_dims(next_token, axis=-1) + + # If end-marker reached for batch item, only emit padding tokens. + # [batch, 1] * [batch, 1] -> [batch, 1] + next_token_or_endpad = next_token * ~state.ended + # Add current sampled tokens to recorded sequences. + one_hot = jax.nn.one_hot( + i + 1, state.sequences.shape[1], dtype=state.sequences.dtype) + new_sequences = state.sequences * (1 - + one_hot) + next_token_or_endpad * one_hot + # new_sequences = dynamic_update_vector_slice_in_dim(sequences, + # next_token_or_endpad, + # i + 1, + # 0) + # Count eos tokens in the sequences and compare to the initial count + # [batch, 1] + cur_eos_count = jnp.sum(new_sequences == end_marker, axis=-1, keepdims=True) + # [batch, 1] + + # Have we reached max decoding length? + # We generally index into sequences[:, i + 1], and sequences.shape[1] = + # max_decode_len + 2, therefore i == max_decode_len - 1 will write to + # sequences[-2] which is our last valid location. i == max_decode_len will + # write to sequences[-1] which is our garbage collection token. Thus `i` + # should be strictly less than max_decode_len. + has_additional_eos = cur_eos_count > initial_eos_count + ended = state.ended | has_additional_eos | jnp.expand_dims( + i >= max_decode_len - 1, axis=1) + + return ArithmeticSamplingLoopState(i + 1, new_sequences, new_cache, + next_token_or_endpad, ended, rng2, + next_log_prob, next_code) + + # Run sampling loop and collect final state. + final_state = lax.while_loop(sampling_loop_cond_fn, sampling_loop_body_fn, + sampling_loop_init_state) + + # Pick part of the state corresponding to the sampled sequences. + final_sequences = final_state.sequences + log_prob = final_state.log_prob + # Drop the first position because they are dummy bos tokens. Drop the new + # garbage collection token at the end too. + return final_sequences[:, 1:-1], log_prob # pytype: disable=bad-return-type # jax-ndarray diff --git a/protoscribe/pmmx/arithmetic_sampling_test.py b/protoscribe/pmmx/arithmetic_sampling_test.py new file mode 100644 index 0000000..90da8b8 --- /dev/null +++ b/protoscribe/pmmx/arithmetic_sampling_test.py @@ -0,0 +1,791 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for decoding module. + +Cloned from https://github.com/google-research/google-research/tree/master/arithmetic_sampling/. # pylint: disable=line-too-long +""" + +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax.experimental import io_callback +import jax.numpy as jnp +import numpy as np +from protoscribe.pmmx import arithmetic_sampling as sampling +from t5x import decoding + +EOS_ID = 1 +NEG_INF = decoding.NEG_INF + + +class DecodingTest(parameterized.TestCase): + + def test_arithmetic_sample_uneven_prefix(self): + def token_to_logits(decoding_state: decoding.DecodingState): + del decoding_state + # Always sample id 2 for batch element 0 and id 3 for element 1. + logits = np.array( + [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 + ) + return logits, {} + + inputs = np.array([[0, 5, 7, 1, 0, 0], [0, 6, 1, 0, 0, 0]]) + rng = jax.random.PRNGKey(0) + codes = decoding.flatten_beam_dim(sampling._make_default_codes(2, 1, rng)) + sampled_sequences, _ = sampling._arithmetic_sample_single_trial( + inputs, + codes, + {}, + token_to_logits, + EOS_ID, + rng, + topk=0, + initial_index=np.array([3, 2]), + ) + expected = np.array([[5, 7, 1, 2, 2, 2], [6, 1, 3, 3, 3, 3]]) + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_arithmetic_sample_no_prefix(self): + batch, max_decode_len = 2, 3 + + def token_to_logits(decoding_state: decoding.DecodingState): + del decoding_state + # Always sample id 2 for batch element 0 and id 3 for element 1. + logits = np.array( + [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 + ) + return logits, {} + + inputs = np.zeros((batch, max_decode_len), dtype=np.int32) + rng = jax.random.PRNGKey(0) + codes = decoding.flatten_beam_dim(sampling._make_default_codes(2, 1, rng)) + sampled_sequences, _ = sampling._arithmetic_sample_single_trial( + inputs, codes, {}, token_to_logits, EOS_ID, rng, topk=0 + ) + + expected = [[2, 2, 2], [3, 3, 3]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_arithmetic_sample_prefix(self): + def token_to_logits(decoding_state: decoding.DecodingState): + del decoding_state + # Always sample id 2 for batch element 0 and id 3 for element 1. + logits = np.array( + [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 + ) + return logits, {} + + # Batch element 0 has length 3 prefix and element 1 has length 2. + inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) + rng = jax.random.PRNGKey(0) + codes = decoding.flatten_beam_dim(sampling._make_default_codes(2, 1, rng)) + sampled_sequences, _ = sampling._arithmetic_sample_single_trial( + inputs, codes, {}, token_to_logits, EOS_ID, rng, topk=0 + ) + + expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_arithmetic_sample_with_zero_temperature(self): + batch, max_decode_len = 2, 3 + + def token_to_logits(decoding_state: decoding.DecodingState): + del decoding_state + # Use very large logits that are close to one another. + logits = np.array( + [[1700.47, 1700.48, 1700.51, 1700.45], [3.2, 4.8, -5.3, 5.6]], + dtype=np.float32, + ) + return logits, {} + + inputs = np.zeros((batch, max_decode_len), dtype=np.int32) + rng = jax.random.PRNGKey(0) + codes = decoding.flatten_beam_dim(sampling._make_default_codes(2, 1, rng)) + sampled_sequences, _ = sampling._arithmetic_sample_single_trial( + inputs, codes, {}, token_to_logits, EOS_ID, rng, topk=4, temperature=0.0 + ) + + expected = [[2, 2, 2], [3, 3, 3]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_arithmetic_sample_prefix_ending_with_eos(self): + def token_to_logits(decoding_state: decoding.DecodingState): + del decoding_state + # Always sample id 2 for batch element 0 and id 3 for element 1. + logits = np.array( + [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 + ) + return logits, {} + + # Batch element 0 has length 4 prefix (including the initial dummy token and + # the last eos) and element 1 has length 3. + inputs = np.array([[0, 5, 6, 1, 0], [0, 8, 1, 0, 0]], dtype=np.int32) + rng = jax.random.PRNGKey(0) + codes = decoding.flatten_beam_dim(sampling._make_default_codes(2, 1, rng)) + sampled_sequences, _ = sampling._arithmetic_sample_single_trial( + inputs, codes, {}, token_to_logits, EOS_ID, rng, topk=1 + ) + + expected = [[5, 6, 1, 2, 2], [8, 1, 3, 3, 3]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_arithmetic_sample_with_state_callback(self): + def token_to_logits(decoding_state: decoding.DecodingState): + del decoding_state + # A distribution with roughly all probability mass in sample id 3. + logits = np.array( + [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 + ) + return logits, {} + + def state_callback_fn(state): + def callback_fn(current_index_and_sequences): + """Add EOS token after first time token id 3 has been sampled.""" + current_index, sequences = current_index_and_sequences + sequences = np.array(sequences) + for i in range(len(current_index)): + if sequences[i, current_index[i]] == 3: + sequences[i, current_index[i] + 1] = EOS_ID + return sequences + + sequences = io_callback( + callback_fn, + jax.ShapeDtypeStruct(state.sequences.shape, state.sequences.dtype), + (state.cur_index, state.sequences), + ) + return state.replace(sequences=sequences) + + inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) + rng = jax.random.PRNGKey(0) + codes = decoding.flatten_beam_dim(sampling._make_default_codes(2, 1, rng)) + sampled_sequences, _ = sampling._arithmetic_sample_single_trial( + inputs, + codes, + {}, + token_to_logits, + EOS_ID, + rng, + topk=0, + temperature=0.0, + state_callback_fn=state_callback_fn, + ) + + expected = [[5, 6, 7, 3, EOS_ID], [8, 9, 3, EOS_ID, 0]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_arithmetic_sample_with_logit_callback(self): + def token_to_logits(decoding_state: decoding.DecodingState): + del decoding_state + # Uniform distribution over targets from model. + logits = np.array( + [[-1e7, -1e7, -1e7, -1e7], [-1e7, -1e7, -1e7, -1e7]], dtype=np.float32 + ) + return logits, {} + + def logit_callback_fn(logits, state): + del state # Unused. + # Rewrite logits to always sample id 2 for batch element 0 and + # id 3 for element 1. + logits[0, 2] = 0 + logits[1, 3] = 0 + return logits + + # Batch element 0 has length 3 prefix and element 1 has length 2. + inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) + rng = jax.random.PRNGKey(0) + codes = decoding.flatten_beam_dim(sampling._make_default_codes(2, 1, rng)) + sampled_sequences, _ = sampling._arithmetic_sample_single_trial( + inputs, + codes, + {}, + token_to_logits, + EOS_ID, + rng, + topk=0, + temperature=0.0, + logit_callback_fn=logit_callback_fn, + ) + + expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_arithmetic_sample_prefix_ending_with_eos_early_stop(self): + batch, max_decode_len = 2, 7 + rng0 = jax.random.PRNGKey(0) + + ret = [np.array([2, 3]) for _ in range(max_decode_len)] + # Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of + # `decoding._temperature_sample_single_trial`. + ret[3] = np.array([2, 1]) + # Sequence 0 outputs EOS=1 when i = 4. + ret[4] = np.array([1, 3]) + ret = jax.numpy.array(ret) + + def mocked_categorical(rng_input, logits, codes): # pylint: disable=unused-argument + """Ignores logit and codes and returns only based on the rng_input.""" + rng = rng0 + k = 0 + # Mimic the rng split done in `decoding.sample_loop_body_fn`. + for j in range(max_decode_len): + rng1, rng = jax.random.split(rng) + # We want to sift out `j` for which rng1 == rng_input + # rngs are a pair of ints. So sum the bool and divide by 2. + k += j * (rng1 == rng_input).sum() // 2 + # `k` at this point is equal to the while loop variable `i` of the caller. + return ret[k], codes + + def token_to_logits(decoding_state: decoding.DecodingState): + del decoding_state + # These values are not used in this test because random.categorical is + # directly mocked. + dummy_logits = np.zeros((batch, 4), dtype=np.float32) + return dummy_logits, {} + + inputs = np.array( + [[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]], dtype=np.int32 + ) + codes = decoding.flatten_beam_dim(sampling._make_default_codes(2, 1, rng0)) + with mock.patch.object( + sampling, '_arithmetic_categorical', new=mocked_categorical + ): + sampled_sequences, _ = sampling._arithmetic_sample_single_trial( + inputs, codes, {}, token_to_logits, EOS_ID, rng0, topk=0 + ) + + expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_greedy_decoding_topk_sample_log_probs(self): + def token_to_logits(decoding_state: decoding.DecodingState): + del decoding_state + # Sample [2, 3] with probability [0.6, 0.4]. + logits = np.array( + [[-1e7, -1e7, -0.510825624, -0.916290732]], dtype=np.float32 + ) + return logits, {} + + inputs = np.array([[0, 2, 2, 2, 0]], dtype=np.int32) + rng = jax.random.PRNGKey(0) + input_codes = decoding.flatten_beam_dim( + sampling._make_default_codes(1, 1, rng) + ) + sampled_sequences, sampled_log_probs = ( + sampling._arithmetic_sample_single_trial( + inputs, + input_codes, + {}, + token_to_logits, + EOS_ID, + rng, + topk=1, + rescale_log_probs=True, + ) + ) + + expected_sequence = [[2, 2, 2, 2, 2]] + expected_log_probs = [0.0] + np.testing.assert_array_equal(expected_sequence, sampled_sequences) + np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs) + + inputs = np.array([[0, 2, 2, 3, 0]], dtype=np.int32) + rng = jax.random.PRNGKey(0) + input_codes = decoding.flatten_beam_dim( + sampling._make_default_codes(1, 1, rng) + ) + sampled_sequences, sampled_log_probs = ( + sampling._arithmetic_sample_single_trial( + inputs, + input_codes, + {}, + token_to_logits, + EOS_ID, + rng, + topk=1, + rescale_log_probs=False, + ) + ) + + expected_sequence = [[2, 2, 3, 2, 2]] + expected_log_probs = [-1.02165125] + np.testing.assert_array_equal(expected_sequence, sampled_sequences) + np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs) + + def test_arithmetic_sample_log_prob(self): + batch, max_decode_len = 2, 7 + rng0 = jax.random.PRNGKey(0) + + ret = [np.array([2, 3]) for _ in range(max_decode_len)] + # Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of + # `decoding._temperature_sample_single_trial`. + ret[3] = np.array([2, 1]) + # Sequence 0 outputs EOS=1 when i = 4. + ret[4] = np.array([1, 3]) + ret = jax.numpy.array(ret) + + def mocked_categorical(rng_input, logits, codes): # pylint: disable=unused-argument + """Ignores logit and codes and returns only based on the rng_input.""" + rng = rng0 + k = 0 + # Mimic the rng split done in `decoding.sample_loop_body_fn`. + for j in range(max_decode_len): + rng1, rng = jax.random.split(rng) + # We want to sift out `j` for which rng1 == rng_input + # rngs are a pair of ints. So sum the bool and divide by 2. + k += j * (rng1 == rng_input).sum() // 2 + # `k` at this point is equal to the while loop variable `i` of the caller. + return ret[k], codes + + logits = np.random.randn(batch, 4) + token_to_logits = lambda decoding_state: (logits, {}) + inputs = np.array( + [[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]], dtype=np.int32 + ) + codes = decoding.flatten_beam_dim(sampling._make_default_codes(2, 1, rng0)) + with mock.patch.object( + sampling, '_arithmetic_categorical', new=mocked_categorical + ): + sampled_sequences, log_prob = sampling._arithmetic_sample_single_trial( + inputs, codes, {}, token_to_logits, EOS_ID, rng0, topk=0 + ) + + log_probs = jax.nn.log_softmax(logits) + expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]] + expected_log_prob = [ + log_probs[0, 2] + log_probs[0, 2] + log_probs[0, 1], + log_probs[1, 3] + log_probs[1, 3] + log_probs[1, 1], + ] + expected_log_prob = np.array(expected_log_prob) + np.testing.assert_array_equal(expected, sampled_sequences) + np.testing.assert_allclose(expected_log_prob, log_prob, atol=1e-5) + + def test_arithmetic_sample_num_decodes(self): + num_decodes = 3 + rng0 = jax.random.PRNGKey(0) + inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32) + + with mock.patch.object( + sampling, '_arithmetic_sample_single_trial' + ) as mocked: + # Expanded_decodes: [batch * num_decodes, max_decode_len] + expanded_decodes = np.array([ + [5, 1, 4, 4], + [5, 1, 5, 5], + [5, 1, 3, 3], + [8, 7, 5, 5], + [8, 7, 3, 3], + [8, 7, 4, 4], + ]) + # Expanded_log_prob: [batch * num_decodes] + expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9]) + mocked.return_value = expanded_decodes, expanded_log_prob + + decodes, scores = sampling.arithmetic_sample( + inputs, {}, mock.Mock(), EOS_ID, rng0, num_decodes=num_decodes + ) + + expanded_inputs = jnp.array([ + [0, 5, 1, 0], + [0, 5, 1, 0], + [0, 5, 1, 0], + [0, 8, 7, 0], + [0, 8, 7, 0], + [0, 8, 7, 0], + ]) + # Test that the actual decode function is called with the expanded values. + np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs) + + np.testing.assert_array_equal( + decodes, + [ + [[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]], + [[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]], + ], + ) + np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]]) + + def test_arithmetic_sample_num_decodes_with_initial_index(self): + num_decodes = 3 + rng0 = jax.random.PRNGKey(0) + inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32) + initial_index = np.array([1, 2], dtype=np.int32) + + with mock.patch.object( + sampling, '_arithmetic_sample_single_trial' + ) as mocked: + with mock.patch.object(decoding, 'cache_map') as mocked_cache_map: + # Expanded_decodes: [batch * num_decodes, max_decode_len] + expanded_decodes = np.array([ + [5, 1, 4, 4], + [5, 1, 5, 5], + [5, 1, 3, 3], + [8, 7, 5, 5], + [8, 7, 3, 3], + [8, 7, 4, 4], + ]) + # Expanded_log_prob: [batch * num_decodes] + expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9]) + mocked.return_value = expanded_decodes, expanded_log_prob + + decodes, scores = sampling.arithmetic_sample( + inputs, + {}, + mock.Mock(), + EOS_ID, + rng0, + num_decodes=num_decodes, + initial_index=initial_index, + ) + + expanded_inputs = jnp.array([ + [0, 5, 1, 0], + [0, 5, 1, 0], + [0, 5, 1, 0], + [0, 8, 7, 0], + [0, 8, 7, 0], + [0, 8, 7, 0], + ]) + expanded_initial_index = np.array([1, 1, 1, 2, 2, 2], dtype=np.int32) + # Test that the actual decode function is called with the expanded + # values. + np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs) + np.testing.assert_array_equal( + mocked.call_args[1]['initial_index'], expanded_initial_index + ) + # Test that the function was applied to the index in the cache map. + self.assertTrue(mocked_cache_map.call_args[1]['apply_to_index']) + + np.testing.assert_array_equal( + decodes, + [ + [[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]], + [[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]], + ], + ) + np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]]) + + @parameterized.named_parameters( + dict( + testcase_name='no_initial_index', + initial_index=None, + expected_calls=6, + ), + dict( + testcase_name='initial_index', + initial_index=np.array([1, 2], dtype=np.int32), + expected_calls=4, + ), + dict( + testcase_name='lower_initial_index', + initial_index=np.array([1, 1], dtype=np.int32), + expected_calls=5, # We decode 4 tokens out of the prompt. + ), + ) + def test_arithmetic_sample_max_decode_steps_with_initial_index( + self, initial_index, expected_calls + ): + max_decode_steps = 4 + rng0 = jax.random.PRNGKey(0) + inputs = np.array( + [[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]], dtype=np.int32 + ) + + token_to_logits = mock.Mock() + token_to_logits.return_value = ( + np.array( + [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 + ), + {}, + ) + + # Unroll while loop. + with jax.disable_jit(): + decodes, scores = sampling.arithmetic_sample( + inputs, + {}, + token_to_logits, + EOS_ID, + rng0, + initial_index=initial_index, + topk=4, + max_decode_steps=max_decode_steps, + ) + + self.assertLen(token_to_logits.call_args_list, expected_calls) + + expected_output = np.array( + [[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 3, 3, 3, 3, 0, 0]] + ) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_array_equal(scores, [[0.0], [0.0]]) + + def test_arithmetic_sample_max_decode_steps_endpad(self): + max_decode_steps = 4 + rng0 = jax.random.PRNGKey(0) + inputs = np.array( + [ + [0, 2, 0, 0, 0, 0, 0, 0], + [0, 2, 2, 2, 2, 2, 2, 0], + [0, 2, 2, 2, 0, 0, 0, 0], + ], + dtype=np.int32, + ) + initial_index = np.array([1, 6, 0]) + + token_to_logits = mock.Mock() + token_to_logits.return_value = ( + np.array( + [ + [-1e7, -1e7, -1e7, 0], + [-1e7, -1e7, -1e7, 0], + [-1e7, -1e7, -1e7, 0], + ], + dtype=np.float32, + ), + {}, + ) + + # Unroll while loop. + with jax.disable_jit(): + decodes, scores = sampling.arithmetic_sample( + inputs, + {}, + token_to_logits, + EOS_ID, + rng0, + initial_index=initial_index, + topk=4, + max_decode_steps=max_decode_steps, + ) + + # `inputs[2]` starts from index 0. So it requires 3 calls to + # `token_to_logits` to exit the prompt (these generated tokens are + # overridden) and 4 more calls to fill the rest. `inputs[0]` only need 4 + # calls. In the last 3 calls, it generates but MUST NOT populate the + # sequences because it is already ended. + self.assertLen(token_to_logits.call_args_list, 7) + expected_output = np.array( + [ + [2, 3, 3, 3, 3, 0, 0, 0], + [2, 2, 2, 2, 2, 2, 3, 3], + [2, 2, 2, 3, 3, 3, 3, 0], + ], + dtype=np.int32, + ) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_allclose(scores, [[0.0], [0.0], [0.0]]) + + def test_arithmetic_sample_max_decode_steps_docstring_ex4(self): + max_decode_steps = 2 + rng0 = jax.random.PRNGKey(0) + inputs = np.array( + [[0, 2, 0, 0, 0, 0, 0, 0], [0, 3, 4, 0, 0, 0, 0, 0]], dtype=np.int32 + ) + initial_index = np.array([1, 2]) + + token_to_logits = mock.Mock() + token_to_logits.return_value = ( + np.array( + [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 + ), + {}, + ) + + # Unroll while loop. + with jax.disable_jit(): + decodes, _ = sampling.arithmetic_sample( + inputs, + {}, + token_to_logits, + EOS_ID, + rng0, + initial_index=initial_index, + topk=4, + max_decode_steps=max_decode_steps, + ) + self.assertLen(token_to_logits.call_args_list, 2) + expected_output = np.array( + [[2, 2, 2, 0, 0, 0, 0, 0], [3, 4, 3, 3, 0, 0, 0, 0]], dtype=np.int32 + ) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + + def test_arithmetic_sample_max_decode_steps_hard_limit(self): + max_decode_steps = 10 + max_decode_steps_hard_limit = 4 + rng0 = jax.random.PRNGKey(0) + inputs = np.array( + [[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]], dtype=np.int32 + ) + + token_to_logits = mock.Mock() + token_to_logits.return_value = ( + np.array( + [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32 + ), + {}, + ) + + # Unroll while loop. + with jax.disable_jit(): + decodes, scores = sampling.arithmetic_sample( + inputs, + {}, + token_to_logits, + EOS_ID, + rng0, + topk=4, + max_decode_steps=max_decode_steps, + max_decode_steps_hard_limit=max_decode_steps_hard_limit, + ) + + expected_output = np.array( + [[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 3, 3, 3, 3, 0, 0]] + ) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_array_equal(scores, [[0.0], [0.0]]) + + def test_arithmetic_sample_topp(self): + with jax.disable_jit(): + rng0 = jax.random.PRNGKey(0) + inputs = np.zeros((1, 20), dtype=np.int32) + + token_to_logits = mock.Mock() + + # Logits correspond to (0.3, 0, 0.1, 0.6). + token_to_logits.return_value = ( + np.array([[-1.2, -1e7, -2.3, -0.51]], dtype=np.float32), + {}, + ) + + decodes, scores = sampling.arithmetic_sample( + inputs, {}, token_to_logits, EOS_ID, rng0, topp=0.55, topk=0 + ) # Anything under 0.6 will trigger deterministic decoding. + + expected_output = np.array([[3] * 20]) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_array_equal(scores, [[0.0]]) + + # Temperature is applied first, so the distribution becomes + # (0.27, 0, 0.069, 0.65), so if topp is 0.63, it should become greedy. + decodes, scores = sampling.arithmetic_sample( + inputs, + {}, + token_to_logits, + EOS_ID, + rng0, + temperature=0.8, + topp=0.63, + topk=0, + ) + + expected_output = np.array([[3] * 20]) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_array_equal(scores, [[0.0]]) + + def test_dynamic_topp_max_decode_steps(self): + rng0 = jax.random.PRNGKey(0) + inputs = np.zeros((1, 20), dtype=np.int32) + + token_to_logits = mock.Mock() + + # Logits correspond to (0.3, 0, 0.1, 0.6). + token_to_logits.return_value = ( + np.array([[-1.2, -1e7, -2.3, -0.51]], dtype=np.float32), + {}, + ) + + def dynamic_decode_fn(inputs, temperature, topp, max_decode_steps): + return sampling.arithmetic_sample( + inputs, + {}, + token_to_logits, + EOS_ID, + rng0, + temperature=temperature, + topp=topp, + topk=0, + max_decode_steps=max_decode_steps, + ) + + dynamic_decode_fn_jit = jax.jit(dynamic_decode_fn) + + decodes, scores = dynamic_decode_fn_jit(inputs, 0.8, 0.63, 10) + + expected_output = np.array([[3] * 10 + [0] * 10]) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_array_equal(scores, [[0.0]]) + + def test_topp_log_probs(self): + rng0 = jax.random.PRNGKey(0) + inputs = np.zeros((1, 1), dtype=np.int32) + + token_to_logits = mock.Mock() + + # Logits correspond to (0.3, 0, 0.1, 0.6). + token_to_logits.return_value = ( + np.array([[-1.2, NEG_INF, -2.3, -0.51]], dtype=np.float32), + {}, + ) + + with jax.disable_jit(): + # This lets us see logits after topp and topk are applied. + + with mock.patch.object(sampling, '_arithmetic_categorical') as mocked: + mocked.return_value = jnp.array([0], dtype=jnp.int32), jnp.array( + [0], dtype=jnp.int32 + ) + decodes, _ = sampling.arithmetic_sample( + inputs, + {}, + token_to_logits, + EOS_ID, + rng0, + temperature=1.4, + topp=0.7, + topk=0, + ) + + self.assertLen(token_to_logits.call_args_list, 1) + np.testing.assert_array_equal(decodes, jnp.asarray([[[0]]])) + + np.testing.assert_array_almost_equal( + mocked.call_args_list[0][0][1], + jnp.asarray([[-0.85714293, NEG_INF, NEG_INF, -0.36428571]]), + ) + + def test_sequential_cumsum(self): + test_arr = jnp.arange(0, 1000) + test_cumsum = sampling._sequential_cumsum(test_arr, 0) + target_cumsum = np.cumsum(jnp.asarray(test_arr)) + np.testing.assert_array_equal(jnp.asarray(test_cumsum), target_cumsum) + + +if __name__ == '__main__': + absltest.main()