Skip to content

Commit

Permalink
Add utilities for instantiating JAX models from config files.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662080465
  • Loading branch information
mjanusz authored and copybara-github committed Aug 12, 2024
1 parent 62407d4 commit e8e6278
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 0 deletions.
54 changes: 54 additions & 0 deletions connectomics/common/import_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for dynamically importing symbols from modules."""

import importlib
from absl import logging


def import_symbol(
specifier: str, default_packages: str = 'connectomics.jax.models'
):
"""Imports a symbol from a python module.
The calling module must have the target module for the import as dependency.
Args:
specifier: full path specifier in format
[<packages>.]<module_name>.<model_class>, if packages is missing
``default_packages`` is used.
default_packages: chain of packages before module in format
<top_pack>.<sub_pack>.<subsub_pack> etc.
Returns:
symbol: object from module
"""
module_path, symbol_name = specifier.rsplit('.', 1)
try:
logging.info(
'Importing symbol %s from %s.%s',
symbol_name,
default_packages,
module_path,
)
module = importlib.import_module(default_packages + '.' + module_path)
except ImportError as e:
logging.info(e)
logging.info('Importing symbol %s from %s', symbol_name, module_path)
module = importlib.import_module(module_path)

symbol = getattr(module, symbol_name)
return symbol
178 changes: 178 additions & 0 deletions connectomics/jax/models/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for instantiating models."""

import collections.abc
import inspect
import re
from typing import Any, Type

from absl import logging
from connectomics.common import import_util
# pylint:disable=unused-import
from connectomics.jax.models import convstack

import flax.linen as nn
import ml_collections

DEFAULT_PKG = 'connectomics.jax.models'


def class_from_name(
model_class: str, default_packages: str = DEFAULT_PKG
) -> tuple[Type, Type]: # pylint:disable=g-bare-generic
model_cls = import_util.import_symbol(
model_class, default_packages=default_packages
)
cfg_cls = (
inspect.signature(model_cls.__init__).parameters['config'].annotation
)
return model_cls, cfg_cls


def get_config_name(config_cls_name: str) -> str:
"""Returns the default ConfigDict field name for a given model class name."""
# The model is configured by a field, the name of which is the snake
# case version of the config class.
return re.sub(r'(?<!^)(?=[A-Z]([^A-Z]|$))', '_', config_cls_name).lower()


def model_from_config(
config: ml_collections.ConfigDict,
default_packages: str = DEFAULT_PKG,
) -> nn.Module:
"""Initializes a JAX model from settings in a ConfigDict.
A typical use case is to instantiate a model for training based on
settings that can be overridden from the command line.
Args:
config: ConfigDict containing a field with the settings for the model; the
model is expected to be configured with a single dataclass stored in its
'.config' attribute
default_packages: module from which to import the model class
Returns:
flax model object
"""
model_cls, cfg_cls = class_from_name(config.model_class, default_packages)
cfg_field = get_config_name(cfg_cls.__name__)

logging.info('Using config settings from "%r"', cfg_field)
cfg = cfg_cls(**getattr(config, cfg_field))
return model_cls(config=cfg, name=getattr(config, 'model_name', None))


def model_from_name(
model_class: str,
model_name: str | None = None,
default_packages: str = DEFAULT_PKG,
**kwargs
) -> nn.Module:
"""Initializes a JAX model given a name and its config settings.
A typical use case is to instantiate a model for inference based on
settings recorded in a JSON object teogether with the experiment
that was used to train the model.
Args:
model_class: name of the Python class implementing the model.
model_name: name of the model parameters (passed to the constructor of the
model class as `name` parameter)
default_packages: module from which to import 'model_class'
**kwargs: arguments passed to the configuration object for the model
Returns:
flax model object
"""
model_cls = import_util.import_symbol(
model_class, default_packages=default_packages
)
cfg_cls = (
inspect.signature(model_cls.__init__).parameters['config'].annotation
)

# TODO(mjanusz): Figure out how to make this compatible with callable config
# values.
def _skip_arg(name, value, cls):
"""Detects settings which currently cannot be restored."""

if isinstance(value, str) and (
value.startswith('function ') or 'unserializable' in value
):
return True

if (
getattr(
inspect.signature(cls).parameters[name].annotation,
'__origin__',
None,
)
is collections.abc.Callable
):
return True

return False

def _value(key, value, cls):
val_type = inspect.signature(cls).parameters[key].annotation
if hasattr(val_type, '__dataclass_fields__'):
value = {
k: _value(k, v, val_type)
for k, v in value.items()
if not _skip_arg(k, v, val_type)
}
return val_type(**value)
else:
return value

kwargs = {
k: _value(k, v, cfg_cls)
for k, v in kwargs.items()
if not _skip_arg(k, v, cfg_cls)
}

logging.info(
'Initializing model %r with config %r(%r)', model_cls, cfg_cls, kwargs
)
return model_cls(config=cfg_cls(**kwargs), name=model_name)


def model_from_dict_config(
config: dict[str, Any],
default_packages: str = DEFAULT_PKG,
) -> nn.Module:
"""Initializes a JAX model from settings in a python dictionary.
Like model_from_config, but uses a dictionary as configuration.
Args:
config: dictionary containing a field with the settings for the model; the
model is expected to be configured with a single dataclass stored in its
'.config' attribute
default_packages: module from which to import the model class
Returns:
flax model object
"""

_, cfg_cls = class_from_name(config['model_class'], default_packages)
cfg_field = get_config_name(cfg_cls.__name__)
return model_from_name(
config['model_class'],
config.get('model_name'),
default_packages,
**config[cfg_field],
)

0 comments on commit e8e6278

Please sign in to comment.