Skip to content

Commit

Permalink
Add step plot for rate vectors
Browse files Browse the repository at this point in the history
This adds a key util function is_rate which determines whether the key
represents a rate or not (Source: resdata).
Additionally, adding tests that validate that the is_rate gives
consinsent results with Summary.is_rate function.
  • Loading branch information
xjules committed Sep 11, 2024
1 parent 22e95fc commit 2353822
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/ert/gui/plottery/plots/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Optional

import numpy as np
import pandas as pd

from ert.gui.plottery.plots.history import plotHistory
from ert.gui.tools.plot.plot_api import EnsembleObject
from ert.shared.storage.summary_key_utils import is_rate

from .observations import plotObservations
from .plot_tools import PlotTools
Expand Down Expand Up @@ -45,11 +46,13 @@ def plot(
plot_context.deactivateDateSupport()
plot_context.x_axis = plot_context.INDEX_AXIS

draw_style = "steps-pre" if is_rate(plot_context.key()) else None
self._plotLines(
axes,
config,
data,
f"{ensemble.experiment_name} : {ensemble.name}",
draw_style,
)
config.nextColor()

Expand All @@ -71,6 +74,7 @@ def _plotLines(
plot_config: PlotConfig,
data: pd.DataFrame,
ensemble_label: str,
draw_style: Optional[str] = None,
) -> None:
style = plot_config.defaultStyle()

Expand All @@ -86,6 +90,7 @@ def _plotLines(
linewidth=style.width,
linestyle=style.line_style,
markersize=style.size,
drawstyle=draw_style,
)

if len(lines) > 0:
Expand Down
200 changes: 200 additions & 0 deletions src/ert/shared/storage/summary_key_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from enum import Enum, auto
from typing import List

special_keys = [
"NAIMFRAC",
"NBAKFL",
"NBYTOT",
"NCPRLINS",
"NEWTFL",
"NEWTON",
"NLINEARP",
"NLINEARS",
"NLINSMAX",
"NLINSMIN",
"NLRESMAX",
"NLRESSUM",
"NMESSAGE",
"NNUMFL",
"NNUMST",
"NTS",
"NTSECL",
"NTSMCL",
"NTSPCL",
"ELAPSED",
"MAXDPR",
"MAXDSO",
"MAXDSG",
"MAXDSW",
"STEPTYPE",
"WNEWTON",
]
rate_keys = [
"OPR",
"OIR",
"OVPR",
"OVIR",
"OFR",
"OPP",
"OPI",
"OMR",
"GPR",
"GIR",
"GVPR",
"GVIR",
"GFR",
"GPP",
"GPI",
"GMR",
"WGPR",
"WGIR",
"WPR",
"WIR",
"WVPR",
"WVIR",
"WFR",
"WPP",
"WPI",
"WMR",
"LPR",
"LFR",
"VPR",
"VIR",
"VFR",
"GLIR",
"RGR",
"EGR",
"EXGR",
"SGR",
"GSR",
"FGR",
"GIMR",
"GCR",
"NPR",
"NIR",
"CPR",
"CIR",
"SIR",
"SPR",
"TIR",
"TPR",
"GOR",
"WCT",
"OGR",
"WGR",
"GLR",
]

seg_rate_keys = [
"OFR",
"GFR",
"WFR",
"CFR",
"SFR",
"TFR",
"CVPR",
"WCT",
"GOR",
"OGR",
"WGR",
]


class SummaryKeyType(Enum):
INVALID = auto()
FIELD = auto()
REGION = auto()
GROUP = auto()
WELL = auto()
SEGMENT = auto()
BLOCK = auto()
AQUIFER = auto()
COMPLETION = auto()
NETWORK = auto()
REGION_2_REGION = auto()
LOCAL_BLOCK = auto()
LOCAL_COMPLETION = auto()
LOCAL_WELL = auto()
MISC = auto()

@staticmethod
def determine_key_type(key: str) -> "SummaryKeyType":
if key in special_keys:
return SummaryKeyType.MISC

if key.startswith("L"):
secondary = key[1] if len(key) > 1 else ""
return {
"B": SummaryKeyType.LOCAL_BLOCK,
"C": SummaryKeyType.LOCAL_COMPLETION,
"W": SummaryKeyType.LOCAL_WELL,
}.get(secondary, SummaryKeyType.MISC)

if key.startswith("R"):
if len(key) == 3 and key[2] == "F":
return SummaryKeyType.REGION_2_REGION
if key == "RNLF":
return SummaryKeyType.REGION_2_REGION
if key == "RORFR":
return SummaryKeyType.REGION
if len(key) >= 4 and key[2] == "F" and key[3] in {"T", "R"}:
return SummaryKeyType.REGION_2_REGION
if len(key) >= 5 and key[3] == "F" and key[4] in {"T", "R"}:
return SummaryKeyType.REGION_2_REGION
return SummaryKeyType.REGION

# default cases or miscellaneous if not matched
return {
"A": SummaryKeyType.AQUIFER,
"B": SummaryKeyType.BLOCK,
"C": SummaryKeyType.COMPLETION,
"F": SummaryKeyType.FIELD,
"G": SummaryKeyType.GROUP,
"N": SummaryKeyType.NETWORK,
"S": SummaryKeyType.SEGMENT,
"W": SummaryKeyType.WELL,
}.get(key[0], SummaryKeyType.MISC)


def match_keyword_vector(start: int, rate_keys: List[str], keyword: str) -> bool:
if len(keyword) < start:
return False
return any(keyword[start:].startswith(key) for key in rate_keys)


def match_keyword_string(start: int, rate_string: str, keyword: str) -> bool:
if len(keyword) < start:
return False
return keyword[start:].startswith(rate_string)


def is_rate(key: str) -> bool:
key_type = SummaryKeyType.determine_key_type(key)
if key_type in {
SummaryKeyType.WELL,
SummaryKeyType.GROUP,
SummaryKeyType.FIELD,
SummaryKeyType.REGION,
SummaryKeyType.COMPLETION,
SummaryKeyType.LOCAL_WELL,
SummaryKeyType.LOCAL_COMPLETION,
SummaryKeyType.NETWORK,
}:
if key_type in {
SummaryKeyType.LOCAL_WELL,
SummaryKeyType.LOCAL_COMPLETION,
SummaryKeyType.NETWORK,
}:
return match_keyword_vector(2, rate_keys, key)
return match_keyword_vector(1, rate_keys, key)

if key_type == SummaryKeyType.SEGMENT:
return match_keyword_vector(1, seg_rate_keys, key)

if key_type == SummaryKeyType.REGION_2_REGION:
# Region to region rates are identified by R*FR or R**FR
if match_keyword_string(2, "FR", key):
return True
return match_keyword_string(3, "FR", key)

return False
62 changes: 62 additions & 0 deletions tests/unit_tests/shared/test_rate_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import hypothesis.strategies as st
import pytest
from hypothesis import given
from resdata.summary import Summary

from ert.shared.storage.summary_key_utils import is_rate
from tests.unit_tests.config.summary_generator import summary_variables


def nonempty_string_without_whitespace():
return st.text(
st.characters(whitelist_categories=("Lu", "Ll", "Nd", "P")), min_size=1
)


@given(key=nonempty_string_without_whitespace())
def test_is_rate_does_not_raise_error(key):
is_rate_bool = is_rate(key)
assert isinstance(is_rate_bool, bool)


examples = [
("OPR", False),
("WOPR:OP_4", True),
("WGIR", True),
("FOPT", False),
("GGPT", False),
("RWPT", False),
("COPR", True),
("LPR", False),
("LWPR", False),
("LCOPR", True),
("RWGIR", True),
("RTPR", True),
("RXFR", True),
("XXX", False),
("YYYY", False),
("ZZT", False),
("SGPR", False),
("AAPR", False),
("JOPR", False),
("ROPRT", True),
("RNFT", False),
("RFR", False),
("RRFRT", True),
("ROC", False),
("BPR:123", False),
("FWIR", True),
]


@pytest.mark.parametrize("key, rate", examples)
def test_is_rate_determines_rate_key_correctly(key, rate):
is_rate_bool = is_rate(key)
assert is_rate_bool == rate


@given(key=summary_variables())
def test_rate_determination_is_consistent(key):
# Here we verify that the determination of rate keys is the same
# as provided by resdata api
assert Summary.is_rate(key) == is_rate(key)

0 comments on commit 2353822

Please sign in to comment.