Skip to content

Commit

Permalink
Move inversion types to analysis module
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Jan 15, 2024
1 parent 0f93287 commit 8815e8a
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 65 deletions.
30 changes: 3 additions & 27 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,21 +585,12 @@ def analysis_ES(

num_obs = len(observation_values)

inversion_types = {0: "exact", 1: "subspace", 2: "subspace", 3: "subspace"}
try:
inversion_type = inversion_types[module.ies_inversion]
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched inversion type for: "
f"Specified: {module.ies_inversion}, with possible: {inversion_types}"
) from e

smoother_es = ies.ESMDA(
covariance=observation_errors**2,
observations=observation_values,
alpha=1, # The user is responsible for scaling observation covariance (esmda usage)
seed=rng,
inversion=inversion_type,
inversion=module.inversion.name,
)
truncation = module.enkf_truncation

Expand Down Expand Up @@ -731,7 +722,7 @@ def analysis_ES(
observation_errors=observation_errors,
observation_values=observation_values,
truncation=truncation,
inversion_type=inversion_type,
inversion_type=module.inversion.name,
progress_callback=progress_callback,
rng=rng,
)
Expand Down Expand Up @@ -761,21 +752,6 @@ def analysis_IES(
# This is needed for the SIES library
masking_of_initial_parameters = ens_mask[initial_mask]

# Map paper (current in ERT) inversion-types to SIES inversion-types
inversion_types = {
0: "direct",
1: "subspace_exact",
2: "subspace_projected",
3: "subspace_projected",
}
try:
inversion_type = inversion_types[analysis_config.ies_inversion]
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched inversion type for: "
f"Specified: {analysis_config.ies_inversion}, with possible: {inversion_types}"
) from e

# It is not the iterations relating to IES or ESMDA.
# It is related to functionality for turning on/off groups of parameters and observations.
for update_step in update_config:
Expand Down Expand Up @@ -824,7 +800,7 @@ def analysis_IES(
covariance=observation_errors**2,
observations=observation_values,
seed=rng,
inversion=inversion_type,
inversion=analysis_config.inversion.name,
truncation=analysis_config.enkf_truncation,
)

