Skip to content

Commit

Permalink
Modify the way sources are registered from being public to an interna…
Browse files Browse the repository at this point in the history
…l registry.

The number of sources we will support in TORAX will be finite so we can make this registry internal.

We will need to do follow up work to enable more flexible model funcs etc. which is what will differ from user use cases.

PiperOrigin-RevId: 695828261
  • Loading branch information
Nush395 authored and Torax team committed Nov 12, 2024
1 parent 7155cba commit acecc2f
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 98 deletions.
178 changes: 80 additions & 98 deletions torax/sources/register_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for registering sources.
"""Source registry.
This module contains a set of utilities for registering sources and retrieving
This module contains a registry of sources and a utility for retrieving
registered sources.
In TORAX we flexibly support different user provided sources to be active at
runtime. To do so, we use a registration mechanism such that users can register
their own sources and TORAX can look them up at runtime in the registry.
To register a new source, use the `register_new_source` function. This function
takes in a source name, source class, default runtime params class, and an
optional source builder class. The source name is used to identify the source
in the registry. The source class is the class of the source itself. The default
runtime params class is the class of the default runtime params for the source.
And the source builder class is an optional class which inherits from
`SourceBuilderProtocol`. If not provided, then a default source builder is
created which uses the source class and default runtime params class.
Once a source is registered, it can be retrieved using the
`get_registered_source` function. This function takes in a source name and
returns a `RegisteredSource` dataclass containing the source class, source
builder class, and default runtime params class. TORAX uses this dataclass to
instantiate the source at runtime overriding any default runtime params with
user provided ones from a config file.
To register a new source, use the `_register_new_source` helper and add to the
`_REGISTERED_SOURCES` dict. We register a source by telling TORAX what
class to build, the runtime associated with that source and (optionally) the
builder used to make the source class. If a builder is not provided, the
`_register_new_source` helper will create a default builder for you. The source
builder is used to build the source at runtime, and can be used to override
default runtime params with user provided ones from a config file.
All registered sources can be retrieved using the `get_registered_source`
function. This function takes in a source name and returns a `RegisteredSource`
dataclass containing the source class, source builder class, and default runtime
params class. TORAX uses this dataclass to instantiate the source at runtime
overriding any default runtime params with user provided ones from a config
file.
This is an internal feature of TORAX and the number of registered sources is
expected to grow over time as TORAX becomes more feature rich but ultimately be
finite.
"""
import dataclasses
from typing import Type
Expand All @@ -52,8 +51,6 @@
from torax.sources import runtime_params
from torax.sources import source

_REGISTERED_SOURCES = {}


@dataclasses.dataclass(frozen=True)
class RegisteredSource:
Expand All @@ -62,28 +59,27 @@ class RegisteredSource:
default_runtime_params_class: Type[runtime_params.RuntimeParams]


