Skip to content

Commit

Permalink
Create pydantic model for stepper.
Browse files Browse the repository at this point in the history
Drive-by:
- remove a layer of nesting from stepper configs(by removing newton_raphson_params and optimizer_params from).
- rename PedestalModel -> Pedestal for consistency with other pydantic config objects.

Follow up to remove the builder objects completely.

PiperOrigin-RevId: 731728066
  • Loading branch information
Nush395 authored and Torax team committed Mar 3, 2025
1 parent dcab336 commit 7464dc1
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 45 deletions.
16 changes: 11 additions & 5 deletions docs/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,8 @@ stepper
-------
Select and configure the ``Stepper`` object, which evolves the PDE system by one timestep. See :ref:`solver_details` for further details.
The dictionary consists of keys common to all steppers, and additional nested dictionaries where parameters pertaining to a specific stepper are defined.
The dictionary consists of keys common to all steppers. Additional fields for
parameters pertaining to a specific stepper are defined in the relevant section below.
``stepper_type`` (str = 'linear')
Selected PDE solver algorithm. The current options are:
Expand Down Expand Up @@ -1226,8 +1227,6 @@ parent ``Stepper`` class.
newton_raphson
^^^^^^^^^^^^^^
``newton_raphson_params`` dict containing the following configuration parameters for the Newton Raphson stepper.
.. _log_iterations:
``log_iterations`` (bool = False)
Expand All @@ -1249,6 +1248,9 @@ newton_raphson
* ``linear_step``
Use the linear solver to obtain an initial guess to warm-start the nonlinear solver.
If used, is recommended to do so with the predictor_corrector solver and
several corrector steps. It is also strongly recommended to
use_pereverzev=True if a stiff transport model like qlknn is used.
``tol`` (float = 1e-5)
PDE residual magnitude tolerance for successfully exiting the iterative solver.
Expand All @@ -1275,8 +1277,6 @@ newton_raphson
optimizer
^^^^^^^^^
``optimizer_params`` dict containing the following configuration parameters for the Optimizer stepper.
``initial_guess_mode`` (str = 'linear_step')
Sets the approach taken for the initial guess into the Newton-Raphson solver for the first iteration.
Two options are available:
Expand All @@ -1286,9 +1286,15 @@ optimizer
* ``linear_step``
Use the linear solver to obtain an initial guess to warm-start the nonlinear solver.
If used, is recommended to do so with the predictor_corrector solver and
several corrector steps. It is also strongly recommended to
use_pereverzev=True if a stiff transport model like qlknn is used.
``tol`` (float = 1e-12)
PDE loss magnitude tolerance for successfully exiting the iterative solver.
Note: the default tolerance here is smaller than the default tolerance for
the Newton-Raphson solver because it's a tolerance on the loss (square of the
residual).
``maxiter`` (int = 100)
Maximum number of allowed optimizer iterations.
Expand Down
34 changes: 8 additions & 26 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torax.sources import source_models as source_models_lib
from torax.stepper import linear_theta_method
from torax.stepper import nonlinear_theta_method
from torax.stepper import pydantic_model as stepper_pydantic_model
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import chi_time_step_calculator
from torax.time_step_calculator import fixed_time_step_calculator
Expand Down Expand Up @@ -505,48 +506,29 @@ def build_stepper_builder_from_config(
Raises:
ValueError if the `stepper_type` is unknown.
"""
if isinstance(stepper_config, str):
stepper_config = {'stepper_type': stepper_config}
else:
if 'stepper_type' not in stepper_config:
raise ValueError('stepper_type must be set in the input config.')
# Deep copy so we don't modify the input config.
stepper_config = copy.deepcopy(stepper_config)
stepper_type = stepper_config.pop('stepper_type')
stepper_model = stepper_pydantic_model.Stepper.from_dict(stepper_config)
stepper_model = stepper_model.to_dict()
stepper_type = stepper_model['stepper_config'].pop('stepper_type')

if stepper_type == 'linear':
# Remove params from steppers with nested configs, if present.
stepper_config.pop('newton_raphson_params', None)
stepper_config.pop('optimizer_params', None)
return linear_theta_method.LinearThetaMethodBuilder(
runtime_params=config_args.recursive_replace(
linear_theta_method.LinearRuntimeParams(),
**stepper_config,
**stepper_model['stepper_config'],
)
)
elif stepper_type == 'newton_raphson':
newton_raphson_params = stepper_config.pop('newton_raphson_params', {})
if not isinstance(newton_raphson_params, dict):
raise ValueError('newton_raphson_params must be a dict.')
newton_raphson_params.update(stepper_config)
# Remove params from other steppers with nested configs, if present.
newton_raphson_params.pop('optimizer_params', None)
return nonlinear_theta_method.NewtonRaphsonThetaMethodBuilder(
runtime_params=config_args.recursive_replace(
nonlinear_theta_method.NewtonRaphsonRuntimeParams(),
**newton_raphson_params,
**stepper_model['stepper_config'],
)
)
elif stepper_type == 'optimizer':
optimizer_params = stepper_config.pop('optimizer_params', {})
if not isinstance(optimizer_params, dict):
raise ValueError('optimizer_params must be a dict.')
optimizer_params.update(stepper_config)
# Remove params from other steppers with nested configs, if present.
optimizer_params.pop('newton_raphson_params', None)
return nonlinear_theta_method.OptimizerThetaMethodBuilder(
runtime_params=config_args.recursive_replace(
nonlinear_theta_method.OptimizerRuntimeParams(),
**optimizer_params,
**stepper_model['stepper_config'],
)
)
raise ValueError(f'Unknown stepper type: {stepper_type}')
Expand Down
4 changes: 1 addition & 3 deletions torax/examples/iterhybrid_rampup.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@
'd_per': 15,
# use_pereverzev is only used for the linear solver
'use_pereverzev': True,
'newton_raphson_params': {
'log_iterations': False,
},
'log_iterations': False,
},
'time_step_calculator': {
'calculator_type': 'fixed',
Expand Down
2 changes: 1 addition & 1 deletion torax/pedestal_model/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class SetTpedNped(torax_pydantic.BaseModelMutable):
)


class PedestalModel(torax_pydantic.BaseModelMutable):
class Pedestal(torax_pydantic.BaseModelMutable):
"""Config for a pedestal model."""
pedestal_config: SetPpedTpedRatioNped | SetTpedNped = pydantic.Field(
discriminator='pedestal_model', default_factory=SetTpedNped,
Expand Down
119 changes: 119 additions & 0 deletions torax/stepper/pydantic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pydantic config for Stepper."""
from typing import Any, Literal, Union

import pydantic
from torax.fvm import enums
from torax.torax_pydantic import torax_pydantic


# pylint: disable=invalid-name
class LinearThetaMethod(torax_pydantic.BaseModelMutable):
"""Model for the linear stepper.
This is also the base model for the Newton-Raphson and Optimizer steppers as
they share the same parameters.
Attributes:
stepper_type: The type of stepper to use, hardcoded to 'linear'.
theta_imp: The theta value in the theta method 0 = explicit, 1 = fully
implicit, 0.5 = Crank-Nicolson.
predictor_corrector: Enables predictor_corrector iterations with the linear
solver. If False, compilation is faster.
corrector_steps: The number of corrector steps for the predictor-corrector
linear solver. 0 means a pure linear solve with no corrector steps.
convection_dirichlet_mode: See `fvm.convection_terms` docstring,
`dirichlet_mode` argument.
convection_neumann_mode: See `fvm.convection_terms` docstring,
`neumann_mode` argument.
use_pereverzev: Use pereverzev terms for linear solver. Is only applied in
the nonlinear solver for the optional initial guess from the linear solver
chi_per: (deliberately) large heat conductivity for Pereverzev rule.
d_per: (deliberately) large particle diffusion for Pereverzev rule.
"""
stepper_type: Literal['linear']
theta_imp: torax_pydantic.UnitInterval = 1.0
predictor_corrector: bool = True
corrector_steps: pydantic.PositiveInt = 1
convection_dirichlet_mode: str = 'ghost'
convection_neumann_mode: str = 'ghost'
use_pereverzev: bool = False
chi_per: pydantic.PositiveFloat = 20.0
d_per: pydantic.NonNegativeFloat = 10.0


class NewtonRaphsonThetaMethod(LinearThetaMethod):
"""Model for non linear NewtonRaphsonThetaMethod stepper.
Attributes:
stepper_type: The type of stepper to use, hardcoded to 'newton_raphson'.
log_iterations: If True, log internal iterations in Newton-Raphson solver.
initial_guess_mode: The initial guess mode for the Newton-Raphson solver.
maxiter: The maximum number of iterations for the Newton-Raphson solver.
tol: The tolerance for the Newton-Raphson solver.
coarse_tol: The coarse tolerance for the Newton-Raphson solver.
delta_reduction_factor: The delta reduction factor for the Newton-Raphson
solver.
tau_min: The minimum value of tau for the Newton-Raphson solver.
"""
stepper_type: Literal['newton_raphson']
log_iterations: bool = False
initial_guess_mode: enums.InitialGuessMode = enums.InitialGuessMode.LINEAR
maxiter: pydantic.PositiveInt = 30
tol: float = 1e-5
coarse_tol: float = 1e-2
delta_reduction_factor: float = 0.5
tau_min: float = 0.01


class OptimizerThetaMethod(LinearThetaMethod):
"""Model for non linear OptimizerThetaMethod stepper.
Attributes:
stepper_type: The type of stepper to use, hardcoded to 'optimizer'.
initial_guess_mode: The initial guess mode for the optimizer.
maxiter: The maximum number of iterations for the optimizer.
tol: The tolerance for the optimizer.
"""

stepper_type: Literal['optimizer']
initial_guess_mode: enums.InitialGuessMode = enums.InitialGuessMode.LINEAR
maxiter: pydantic.PositiveInt = 100
tol: float = 1e-12


StepperConfig = Union[
LinearThetaMethod, NewtonRaphsonThetaMethod, OptimizerThetaMethod
]


class Stepper(torax_pydantic.BaseModelMutable):
"""Config for a stepper.
The `from_dict` method of constructing this class supports the config
described in: https://torax.readthedocs.io/en/latest/configuration.html
"""
stepper_config: StepperConfig = pydantic.Field(discriminator='stepper_type')

@pydantic.model_validator(mode='before')
@classmethod
def _conform_data(cls, data: dict[str, Any]) -> dict[str, Any]:
# If we are running with the standard class constructor we don't need to do
# any custom validation.
if 'stepper_config' in data:
return data

return {'stepper_config': data}
1 change: 0 additions & 1 deletion torax/tests/test_data/test_iterhybrid_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@
# (deliberately) large particle diffusion for Pereverzev rule
'd_per': 15,
'use_pereverzev': True,
'newton_raphson_params': {},
},
'time_step_calculator': {
'calculator_type': 'chi',
Expand Down
1 change: 0 additions & 1 deletion torax/tests/test_data/test_iterhybrid_rampup.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@
'd_per': 15,
# use_pereverzev is only used for the linear solver
'use_pereverzev': True,
'newton_raphson_params': {},
},
'time_step_calculator': {
'calculator_type': 'fixed',
Expand Down
1 change: 0 additions & 1 deletion torax/tests/test_data/test_iterhybrid_rampup_short.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@
'd_per': 15,
# use_pereverzev is only used for the linear solver
'use_pereverzev': True,
'newton_raphson_params': {},
},
'time_step_calculator': {
'calculator_type': 'fixed',
Expand Down
4 changes: 1 addition & 3 deletions torax/tests/test_data/test_newton_raphson_zeroiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@
'stepper_type': 'newton_raphson',
'predictor_corrector': False,
'use_pereverzev': True,
'newton_raphson_params': {
'maxiter': 0,
},
'maxiter': 0,
},
'time_step_calculator': {
'calculator_type': 'chi',
Expand Down
11 changes: 7 additions & 4 deletions torax/torax_pydantic/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""Pydantic config for Torax."""

from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.pedestal_model import pydantic_model as pedestal_model_config
from torax.pedestal_model import pydantic_model as pedestal_pydantic_model
from torax.stepper import pydantic_model as stepper_pydantic_model
from torax.time_step_calculator import config as time_step_calculator_config
from torax.torax_pydantic import model_base

Expand All @@ -24,10 +25,12 @@ class ToraxConfig(model_base.BaseModelMutable):
"""Base config class for Torax.
Attributes:
time_step_calculator: Config for the time step calculator.
geometry: Config for the geometry.
pedestal: Config for the pedestal model.
stepper: Config for the stepper.
time_step_calculator: Config for the time step calculator.
"""

geometry: geometry_pydantic_model.Geometry
pedestal: pedestal_model_config.PedestalModel
pedestal: pedestal_pydantic_model.Pedestal
stepper: stepper_pydantic_model.Stepper
time_step_calculator: time_step_calculator_config.TimeStepCalculator

0 comments on commit 7464dc1

Please sign in to comment.