diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 3cbd40f56ed..065264396bc 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -585,21 +585,12 @@ def analysis_ES( num_obs = len(observation_values) - inversion_types = {0: "exact", 1: "subspace", 2: "subspace", 3: "subspace"} - try: - inversion_type = inversion_types[module.ies_inversion] - except KeyError as e: - raise ErtAnalysisError( - f"Mismatched inversion type for: " - f"Specified: {module.ies_inversion}, with possible: {inversion_types}" - ) from e - smoother_es = ies.ESMDA( covariance=observation_errors**2, observations=observation_values, alpha=1, # The user is responsible for scaling observation covariance (esmda usage) seed=rng, - inversion=inversion_type, + inversion=module.inversion.name, ) truncation = module.enkf_truncation @@ -731,7 +722,7 @@ def analysis_ES( observation_errors=observation_errors, observation_values=observation_values, truncation=truncation, - inversion_type=inversion_type, + inversion_type=module.inversion.name, progress_callback=progress_callback, rng=rng, ) @@ -761,21 +752,6 @@ def analysis_IES( # This is needed for the SIES library masking_of_initial_parameters = ens_mask[initial_mask] - # Map paper (current in ERT) inversion-types to SIES inversion-types - inversion_types = { - 0: "direct", - 1: "subspace_exact", - 2: "subspace_projected", - 3: "subspace_projected", - } - try: - inversion_type = inversion_types[analysis_config.ies_inversion] - except KeyError as e: - raise ErtAnalysisError( - f"Mismatched inversion type for: " - f"Specified: {analysis_config.ies_inversion}, with possible: {inversion_types}" - ) from e - # It is not the iterations relating to IES or ESMDA. # It is related to functionality for turning on/off groups of parameters and observations. for update_step in update_config: @@ -824,7 +800,7 @@ def analysis_IES( covariance=observation_errors**2, observations=observation_values, seed=rng, - inversion=inversion_type, + inversion=analysis_config.inversion.name, truncation=analysis_config.enkf_truncation, ) diff --git a/src/ert/config/analysis_config.py b/src/ert/config/analysis_config.py index 31494062627..78c27880b3e 100644 --- a/src/ert/config/analysis_config.py +++ b/src/ert/config/analysis_config.py @@ -60,8 +60,11 @@ def __init__( if var_name == "ENKF_FORCE_NCOMP": continue if var_name == "INVERSION": - value = str(inversion_str_map[value]) + value = inversion_str_map[value] var_name = "IES_INVERSION" + if var_name == "IES_INVERSION": + value = int(value) + var_name = "inversion" key = var_name.lower() options[module_name][key] = value try: diff --git a/src/ert/config/analysis_module.py b/src/ert/config/analysis_module.py index c7e50f541f2..a1911e9bf69 100644 --- a/src/ert/config/analysis_module.py +++ b/src/ert/config/analysis_module.py @@ -1,5 +1,6 @@ import logging import math +from enum import IntEnum from typing import TYPE_CHECKING, Optional, Type, TypedDict, Union from pydantic import BaseModel, Extra, Field @@ -23,14 +24,10 @@ class VariableInfo(TypedDict): DEFAULT_IES_MIN_STEPLENGTH = 0.30 DEFAULT_IES_DEC_STEPLENGTH = 2.50 DEFAULT_ENKF_TRUNCATION = 0.98 -DEFAULT_IES_INVERSION = 0 DEFAULT_LOCALIZATION = False class BaseSettings(BaseModel): - ies_inversion: Annotated[ - int, Field(ge=0, le=3, title="Inversion algorithm") - ] = DEFAULT_IES_INVERSION enkf_truncation: Annotated[ float, Field(gt=0.0, le=1.0, title="Singular value truncation"), @@ -41,7 +38,48 @@ class Config: validate_assignment = True +class InversionTypeES(IntEnum): + """ + The type of inversion used in the algorithm. Every inversion method + scales the variables. The options are: + + * `exact`: + Computes an exact inversion which uses a Cholesky factorization in the + case of symmetric, positive definite matrices. + * `subspace`: + This is an approximate solution. The approximation is that when + U, w, V.T = svd(D_delta) then we assume that U @ U.T = I. + """ + + exact = 0 + subspace = 1 + + +class InversionTypeIES(IntEnum): + """ + The type of inversion used in the algorithm. Every inversion method + scales the variables. The options are: + + * `direct`: + Solve Eqn (42) directly, which involves inverting a + matrix of shape (num_parameters, num_parameters). + * `subspace_exact` : + Solve Eqn (42) using Eqn (50), i.e., the Woodbury + lemma to invert a matrix of size (ensemble_size, ensemble_size). + * `subspace_projected` : + Solve Eqn (42) using Section 3.3, i.e., by projecting the covariance + onto S. This approach utilizes the truncation factor `truncation`. + """ + + direct = 0 + subspace_exact = 1 + subspace_projected = 2 + + class ESSettings(BaseSettings): + inversion: Annotated[ + InversionTypeES, Field(title="Inversion algorithm") + ] = InversionTypeES.exact localization: Annotated[bool, Field(title="Adaptive localization")] = False localization_correlation_threshold: Annotated[ Optional[float], @@ -69,6 +107,9 @@ class IESSettings(BaseSettings): """A good start is max steplength of 0.6, min steplength of 0.3, and decline of 2.5", A steplength of 1.0 and one iteration results in ES update""" + inversion: Annotated[ + InversionTypeIES, Field(title="Inversion algorithm") + ] = InversionTypeIES.subspace_exact ies_max_steplength: Annotated[ float, Field(ge=0.1, le=1.0, title="Gauss–Newton maximum steplength"), diff --git a/src/ert/gui/ertwidgets/analysismodulevariablespanel.py b/src/ert/gui/ertwidgets/analysismodulevariablespanel.py index c08d3087c60..fecbe3f0eff 100644 --- a/src/ert/gui/ertwidgets/analysismodulevariablespanel.py +++ b/src/ert/gui/ertwidgets/analysismodulevariablespanel.py @@ -3,14 +3,13 @@ from annotated_types import Ge, Gt, Le from qtpy.QtCore import Qt from qtpy.QtWidgets import ( - QButtonGroup, QCheckBox, + QComboBox, QDoubleSpinBox, QFormLayout, QFrame, QHBoxLayout, QLabel, - QRadioButton, QWidget, ) @@ -61,20 +60,17 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int): layout.addRow(self.create_horizontal_line()) layout.addRow(QLabel("Inversion Algorithm")) - bg = QButtonGroup(self) - for button_id, s in enumerate( - [ - "Exact inversion with diagonal R=I", - "Subspace inversion with exact R", - "Subspace inversion using R=EE'", - "Subspace inversion using E", - ], - start=0, - ): - b = QRadioButton(s, self) - b.setObjectName("IES_INVERSION_" + str(button_id)) - bg.addButton(b, button_id) - layout.addRow(b) + dropdown = QComboBox(self) + options = analysis_module.model_fields["inversion"].annotation + layout.addRow(QLabel(options.__doc__)) + default_index = 0 + for i, option in enumerate(options): + dropdown.addItem(option.name) + if analysis_module.inversion.name == option.name: + default_index = i + dropdown.setCurrentIndex(default_index) + dropdown.currentIndexChanged.connect(self.update_inversion_algorithm) + layout.addRow(dropdown) var_name = "enkf_truncation" metadata = analysis_module.model_fields[var_name] self.truncation_spinner = self.createDoubleSpinBox( @@ -87,9 +83,6 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int): self.truncation_spinner.setEnabled(False) layout.addRow("Singular value truncation", self.truncation_spinner) - bg.idClicked.connect(self.update_inversion_algorithm) - bg.buttons()[analysis_module.ies_inversion].click() # update the current value - if not isinstance(analysis_module, IESSettings): layout.addRow(self.create_horizontal_line()) layout.addRow(QLabel("[EXPERIMENTAL]")) @@ -139,7 +132,7 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int): def update_inversion_algorithm(self, button_id): self.truncation_spinner.setEnabled(button_id != 0) # not for exact inversion - self.analysis_module.ies_inversion = button_id + self.analysis_module.inversion = button_id def create_horizontal_line(self) -> QFrame: hline = QFrame() diff --git a/tests/unit_tests/analysis/test_es_update.py b/tests/unit_tests/analysis/test_es_update.py index e6da9b7900a..8050c9faa06 100644 --- a/tests/unit_tests/analysis/test_es_update.py +++ b/tests/unit_tests/analysis/test_es_update.py @@ -101,7 +101,7 @@ def test_update_report( ert_config.ensemble_config.parameters, ), UpdateSettings(misfit_preprocess=misfit_preprocess), - ESSettings(ies_inversion=1), + ESSettings(inversion=1), log_path=Path("update_log"), ) log_file = Path(ert_config.analysis_config.log_path) / "id.txt" @@ -237,7 +237,7 @@ def test_update_snapshot( run_id="id", update_config=update_configuration, update_settings=UpdateSettings(), - analysis_config=IESSettings(ies_inversion=1), + analysis_config=IESSettings(inversion=1), sies_step_length=sies_step_length, initial_mask=initial_mask, rng=rng, @@ -249,7 +249,7 @@ def test_update_snapshot( "id", update_configuration, UpdateSettings(), - ESSettings(ies_inversion=1), + ESSettings(inversion=1), rng=rng, ) @@ -357,7 +357,7 @@ def test_localization( "id", update_config, UpdateSettings(), - ESSettings(ies_inversion=1), + ESSettings(inversion=1), rng=np.random.default_rng(42), log_path=Path("update_log"), )