def register_new_source(
source_name: str,
def _register_new_source(
source_class: Type[source.Source],
default_runtime_params_class: Type[runtime_params.RuntimeParams],
source_builder_class: source.SourceBuilderProtocol | None = None,
links_back: bool = False,
):
) -> RegisteredSource:
"""Register source class, default runtime params and (optional) builder for this source.
Args:
source_name: The name of the source.
source_class: The source class.
default_runtime_params_class: The default runtime params class.
source_builder_class: The source builder class. If None, a default builder
is created which uses the source class and default runtime params class to
construct a builder for that source.
links_back: Whether the source requires a reference to all the source
models.
"""
if source_name in _REGISTERED_SOURCES:
raise ValueError(f'Source:{source_name} has already been registered.')
Returns:
A `RegisteredSource` dataclass containing the source class, source
builder class, and default runtime params class.
"""
if source_builder_class is None:
builder_class = source.make_source_builder(
source_class,
Expand All @@ -93,84 +89,70 @@ def register_new_source(
else:
builder_class = source_builder_class

_REGISTERED_SOURCES[source_name] = RegisteredSource(
return RegisteredSource(
source_class=source_class,
source_builder_class=builder_class,
default_runtime_params_class=default_runtime_params_class,
)


_REGISTERED_SOURCES = {
bootstrap_current_source.SOURCE_NAME: _register_new_source(
source_class=bootstrap_current_source.BootstrapCurrentSource,
default_runtime_params_class=bootstrap_current_source.RuntimeParams,
),
generic_current_source.SOURCE_NAME: _register_new_source(
source_class=generic_current_source.GenericCurrentSource,
default_runtime_params_class=generic_current_source.RuntimeParams,
),
electron_cyclotron_source.SOURCE_NAME: _register_new_source(
source_class=electron_cyclotron_source.ElectronCyclotronSource,
default_runtime_params_class=electron_cyclotron_source.RuntimeParams,
),
electron_density_sources.GENERIC_PARTICLE_SOURCE_NAME: _register_new_source(
source_class=electron_density_sources.GenericParticleSource,
default_runtime_params_class=electron_density_sources.GenericParticleSourceRuntimeParams,
),
electron_density_sources.GAS_PUFF_SOURCE_NAME: _register_new_source(
source_class=electron_density_sources.GasPuffSource,
default_runtime_params_class=electron_density_sources.GasPuffRuntimeParams,
),
electron_density_sources.PELLET_SOURCE_NAME: _register_new_source(
source_class=electron_density_sources.PelletSource,
default_runtime_params_class=electron_density_sources.PelletRuntimeParams,
),
ion_el_heat.SOURCE_NAME: _register_new_source(
source_class=ion_el_heat.GenericIonElectronHeatSource,
default_runtime_params_class=ion_el_heat.RuntimeParams,
),
fusion_heat_source.SOURCE_NAME: _register_new_source(
source_class=fusion_heat_source.FusionHeatSource,
default_runtime_params_class=fusion_heat_source.FusionHeatSourceRuntimeParams,
),
qei_source.SOURCE_NAME: _register_new_source(
source_class=qei_source.QeiSource,
default_runtime_params_class=qei_source.RuntimeParams,
),
ohmic_heat_source.SOURCE_NAME: _register_new_source(
source_class=ohmic_heat_source.OhmicHeatSource,
default_runtime_params_class=ohmic_heat_source.OhmicRuntimeParams,
links_back=True,
),
bremsstrahlung_heat_sink.SOURCE_NAME: _register_new_source(
source_class=bremsstrahlung_heat_sink.BremsstrahlungHeatSink,
default_runtime_params_class=bremsstrahlung_heat_sink.RuntimeParams,
),
ion_cyclotron_source.SOURCE_NAME: _register_new_source(
source_class=ion_cyclotron_source.IonCyclotronSource,
default_runtime_params_class=ion_cyclotron_source.RuntimeParams,
),
}


def get_registered_source(source_name: str) -> RegisteredSource:
"""Used when building a simulation to get the registered source."""
if source_name in _REGISTERED_SOURCES:
return _REGISTERED_SOURCES[source_name]
else:
raise RuntimeError(f'Source:{source_name} has not been registered.')


def register_torax_sources():
"""Register a set of sources commonly used in TORAX."""
register_new_source(
bootstrap_current_source.SOURCE_NAME,
source_class=bootstrap_current_source.BootstrapCurrentSource,
default_runtime_params_class=bootstrap_current_source.RuntimeParams,
)
register_new_source(
generic_current_source.SOURCE_NAME,
generic_current_source.GenericCurrentSource,
default_runtime_params_class=generic_current_source.RuntimeParams,
)
register_new_source(
electron_cyclotron_source.SOURCE_NAME,
electron_cyclotron_source.ElectronCyclotronSource,
default_runtime_params_class=electron_cyclotron_source.RuntimeParams,
)
register_new_source(
electron_density_sources.GENERIC_PARTICLE_SOURCE_NAME,
electron_density_sources.GenericParticleSource,
default_runtime_params_class=electron_density_sources.GenericParticleSourceRuntimeParams,
)
register_new_source(
electron_density_sources.GAS_PUFF_SOURCE_NAME,
electron_density_sources.GasPuffSource,
default_runtime_params_class=electron_density_sources.GasPuffRuntimeParams,
)
register_new_source(
electron_density_sources.PELLET_SOURCE_NAME,
electron_density_sources.PelletSource,
default_runtime_params_class=electron_density_sources.PelletRuntimeParams,
)
register_new_source(
ion_el_heat.SOURCE_NAME,
ion_el_heat.GenericIonElectronHeatSource,
default_runtime_params_class=ion_el_heat.RuntimeParams,
)
register_new_source(
fusion_heat_source.SOURCE_NAME,
fusion_heat_source.FusionHeatSource,
default_runtime_params_class=fusion_heat_source.FusionHeatSourceRuntimeParams,
)
register_new_source(
qei_source.SOURCE_NAME,
qei_source.QeiSource,
default_runtime_params_class=qei_source.RuntimeParams,
)
register_new_source(
ohmic_heat_source.SOURCE_NAME,
ohmic_heat_source.OhmicHeatSource,
default_runtime_params_class=ohmic_heat_source.OhmicRuntimeParams,
links_back=True,
)
register_new_source(
bremsstrahlung_heat_sink.SOURCE_NAME,
bremsstrahlung_heat_sink.BremsstrahlungHeatSink,
default_runtime_params_class=bremsstrahlung_heat_sink.RuntimeParams,
)
register_new_source(
ion_cyclotron_source.SOURCE_NAME,
ion_cyclotron_source.IonCyclotronSource,
ion_cyclotron_source.RuntimeParams,
)


register_torax_sources()
75 changes: 75 additions & 0 deletions torax/sources/tests/register_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the source registry."""


