diff --git a/protoscribe/models/pmmx/model_config_gin_test.py b/protoscribe/models/pmmx/model_config_gin_test.py new file mode 100644 index 0000000..9dde4db --- /dev/null +++ b/protoscribe/models/pmmx/model_config_gin_test.py @@ -0,0 +1,101 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sanity check for miscellaneous model configurations in gin.""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import gin + +# Core PMMX configurations. These are copied from +# protoscribe/pmmx/config/runs +# and modified to work in test environment. +_CORE_CONFIG_BASE_DIR = ( + "protoscribe/models/pmmx/testdata" +) + +# Configurations for individual models. +_MODEL_CONFIG_BASE_DIR = ( + "protoscribe/models/pmmx/configs" +) + + +def _config_path( + filename: str, + config_dir: str = _MODEL_CONFIG_BASE_DIR +) -> str: + """Returns full path of the specified file name.""" + return os.path.join( + absltest.get_default_test_srcdir(), config_dir, filename + ) + + +class ModelConfigGinTest(parameterized.TestCase): + + def tearDown(self): + super().tearDown() + gin.clear_config() + + @parameterized.parameters( + "glyph_concepts", + "glyph_phonemes", + ) + def test_model_train(self, model_dir: str) -> None: + """Tests tiny model configuration for training.""" + tmp_dir = absltest.get_default_test_tmpdir() + gin.parse_config_files_and_bindings( + config_files=[ + _config_path("pretrain_test.gin", config_dir=_CORE_CONFIG_BASE_DIR), + _config_path(os.path.join(model_dir, "model_tiny.gin")), + _config_path(os.path.join(model_dir, "dataset.gin")), + ], + bindings=[ + f"DATA_DIR=\"{tmp_dir}\"", + f"MODEL_DIR=\"{tmp_dir}\"", + "TRAIN_STEPS=1", + "BATCH_SIZE=8", + "EVAL_BATCH_SIZE=8", + ], + finalize_config=True, + skip_unknown=False + ) + + @parameterized.parameters( + "glyph_concepts", + "glyph_phonemes", + ) + def test_model_infer(self, model_dir: str) -> None: + """Tests tiny model configuration in inference mode.""" + tmp_dir = absltest.get_default_test_tmpdir() + gin.parse_config_files_and_bindings( + config_files=[ + _config_path("infer_test.gin", config_dir=_CORE_CONFIG_BASE_DIR), + _config_path(os.path.join(model_dir, "model_tiny.gin")), + _config_path(os.path.join(model_dir, "dataset.gin")), + ], + bindings=[ + f"DATA_DIR=\"{tmp_dir}\"", + f"CHECKPOINT_PATH=\"{tmp_dir}\"", + f"INFER_OUTPUT_DIR=\"{tmp_dir}\"", + "BATCH_SIZE=8", + ], + finalize_config=True, + skip_unknown=False + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/protoscribe/models/pmmx/testdata/README.md b/protoscribe/models/pmmx/testdata/README.md new file mode 100644 index 0000000..20dc110 --- /dev/null +++ b/protoscribe/models/pmmx/testdata/README.md @@ -0,0 +1,4 @@ +# Testdata for PMMX models. + +* `pretrain_test.gin`, `infer_test.gin`: Custom tine model gin configurations + for pretraining and inference. diff --git a/protoscribe/models/pmmx/testdata/infer_test.gin b/protoscribe/models/pmmx/testdata/infer_test.gin new file mode 100644 index 0000000..9b2db70 --- /dev/null +++ b/protoscribe/models/pmmx/testdata/infer_test.gin @@ -0,0 +1,90 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Defaults for infer.py. +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference +# - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features +# to. +# - CHECKPOINT_PATH: The model checkpoint to use for inference +# - INFER_OUTPUT_DIR: The dir to write results to. When launching using +# XManager, this is set automatically. +# + +from __gin__ import dynamic_registration + +from protoscribe.pmmx.utils import partitioning_utils +from t5x import partitioning +from t5x import utils + +# -------------------------------------------------- +# From t5x/configs/runs/infer.gin: +# -------------------------------------------------- + +# Must be overridden +MIXTURE_OR_TASK_NAME = %gin.REQUIRED +TASK_FEATURE_LENGTHS = %gin.REQUIRED +CHECKPOINT_PATH = %gin.REQUIRED +INFER_OUTPUT_DIR = %gin.REQUIRED + +# DEPRECATED: Import the this module in your gin file. +MIXTURE_OR_TASK_MODULE = None + +partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + module = %MIXTURE_OR_TASK_MODULE + task_feature_lengths = %TASK_FEATURE_LENGTHS + use_cached = False + split = 'test' + batch_size = 32 + shuffle = False + seed = 0 + pack = False + +utils.RestoreCheckpointConfig: + path = %CHECKPOINT_PATH + mode = 'specific' + dtype = 'bfloat16' + +# -------------------------------------------------- +# From PMMX: +# -------------------------------------------------- + +partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +partitioning.standard_logical_axis_rules: + additional_rules = @partitioning_utils.additional_axis_rules() + +# Must be overridden +MIXTURE_OR_TASK_NAME = %gin.REQUIRED +TASK_FEATURE_LENGTHS = %gin.REQUIRED +CHECKPOINT_PATH = %gin.REQUIRED +INFER_OUTPUT_DIR = %gin.REQUIRED +BATCH_SIZE = %gin.REQUIRED + +utils.DatasetConfig: + batch_size = %BATCH_SIZE + +# No falling back to scratch for inference. +utils.RestoreCheckpointConfig.fallback_to_scratch = False diff --git a/protoscribe/models/pmmx/testdata/pretrain_test.gin b/protoscribe/models/pmmx/testdata/pretrain_test.gin new file mode 100644 index 0000000..9a68b88 --- /dev/null +++ b/protoscribe/models/pmmx/testdata/pretrain_test.gin @@ -0,0 +1,214 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Defaults for pretraining with train.py. +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# -MIXTURE_OR_TASK_NAME +# -MIXTURE_OR_TASK_MODULE +# -TASK_FEATURE_LENGTHS +# -MODEL_DIR +# +# Commonly overridden options: +# +# - DatasetConfig.batch_size +# - PjitPartitioner.num_partitions +# - Trainer.num_microbatches +from __gin__ import dynamic_registration +import seqio +from t5x import adafactor +from t5x import checkpoints +from t5x import gin_utils +from t5x import partitioning +from t5x import trainer +from t5x import utils + +from protoscribe.pmmx.utils import adafactor_utils +from protoscribe.pmmx.utils import gin_str_utils +from protoscribe.pmmx.utils import partitioning_utils +from protoscribe.pmmx.utils import seqio_utils + +# -------------------------------------------------- +# From t5x/configs/runs/pretrain.gin: +# -------------------------------------------------- + +MIXTURE_OR_TASK_NAME = %gin.REQUIRED +TASK_FEATURE_LENGTHS = %gin.REQUIRED +TRAIN_STEPS = %gin.REQUIRED +MODEL_DIR = %gin.REQUIRED +BATCH_SIZE = 128 +USE_CACHED_TASKS = True + +# DEPRECATED: Import the this module in your gin file. +MIXTURE_OR_TASK_MODULE = None +SHUFFLE_TRAIN_EXAMPLES = True + +# HW RNG is faster than SW, but has limited determinism. +# Most notably it is not deterministic across different +# submeshes. +USE_HARDWARE_RNG = False +# None always uses faster, hardware RNG +RANDOM_SEED = None +TRAIN_STEPS_RELATIVE = None + +partitioning.PjitPartitioner: + num_partitions = 1 + model_parallel_submesh = None + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +train/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'train' + batch_size = %BATCH_SIZE + shuffle = %SHUFFLE_TRAIN_EXAMPLES + seed = None # use a new seed each run/restart + use_cached = %USE_CACHED_TASKS + pack = True + module = %MIXTURE_OR_TASK_MODULE + +train_eval/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'validation' + batch_size = %BATCH_SIZE + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS + pack = True + module = %MIXTURE_OR_TASK_MODULE + +utils.CheckpointConfig: + restore = @utils.RestoreCheckpointConfig() + save = @utils.SaveCheckpointConfig() +utils.RestoreCheckpointConfig: + path = [] # initialize from scratch +utils.SaveCheckpointConfig: + period = 1000 + dtype = 'float32' + keep = None # keep all checkpoints + save_dataset = False # don't checkpoint dataset state + +trainer.Trainer: + num_microbatches = None + learning_rate_fn = @utils.create_learning_rate_scheduler() + +utils.create_learning_rate_scheduler: + factors = 'constant * rsqrt_decay' + base_learning_rate = 1.0 + warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. + +# -------------------------------------------------- +# Overrides for PMMX: +# -------------------------------------------------- + +# Use None by default, which infers the max length by iterating over +# the evaluation set, for efficiency reasons. +# Some features, such as LP Summarizer input text, have very long lengths and +# need to be truncated to a max length to avoid HBM OOMs. +# You may specify only some of the feature lengths; others will be inferred. +INFER_TASK_FEATURE_LENGTHS = None + +NUM_PARTITIONS = 1 +MODEL_DIR = %gin.REQUIRED +LEARNING_RATE = 1.0 +WARMUP_STEPS = 1000 +EVAL_PERIOD = 1000 +EVAL_SPLIT = 'validation' +# Adding this alias for clarity that EVAL_SPLIT (And the alias INFER_EVAL_SPLIT) +# applies to infer_eval, not training_eval. +INFER_EVAL_SPLIT = %EVAL_SPLIT + +EVAL_MIXTURE_OR_TASK_NAME = %MIXTURE_OR_TASK_NAME +EVALUATOR_USE_MEMORY_CACHE = True +EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. +JSON_WRITE_N_RESULTS = None # Write all inferences. + +seqio_utils.get_dataset: + num_seeds = 60 # O(num. of epochs) + +train/utils.DatasetConfig: + use_custom_packing_ops = False + +train_eval/utils.DatasetConfig: + use_custom_packing_ops = False + +# Unlike P5X, we do infer-eval for text metrics during pretraining. +infer_eval/utils.DatasetConfig: + use_custom_packing_ops = False + mixture_or_task_name = %EVAL_MIXTURE_OR_TASK_NAME + task_feature_lengths = %INFER_TASK_FEATURE_LENGTHS + split = %INFER_EVAL_SPLIT + batch_size = 128 + shuffle = False + seed = 42 + use_cached = False + pack = False + module = %MIXTURE_OR_TASK_MODULE + +# Parameters for utils.SaveCheckpointConfig: +# To keep all checkpoints override +# utils.SaveCheckpointConfig.keep = None +# utils.SaveCheckpointConfig.checkpointer_cls = None +# ============================================================================== +utils.SaveCheckpointConfig.save_dataset = True +utils.SaveCheckpointConfig.keep = 20 # Remove all but the best $keep ckpts. +utils.SaveCheckpointConfig.keep_dataset_checkpoints = 1 # Keep only 1 dataset iterator (b/230682911) +utils.SaveCheckpointConfig.checkpointer_cls = @checkpoints.SaveBestCheckpointer + +# Parameters for checkpoints.SaveBestCheckpointer: +# ============================================================================== +# Example metrics: +# * train accuracy: +# values = ['train', 'accuracy]' +# * training_eval accuracy: +# values = ["training_eval", %MIXTURE_OR_TASK_NAME, "accuracy"] +metric_name_builder/gin_str_utils.join: + values = ["training_eval", %MIXTURE_OR_TASK_NAME, "perplexity"] + delimiter = "/" + +checkpoints.SaveBestCheckpointer: + force_keep_period = 100000 + keep_checkpoints_without_metrics = False + metric_mode = 'min' # 'min' for perplexity / 'max' for accuracy. + metric_name_to_monitor = @metric_name_builder/gin_str_utils.join() + + +utils.create_learning_rate_scheduler: + factors = 'constant * linear_warmup * rsqrt_decay' + base_learning_rate = %LEARNING_RATE + warmup_steps = %WARMUP_STEPS + +seqio.Evaluator: + logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] + num_examples = %EVALUATOR_NUM_EXAMPLES + use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE + +seqio.JSONLogger: + write_n_results = %JSON_WRITE_N_RESULTS + +partitioning.PjitPartitioner: + num_partitions = %NUM_PARTITIONS + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +partitioning.standard_logical_axis_rules: + additional_rules = @partitioning_utils.additional_axis_rules() + activation_partitioning_dims = 2 # See b/223425357#comment42 + +adafactor.Adafactor: + logical_factor_rules = @adafactor_utils.logical_factor_rules() diff --git a/test.sh b/test.sh index a8e880b..e81e4a1 100755 --- a/test.sh +++ b/test.sh @@ -47,6 +47,7 @@ INDIVIDUAL_COMPONENTS=( "protoscribe/language/morphology/numbers_test.py" "protoscribe/language/phonology" "protoscribe/language/syntax" + "protoscribe/models" "protoscribe/pmmx" "protoscribe/scoring" "protoscribe/semantics"