Skip to content

Commit

Permalink
Replace configsuite with pydantic in interp_relperm
Browse files Browse the repository at this point in the history
Remove configsuite from subscripts dependencies as
interp_relperm was the last user of it.
  • Loading branch information
berland committed Jan 9, 2024
1 parent 549b6bf commit a54897e
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 256 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"configsuite",
"resdata",
"res2df",
"ert>=2.38.0b7",
Expand Down
228 changes: 58 additions & 170 deletions src/subscript/interp_relperm/interp_relperm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional

import configsuite # lgtm [py/import-and-import-from]
import pandas as pd
import pyscal
import yaml
from configsuite import MetaKeys as MK # lgtm [py/import-and-import-from]
from configsuite import types # lgtm [py/import-and-import-from]
from pydantic import BaseModel, Field, FilePath, model_validator
from res2df import satfunc
from typing_extensions import Annotated

import subscript

Expand Down Expand Up @@ -101,146 +102,43 @@
""" # noqa


@configsuite.validator_msg("Valid file name")
def _is_filename(filename: str):
return Path(filename).exists()


@configsuite.validator_msg("Valid interpolator list")
def _is_valid_interpolator_list(interpolators: list):
if len(interpolators) > 0:
return True
return False


@configsuite.validator_msg("Valid interpolator")
def _is_valid_interpolator(interp: dict):
valid = False

try:
if interp["param_w"]:
valid = True
elif interp["param_w"] == 0:
valid = True

except (KeyError, ValueError, TypeError):
pass

try:
if interp["param_w"] > 1.0 or interp["param_w"] < -1.0:
valid = False
except (KeyError, ValueError, TypeError):
pass

try:
if interp["param_g"]:
valid = True
elif interp["param_g"] == 0:
valid = True
except (KeyError, ValueError, TypeError):
pass

try:
if interp["param_g"] > 1.0 or interp["param_g"] < -1.0:
valid = False
except (KeyError, ValueError, TypeError):
pass

return valid


@configsuite.validator_msg("Low, base and high are provided")
def _is_valid_table_entries(schema: dict):
if "base" in schema and "low" in schema and "high" in schema:
if schema["low"] and schema["base"] and schema["high"]:
return (
isinstance(schema["low"], tuple)
and isinstance(schema["base"], tuple)
and isinstance(schema["high"], tuple)
)
if "pyscalfile" in schema:
# If pyscalfile is given, we don't need low/base/high
return True
return False


@configsuite.validator_msg("Valid Eclipse keyword family")
def _is_valid_eclipse_keyword_family(familychoice: int):
return familychoice in [1, 2]


def get_cfg_schema() -> dict:
"""
Defines the yml config schema
"""
schema = {
MK.Type: types.NamedDict,
MK.ElementValidators: (_is_valid_table_entries,),
MK.Content: {
"base": {
MK.Type: types.List,
MK.Content: {
MK.Item: {
MK.Type: types.String,
MK.ElementValidators: (_is_filename,),
}
},
},
"low": {
MK.Type: types.List,
MK.Content: {
MK.Item: {
MK.Type: types.String,
MK.ElementValidators: (_is_filename,),
}
},
},
"high": {
MK.Type: types.List,
MK.Content: {
MK.Item: {
MK.Type: types.String,
MK.ElementValidators: (_is_filename,),
}
},
},
"pyscalfile": {
MK.Type: types.String,
MK.ElementValidators: (_is_filename,),
MK.AllowNone: True,
},
"result_file": {MK.Type: types.String},
"family": {
MK.Type: types.Number,
MK.Default: 1,
MK.ElementValidators: (_is_valid_eclipse_keyword_family,),
},
"delta_s": {MK.Type: types.Number, MK.Default: 0.01},
"interpolations": {
MK.Type: types.List,
MK.ElementValidators: (_is_valid_interpolator_list,),
MK.Content: {
MK.Item: {
MK.Type: types.NamedDict,
MK.ElementValidators: (_is_valid_interpolator,),
MK.Content: {
"tables": {
MK.Type: types.List,
MK.Content: {MK.Item: {MK.Type: types.Integer}},
},
"param_w": {MK.Type: types.Number, MK.AllowNone: True},
"param_g": {MK.Type: types.Number, MK.AllowNone: True},
},
}
},
},
},
}

return schema


def parse_satfunc_files(filenames: List[str]) -> pd.DataFrame:
class Interpolator(BaseModel):
tables: Optional[List[int]] = []
param_w: Optional[Annotated[float, Field(strict=True, ge=-1, le=1)]] = None
param_g: Optional[Annotated[float, Field(strict=True, ge=-1, le=1)]] = None

@model_validator(mode="after")
def check_param_w_or_param_g(self) -> Interpolator:
assert (
self.param_w is not None or self.param_g is not None
), "Provide either param_w or param_g"
return self


class InterpRelpermConfig(BaseModel):
low: Optional[List[FilePath]] = None
base: Optional[List[FilePath]] = None
high: Optional[List[FilePath]] = None
pyscalfile: Optional[FilePath] = None
result_file: str
family: Literal[1, 2] = 1
delta_s: float = 0.01
interpolations: Annotated[List[Interpolator], Field(..., min_length=1)]

@model_validator(mode="after")
def check_lowbasehigh_or_pyscalfile(self) -> InterpRelpermConfig:
if self.pyscalfile is None:
assert self.base is not None, "base is not provided"
assert self.high is not None, "high is not provided"
assert self.low is not None, "low is not provided"
else:
assert self.base is None, "do not specify base when pyscalfile is set"
assert self.high is None, "do not specify high when pyscalfile is set"
assert self.low is None, "do not specify low when pyscalfile is set"
return self


def parse_satfunc_files(filenames: List[Path]) -> pd.DataFrame:
"""
Routine to gather scal tables (SWOF and SGOF) from ecl include files.
Expand All @@ -251,7 +149,6 @@ def parse_satfunc_files(filenames: List[str]) -> pd.DataFrame:
Returns:
dataframe with the tables
"""

return pd.concat(
[
satfunc.df(Path(filename).read_text(encoding="utf8"))
Expand Down Expand Up @@ -457,35 +354,23 @@ def process_config(cfg: Dict[str, Any], root_path: Optional[Path] = None) -> Non
if root_path is not None:
cfg = prepend_root_path_to_relative_files(cfg, root_path)

cfg_schema = get_cfg_schema()
cfg_suite = configsuite.ConfigSuite(cfg, cfg_schema, deduce_required=True)

if not cfg_suite.valid:
logger.error("Sorry, the configuration is invalid.")
sys.exit(cfg_suite.errors)

# set default values
relperm_delta_s = False
if cfg_suite.snapshot.delta_s:
relperm_delta_s = cfg_suite.snapshot.delta_s
config = InterpRelpermConfig(**cfg)

base_df: pd.DataFrame = pd.DataFrame()
low_df: pd.DataFrame = pd.DataFrame()
high_df: pd.DataFrame = pd.DataFrame()

if cfg_suite.snapshot.pyscalfile is not None:
if cfg_suite.snapshot.base or cfg_suite.snapshot.low or cfg_suite.snapshot.high:
if config.pyscalfile is not None:
if config.base or config.low or config.high:
logger.error(
"Inconsistent configuration. "
"You cannot define both pyscalfile and base/low/high"
)
sys.exit(1)

logger.info(
"Loading relperm parametrization from %s", cfg_suite.snapshot.pyscalfile
)
logger.info("Loading relperm parametrization from %s", config.pyscalfile)
param_dframe = pyscal.PyscalFactory.load_relperm_df(
cfg_suite.snapshot.pyscalfile
config.pyscalfile
).set_index("CASE")
base_df = (
pyscal.PyscalFactory.create_pyscal_list(param_dframe.loc["base"])
Expand All @@ -504,9 +389,12 @@ def process_config(cfg: Dict[str, Any], root_path: Optional[Path] = None) -> Non
)
else:
# Parse tables from files
base_df = parse_satfunc_files(cfg_suite.snapshot.base)
low_df = parse_satfunc_files(cfg_suite.snapshot.low)
high_df = parse_satfunc_files(cfg_suite.snapshot.high)
assert config.base is not None
base_df = parse_satfunc_files(config.base)
assert config.low is not None
low_df = parse_satfunc_files(config.low)
assert config.high is not None
high_df = parse_satfunc_files(config.high)

if not (
set(base_df.columns) == set(low_df.columns)
Expand All @@ -523,7 +411,7 @@ def process_config(cfg: Dict[str, Any], root_path: Optional[Path] = None) -> Non
satnums = range(1, base_df.reset_index("SATNUM")["SATNUM"].unique().max() + 1)
for satnum in satnums:
interp_values = {"param_w": 0.0, "param_g": 0.0}
for interp in cfg_suite.snapshot.interpolations:
for interp in config.interpolations:
if not interp.tables or satnum in interp.tables:
if interp.param_w:
interp_values["param_w"] = interp.param_w
Expand All @@ -537,17 +425,17 @@ def process_config(cfg: Dict[str, Any], root_path: Optional[Path] = None) -> Non
high_df.loc[satnum],
interp_values,
satnum,
relperm_delta_s,
config.delta_s,
)
)

Path(cfg_suite.snapshot.result_file).write_text(
interpolants.build_eclipse_data(cfg_suite.snapshot.family), encoding="utf-8"
Path(config.result_file).write_text(
interpolants.build_eclipse_data(config.family), encoding="utf-8"
)

logger.info(
"Done; interpolated relperm curves written to file: %s",
str(cfg_suite.snapshot.result_file),
str(config.result_file),
)


Expand Down
Loading

0 comments on commit a54897e

Please sign in to comment.