Skip to content

Commit

Permalink
Rudimentary gin sanity checks for model configurations.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721684050
  • Loading branch information
agutkin committed Jan 31, 2025
1 parent d7c4146 commit bd0a8d4
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 0 deletions.
101 changes: 101 additions & 0 deletions protoscribe/models/pmmx/model_config_gin_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sanity check for miscellaneous model configurations in gin."""

import os

from absl.testing import absltest
from absl.testing import parameterized
import gin

# Core PMMX configurations. These are copied from
# protoscribe/pmmx/config/runs
# and modified to work in test environment.
_CORE_CONFIG_BASE_DIR = (
"protoscribe/models/pmmx/testdata"
)

# Configurations for individual models.
_MODEL_CONFIG_BASE_DIR = (
"protoscribe/models/pmmx/configs"
)


def _config_path(
filename: str,
config_dir: str = _MODEL_CONFIG_BASE_DIR
) -> str:
"""Returns full path of the specified file name."""
return os.path.join(
absltest.get_default_test_srcdir(), config_dir, filename
)


class ModelConfigGinTest(parameterized.TestCase):

def tearDown(self):
super().tearDown()
gin.clear_config()

@parameterized.parameters(
"glyph_concepts",
"glyph_phonemes",
)
def test_model_train(self, model_dir: str) -> None:
"""Tests tiny model configuration for training."""
tmp_dir = absltest.get_default_test_tmpdir()
gin.parse_config_files_and_bindings(
config_files=[
_config_path("pretrain_test.gin", config_dir=_CORE_CONFIG_BASE_DIR),
_config_path(os.path.join(model_dir, "model_tiny.gin")),
_config_path(os.path.join(model_dir, "dataset.gin")),
],
bindings=[
f"DATA_DIR=\"{tmp_dir}\"",
f"MODEL_DIR=\"{tmp_dir}\"",
"TRAIN_STEPS=1",
"BATCH_SIZE=8",
"EVAL_BATCH_SIZE=8",
],
finalize_config=True,
skip_unknown=False
)

@parameterized.parameters(
"glyph_concepts",
"glyph_phonemes",
)
def test_model_infer(self, model_dir: str) -> None:
"""Tests tiny model configuration in inference mode."""
tmp_dir = absltest.get_default_test_tmpdir()
gin.parse_config_files_and_bindings(
config_files=[
_config_path("infer_test.gin", config_dir=_CORE_CONFIG_BASE_DIR),
_config_path(os.path.join(model_dir, "model_tiny.gin")),
_config_path(os.path.join(model_dir, "dataset.gin")),
],
bindings=[
f"DATA_DIR=\"{tmp_dir}\"",
f"CHECKPOINT_PATH=\"{tmp_dir}\"",
f"INFER_OUTPUT_DIR=\"{tmp_dir}\"",
"BATCH_SIZE=8",
],
finalize_config=True,
skip_unknown=False
)


if __name__ == "__main__":
absltest.main()
4 changes: 4 additions & 0 deletions protoscribe/models/pmmx/testdata/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Testdata for PMMX models.

* `pretrain_test.gin`, `infer_test.gin`: Custom tine model gin configurations
for pretraining and inference.
90 changes: 90 additions & 0 deletions protoscribe/models/pmmx/testdata/infer_test.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Defaults for infer.py.
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference
# - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features
# to.
# - CHECKPOINT_PATH: The model checkpoint to use for inference
# - INFER_OUTPUT_DIR: The dir to write results to. When launching using
# XManager, this is set automatically.
#

from __gin__ import dynamic_registration

from protoscribe.pmmx.utils import partitioning_utils
from t5x import partitioning
from t5x import utils

# --------------------------------------------------
# From t5x/configs/runs/infer.gin:
# --------------------------------------------------

# Must be overridden
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
CHECKPOINT_PATH = %gin.REQUIRED
INFER_OUTPUT_DIR = %gin.REQUIRED

# DEPRECATED: Import the this module in your gin file.
MIXTURE_OR_TASK_MODULE = None

partitioning.PjitPartitioner:
num_partitions = 1
logical_axis_rules = @partitioning.standard_logical_axis_rules()

utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
module = %MIXTURE_OR_TASK_MODULE
task_feature_lengths = %TASK_FEATURE_LENGTHS
use_cached = False
split = 'test'
batch_size = 32
shuffle = False
seed = 0
pack = False

utils.RestoreCheckpointConfig:
path = %CHECKPOINT_PATH
mode = 'specific'
dtype = 'bfloat16'

# --------------------------------------------------
# From PMMX:
# --------------------------------------------------

partitioning.PjitPartitioner:
num_partitions = 1
logical_axis_rules = @partitioning.standard_logical_axis_rules()

partitioning.standard_logical_axis_rules:
additional_rules = @partitioning_utils.additional_axis_rules()

# Must be overridden
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
CHECKPOINT_PATH = %gin.REQUIRED
INFER_OUTPUT_DIR = %gin.REQUIRED
BATCH_SIZE = %gin.REQUIRED

utils.DatasetConfig:
batch_size = %BATCH_SIZE

# No falling back to scratch for inference.
utils.RestoreCheckpointConfig.fallback_to_scratch = False
Loading

0 comments on commit bd0a8d4

Please sign in to comment.