diff --git a/docs/reference/configuration/keywords.rst b/docs/reference/configuration/keywords.rst index 8d49c3ccdaa..721538009a0 100644 --- a/docs/reference/configuration/keywords.rst +++ b/docs/reference/configuration/keywords.rst @@ -1231,33 +1231,68 @@ The keywords to load, select and modify the analysis modules are documented here These can be manipulated from the config file using the ANALYSIS_SET_VAR keyword for either the `STD_ENKF` or `IES_ENKF` module. + **STD_ENKF** - .. list-table:: Inversion Algorithms - :widths: 50 50 50 + + .. list-table:: Inversion Algorithms for Ensemble Smoother + :widths: 50 50 50 50 :header-rows: 1 * - Description - INVERSION - - IES_INVERSION + - IES_INVERSION (deprecated) + - Note * - Exact inversion with diagonal R=I - - EXACT + - EXACT / EXACT - 0 + - * - Subspace inversion with exact R - - SUBSPACE_EXACT_R + - SUBSPACE_EXACT_R / SUBSPACE - 1 + - Preferred name: SUBSPACE * - Subspace inversion using R=EE' - SUBSPACE_EE_R - 2 + - Deprecated, maps to: subspace + * - Subspace inversion using E + - SUBSPACE_RE + - 3 + - Deprecated, maps to: SUBSPACE + + + **IES_ENKF** + + + .. list-table:: Inversion Algorithms for IES + :widths: 50 50 50 50 + :header-rows: 1 + + * - Description + - INVERSION + - IES_INVERSION (deprecated) + - Note + * - Exact inversion with diagonal R=I + - EXACT / DIRECT + - 0 + - Preferred name: DIRECT + * - Subspace inversion with exact R + - SUBSPACE_EXACT_R / SUBSPACE_EXACT + - 1 + - Preferred name: subspace_exact + * - Subspace inversion using R=EE' + - SUBSPACE_EE_R / SUBSPACE_PROJECTED + - 2 + - Preferred name: subspace_projected * - Subspace inversion using E - SUBSPACE_RE - 3 + - Deprecated, maps to: SUBSPACE_PROJECTED - Two ways of setting the same inversion method + Setting the inversion method :: -- Example for the `STD_ENKF` module - ANALYSIS_SET_VAR STD_ENKF INVERSION SUBSPACE_EXACT_R - ANALYSIS_SET_VAR STD_ENKF IES_INVERSION 1 + ANALYSIS_SET_VAR STD_ENKF INVERSION DIRECT .. _ies_max_steplength: diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 065264396bc..3a4d1c5007a 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -590,7 +590,7 @@ def analysis_ES( observations=observation_values, alpha=1, # The user is responsible for scaling observation covariance (esmda usage) seed=rng, - inversion=module.inversion.name, + inversion=module.inversion, ) truncation = module.enkf_truncation @@ -722,7 +722,7 @@ def analysis_ES( observation_errors=observation_errors, observation_values=observation_values, truncation=truncation, - inversion_type=module.inversion.name, + inversion_type=module.inversion, progress_callback=progress_callback, rng=rng, ) @@ -800,7 +800,7 @@ def analysis_IES( covariance=observation_errors**2, observations=observation_values, seed=rng, - inversion=analysis_config.inversion.name, + inversion=analysis_config.inversion, truncation=analysis_config.enkf_truncation, ) diff --git a/src/ert/config/analysis_config.py b/src/ert/config/analysis_config.py index 78c27880b3e..047caaa5fd9 100644 --- a/src/ert/config/analysis_config.py +++ b/src/ert/config/analysis_config.py @@ -46,10 +46,18 @@ def __init__( options: Dict[str, Dict[str, Any]] = {"STD_ENKF": {}, "IES_ENKF": {}} analysis_set_var = [] if analysis_set_var is None else analysis_set_var inversion_str_map: Final = { - "EXACT": 0, - "SUBSPACE_EXACT_R": 1, - "SUBSPACE_EE_R": 2, - "SUBSPACE_RE": 3, + "STD_ENKF": { + **dict.fromkeys(["EXACT", 0], "exact"), + **dict.fromkeys(["SUBSPACE_EXACT_R", 1], "subspace"), + **dict.fromkeys(["SUBSPACE_EE_R", 2], "subspace"), + **dict.fromkeys(["SUBSPACE_RE", 3], "subspace"), + }, + "IES_ENKF": { + **dict.fromkeys(["EXACT", 0], "direct"), + **dict.fromkeys(["SUBSPACE_EXACT_R", 1], "subspace_exact"), + **dict.fromkeys(["SUBSPACE_EE_R", 2], "subspace_projected"), + **dict.fromkeys(["SUBSPACE_RE", 3], "subspace_projected"), + }, } deprecated_keys = ["ENKF_NCOMP", "ENKF_SUBSPACE_DIMENSION"] errors = [] @@ -59,11 +67,9 @@ def __init__( continue if var_name == "ENKF_FORCE_NCOMP": continue - if var_name == "INVERSION": - value = inversion_str_map[value] - var_name = "IES_INVERSION" - if var_name == "IES_INVERSION": - value = int(value) + if var_name in ["INVERSION", "IES_INVERSION"]: + if value in inversion_str_map[module_name]: + value = inversion_str_map[module_name][value] var_name = "inversion" key = var_name.lower() options[module_name][key] = value diff --git a/src/ert/config/analysis_module.py b/src/ert/config/analysis_module.py index b2de81bae23..d753a410d02 100644 --- a/src/ert/config/analysis_module.py +++ b/src/ert/config/analysis_module.py @@ -1,10 +1,9 @@ import logging import math -from enum import IntEnum from typing import TYPE_CHECKING, Optional, Type, TypedDict, Union from pydantic import BaseModel, Extra, Field -from typing_extensions import Annotated +from typing_extensions import Annotated, Literal logger = logging.getLogger(__name__) @@ -38,8 +37,8 @@ class Config: validate_assignment = True -class InversionTypeES(IntEnum): - """ +InversionTypeES = Literal["exact", "subspace"] +es_description = """ The type of inversion used in the algorithm. Every inversion method scales the variables. The options are: @@ -51,12 +50,9 @@ class InversionTypeES(IntEnum): U, w, V.T = svd(D_delta) then we assume that U @ U.T = I. """ - exact = 0 - subspace = 1 - -class InversionTypeIES(IntEnum): - """ +InversionTypeIES = Literal["direct", "subspace_exact", "subspace_projected"] +ies_description = """ The type of inversion used in the algorithm. Every inversion method scales the variables. The options are: @@ -70,15 +66,11 @@ class InversionTypeIES(IntEnum): Invert by projecting the covariance onto S. """ - direct = 0 - subspace_exact = 1 - subspace_projected = 2 - class ESSettings(BaseSettings): inversion: Annotated[ - InversionTypeES, Field(title="Inversion algorithm") - ] = InversionTypeES.exact + InversionTypeES, Field(title="Inversion algorithm", description=es_description) + ] = "exact" localization: Annotated[bool, Field(title="Adaptive localization")] = False localization_correlation_threshold: Annotated[ Optional[float], @@ -107,8 +99,9 @@ class IESSettings(BaseSettings): A steplength of 1.0 and one iteration results in ES update""" inversion: Annotated[ - InversionTypeIES, Field(title="Inversion algorithm") - ] = InversionTypeIES.subspace_exact + InversionTypeIES, + Field(title="Inversion algorithm", description=ies_description), + ] = "subspace_exact" ies_max_steplength: Annotated[ float, Field(ge=0.1, le=1.0, title="Gauss–Newton maximum steplength"), diff --git a/src/ert/gui/ertwidgets/analysismodulevariablespanel.py b/src/ert/gui/ertwidgets/analysismodulevariablespanel.py index fecbe3f0eff..101fc1f2366 100644 --- a/src/ert/gui/ertwidgets/analysismodulevariablespanel.py +++ b/src/ert/gui/ertwidgets/analysismodulevariablespanel.py @@ -12,6 +12,7 @@ QLabel, QWidget, ) +from typing_extensions import get_args from ert.config.analysis_module import AnalysisModule, IESSettings @@ -61,15 +62,15 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int): layout.addRow(QLabel("Inversion Algorithm")) dropdown = QComboBox(self) - options = analysis_module.model_fields["inversion"].annotation - layout.addRow(QLabel(options.__doc__)) + options = analysis_module.model_fields["inversion"] + layout.addRow(QLabel(options.description)) default_index = 0 - for i, option in enumerate(options): - dropdown.addItem(option.name) - if analysis_module.inversion.name == option.name: + for i, option in enumerate(get_args(options.annotation)): + dropdown.addItem(option) + if analysis_module.inversion == option: default_index = i dropdown.setCurrentIndex(default_index) - dropdown.currentIndexChanged.connect(self.update_inversion_algorithm) + dropdown.currentTextChanged.connect(self.update_inversion_algorithm) layout.addRow(dropdown) var_name = "enkf_truncation" metadata = analysis_module.model_fields[var_name] @@ -130,9 +131,11 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int): self.setLayout(layout) self.blockSignals(False) - def update_inversion_algorithm(self, button_id): - self.truncation_spinner.setEnabled(button_id != 0) # not for exact inversion - self.analysis_module.inversion = button_id + def update_inversion_algorithm(self, text): + self.truncation_spinner.setEnabled( + not any(val in text for val in ["direct", "exact"]) + ) + self.analysis_module.inversion = text def create_horizontal_line(self) -> QFrame: hline = QFrame() diff --git a/tests/unit_tests/config/test_analysis_config.py b/tests/unit_tests/config/test_analysis_config.py index 25accfc2d30..711eccaba34 100644 --- a/tests/unit_tests/config/test_analysis_config.py +++ b/tests/unit_tests/config/test_analysis_config.py @@ -157,7 +157,9 @@ def test_setting_case_format(analysis_config): def test_incorrect_variable_raises_validation_error(): - with pytest.raises(ConfigValidationError, match="Input should be a valid integer"): + with pytest.raises( + ConfigValidationError, match="Input should be 'exact' or 'subspace'" + ): _ = AnalysisConfig.from_dict( { ConfigKeys.ANALYSIS_SET_VAR: [["STD_ENKF", "IES_INVERSION", "FOO"]],