Expand Down
5 changes: 4 additions & 1 deletion src/ert/config/analysis_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ def __init__(
if var_name == "ENKF_FORCE_NCOMP":
continue
if var_name == "INVERSION":
value = str(inversion_str_map[value])
value = inversion_str_map[value]
var_name = "IES_INVERSION"
if var_name == "IES_INVERSION":
value = int(value)
var_name = "inversion"
key = var_name.lower()
options[module_name][key] = value
try:
Expand Down
49 changes: 45 additions & 4 deletions src/ert/config/analysis_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import math
from enum import IntEnum
from typing import TYPE_CHECKING, Optional, Type, TypedDict, Union

from pydantic import BaseModel, Extra, Field
Expand All @@ -23,14 +24,10 @@ class VariableInfo(TypedDict):
DEFAULT_IES_MIN_STEPLENGTH = 0.30
DEFAULT_IES_DEC_STEPLENGTH = 2.50
DEFAULT_ENKF_TRUNCATION = 0.98
DEFAULT_IES_INVERSION = 0
DEFAULT_LOCALIZATION = False


class BaseSettings(BaseModel):
ies_inversion: Annotated[
int, Field(ge=0, le=3, title="Inversion algorithm")
] = DEFAULT_IES_INVERSION
enkf_truncation: Annotated[
float,
Field(gt=0.0, le=1.0, title="Singular value truncation"),
Expand All @@ -41,7 +38,48 @@ class Config:
validate_assignment = True


class InversionTypeES(IntEnum):
"""
The type of inversion used in the algorithm. Every inversion method
scales the variables. The options are:
* `exact`:
Computes an exact inversion which uses a Cholesky factorization in the
case of symmetric, positive definite matrices.
* `subspace`:
This is an approximate solution. The approximation is that when
U, w, V.T = svd(D_delta) then we assume that U @ U.T = I.
"""

exact = 0
subspace = 1


class InversionTypeIES(IntEnum):
"""
The type of inversion used in the algorithm. Every inversion method
scales the variables. The options are:
* `direct`:
Solve Eqn (42) directly, which involves inverting a
matrix of shape (num_parameters, num_parameters).
* `subspace_exact` :
Solve Eqn (42) using Eqn (50), i.e., the Woodbury
lemma to invert a matrix of size (ensemble_size, ensemble_size).
* `subspace_projected` :
Solve Eqn (42) using Section 3.3, i.e., by projecting the covariance
onto S. This approach utilizes the truncation factor `truncation`.
"""

direct = 0
subspace_exact = 1
subspace_projected = 2


class ESSettings(BaseSettings):
inversion: Annotated[
InversionTypeES, Field(title="Inversion algorithm")
] = InversionTypeES.exact
localization: Annotated[bool, Field(title="Adaptive localization")] = False
localization_correlation_threshold: Annotated[
Optional[float],
Expand Down Expand Up @@ -69,6 +107,9 @@ class IESSettings(BaseSettings):
"""A good start is max steplength of 0.6, min steplength of 0.3, and decline of 2.5",
A steplength of 1.0 and one iteration results in ES update"""

inversion: Annotated[
InversionTypeIES, Field(title="Inversion algorithm")
] = InversionTypeIES.subspace_exact
ies_max_steplength: Annotated[
float,
Field(ge=0.1, le=1.0, title="Gauss–Newton maximum steplength"),
Expand Down
33 changes: 13 additions & 20 deletions src/ert/gui/ertwidgets/analysismodulevariablespanel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from annotated_types import Ge, Gt, Le
from qtpy.QtCore import Qt
from qtpy.QtWidgets import (
QButtonGroup,
QCheckBox,
QComboBox,
QDoubleSpinBox,
QFormLayout,
QFrame,
QHBoxLayout,
QLabel,
QRadioButton,
QWidget,
)

Expand Down Expand Up @@ -61,20 +60,17 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int):
layout.addRow(self.create_horizontal_line())

layout.addRow(QLabel("Inversion Algorithm"))
bg = QButtonGroup(self)
for button_id, s in enumerate(
[
"Exact inversion with diagonal R=I",
"Subspace inversion with exact R",
"Subspace inversion using R=EE'",
"Subspace inversion using E",
],
start=0,
):
b = QRadioButton(s, self)
b.setObjectName("IES_INVERSION_" + str(button_id))
bg.addButton(b, button_id)
layout.addRow(b)
dropdown = QComboBox(self)
options = analysis_module.model_fields["inversion"].annotation
layout.addRow(QLabel(options.__doc__))
default_index = 0
for i, option in enumerate(options):
dropdown.addItem(option.name)
if analysis_module.inversion.name == option.name:
default_index = i
dropdown.setCurrentIndex(default_index)
dropdown.currentIndexChanged.connect(self.update_inversion_algorithm)
layout.addRow(dropdown)
var_name = "enkf_truncation"
metadata = analysis_module.model_fields[var_name]
self.truncation_spinner = self.createDoubleSpinBox(
Expand All @@ -87,9 +83,6 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int):
self.truncation_spinner.setEnabled(False)
layout.addRow("Singular value truncation", self.truncation_spinner)

bg.idClicked.connect(self.update_inversion_algorithm)
bg.buttons()[analysis_module.ies_inversion].click() # update the current value

if not isinstance(analysis_module, IESSettings):
layout.addRow(self.create_horizontal_line())
layout.addRow(QLabel("[EXPERIMENTAL]"))
Expand Down Expand Up @@ -139,7 +132,7 @@ def __init__(self, analysis_module: AnalysisModule, ensemble_size: int):

def update_inversion_algorithm(self, button_id):
self.truncation_spinner.setEnabled(button_id != 0) # not for exact inversion
self.analysis_module.ies_inversion = button_id
self.analysis_module.inversion = button_id

def create_horizontal_line(self) -> QFrame:
hline = QFrame()
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_update_report(
ert_config.ensemble_config.parameters,
),
UpdateSettings(misfit_preprocess=misfit_preprocess),
ESSettings(ies_inversion=1),
ESSettings(inversion=1),
log_path=Path("update_log"),
)
log_file = Path(ert_config.analysis_config.log_path) / "id.txt"
Expand Down Expand Up @@ -237,7 +237,7 @@ def test_update_snapshot(
run_id="id",
update_config=update_configuration,
update_settings=UpdateSettings(),
analysis_config=IESSettings(ies_inversion=1),
analysis_config=IESSettings(inversion=1),
sies_step_length=sies_step_length,
initial_mask=initial_mask,
rng=rng,
Expand All @@ -249,7 +249,7 @@ def test_update_snapshot(
"id",
update_configuration,
UpdateSettings(),
ESSettings(ies_inversion=1),
ESSettings(inversion=1),
rng=rng,
)

Expand Down Expand Up @@ -357,7 +357,7 @@ def test_localization(
"id",
update_config,
UpdateSettings(),
ESSettings(ies_inversion=1),
ESSettings(inversion=1),
rng=np.random.default_rng(42),
log_path=Path("update_log"),
)
Expand Down
17 changes: 8 additions & 9 deletions tests/unit_tests/gui/test_main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,17 +581,16 @@ def test_that_inversion_type_can_be_set_from_gui(qtbot, opened_main_window):
# https://github.com/pytest-dev/pytest-qt/issues/256
def handle_analysis_module_panel():
var_panel = wait_for_child(gui, qtbot, AnalysisModuleVariablesPanel)
rb0 = wait_for_child(var_panel, qtbot, QRadioButton, name="IES_INVERSION_0")
rb1 = wait_for_child(var_panel, qtbot, QRadioButton, name="IES_INVERSION_1")
rb2 = wait_for_child(var_panel, qtbot, QRadioButton, name="IES_INVERSION_2")
rb3 = wait_for_child(var_panel, qtbot, QRadioButton, name="IES_INVERSION_3")
dropdown = wait_for_child(var_panel, qtbot, QComboBox)
spinner = wait_for_child(var_panel, qtbot, QDoubleSpinBox, "enkf_truncation")

for b in [rb0, rb1, rb2, rb3, rb0]:
b.click()
assert b.isChecked()
assert [dropdown.itemText(i) for i in range(dropdown.count())] == [
"exact",
"subspace",
]
for i in range(dropdown.count()):
dropdown.setCurrentIndex(i)
# spinner should be enabled if not rb0 set
assert spinner.isEnabled() == (b != rb0)
assert spinner.isEnabled() == (i != 0)

var_panel.parent().close()

Expand Down

0 comments on commit 8815e8a

Please sign in to comment.