from absl.testing import absltest
from absl.testing import parameterized
from torax.sources import bootstrap_current_source
from torax.sources import bremsstrahlung_heat_sink
from torax.sources import electron_cyclotron_source
from torax.sources import electron_density_sources
from torax.sources import fusion_heat_source
from torax.sources import generic_current_source
from torax.sources import generic_ion_el_heat_source as ion_el_heat
from torax.sources import ion_cyclotron_source
from torax.sources import ohmic_heat_source
from torax.sources import qei_source
from torax.sources import register_source
from torax.sources import source_models as source_models_lib


class SourceTest(parameterized.TestCase):
"""Tests for the source registry."""

@parameterized.parameters(
bootstrap_current_source.SOURCE_NAME,
bremsstrahlung_heat_sink.SOURCE_NAME,
electron_cyclotron_source.SOURCE_NAME,
electron_density_sources.GENERIC_PARTICLE_SOURCE_NAME,
electron_density_sources.GAS_PUFF_SOURCE_NAME,
electron_density_sources.PELLET_SOURCE_NAME,
fusion_heat_source.SOURCE_NAME,
generic_current_source.SOURCE_NAME,
ion_el_heat.SOURCE_NAME,
ohmic_heat_source.SOURCE_NAME,
qei_source.SOURCE_NAME,
)
def test_sources_in_registry_build_successfully(self, source_name: str):
"""Test that all sources in the registry build successfully."""
registered_source = register_source.get_registered_source(source_name)
source_class = registered_source.source_class
source_runtime_params_class = registered_source.default_runtime_params_class
source_builder_class = registered_source.source_builder_class
source_builder = source_builder_class()
self.assertIsInstance(
source_builder.runtime_params, source_runtime_params_class
)
if not source_builder.links_back:
source = source_builder()
self.assertIsInstance(source, source_class)
else:
# If the source links back, we need to create a `SourceModels` object to
# pass to the source builder.
source_models = source_models_lib.SourceModels(
source_builders={
source_name: source_builder,
}
)
source = source_builder(source_models)
self.assertIsInstance(source, source_class)


if __name__ == "__main__":
absltest.main()

0 comments on commit acecc2f

Please sign in to comment.