Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save & load responses as parquet #8684

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies = [
"packaging",
"pandas",
"pluggy>=1.3.0",
"polars",
"psutil",
"pyarrow", # extra dependency for pandas (parquet)
"pydantic > 2",
Expand Down
123 changes: 71 additions & 52 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import iterative_ensemble_smoother as ies
import numpy as np
import pandas as pd
import polars
import psutil
from iterative_ensemble_smoother.experimental import (
AdaptiveESMDA,
Expand Down Expand Up @@ -153,56 +153,75 @@ def _get_observations_and_responses(
observation_values = []
observation_errors = []
indexes = []
observations = ensemble.experiment.observations
for obs in selected_observations:
observation = observations[obs]
group = observation.attrs["response"]
all_responses = ensemble.load_responses(group, tuple(iens_active_index))
if "time" in observation.coords:
all_responses = all_responses.reindex(
time=observation.time,
method="nearest",
observations_by_type = ensemble.experiment.observations
for (
response_type,
response_cls,
) in ensemble.experiment.response_configuration.items():
if response_type not in observations_by_type:
continue

observations_for_type = observations_by_type[response_type].filter(
polars.col("observation_key").is_in(list(selected_observations))
)
responses_for_type = ensemble.load_responses(
response_type, realizations=tuple(iens_active_index)
)

# Note that if there are duplicate entries for one
# response at one index, they are aggregated together
# with "mean" by default
pivoted = responses_for_type.pivot(
on="realization",
index=["response_key", *response_cls.primary_key],
aggregate_function="mean",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the implication of mean?

Copy link
Collaborator

@oyvindeide oyvindeide Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It said so in the comment 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will that be output somewhere? Is it possible to for example log it?

Copy link
Contributor Author

@yngve-sk yngve-sk Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for the edge case where we end up with duplicate values for one response at one index, for example a given time. In that case, we need to aggregate them for the pivoted table to make sense, else the index used to pivot contains duplicates. So taking the average of the duplicate response values on the timestep seems to be somewhat "close enough" to do what we want, we could set it to use min,max,median,first, etc, could configure it, but not sure if it would be interesting to users to do this?

Example from running test_that_duplicate_summary_time_steps_does_not_fail:

responses_for_type.pivot(
            on="realization",
            index=["response_key", *response_cls.primary_key],
            aggregate_function="mean",
        )
Out[9]: 
shape: (1, 5)
┌──────────────┬─────────────────────┬───────────┬────────┬──────────┐
│ response_key ┆ time                ┆ 0         ┆ 1      ┆ 2        │
│ ---          ┆ ---                 ┆ ---       ┆ ---    ┆ ---      │
│ str          ┆ datetime[ms]        ┆ f32       ┆ f32    ┆ f32      │
╞══════════════╪═════════════════════╪═══════════╪════════╪══════════╡
│ FOPR         ┆ 2014-09-10 00:00:00 ┆ -1.603837 ┆ 0.0641 ┆ 0.740891 │
└──────────────┴─────────────────────┴───────────┴────────┴──────────┘
responses_for_type
Out[10]: 
shape: (4, 4)
┌─────────────┬──────────────┬─────────────────────┬───────────┐
│ realization ┆ response_key ┆ time                ┆ values    │
│ ---         ┆ ---          ┆ ---                 ┆ ---       │
│ u16         ┆ str          ┆ datetime[ms]        ┆ f32       │
╞═════════════╪══════════════╪═════════════════════╪═══════════╡
│ 0           ┆ FOPR         ┆ 2014-09-10 00:00:00 ┆ -1.603837 │
│ 1           ┆ FOPR         ┆ 2014-09-10 00:00:00 ┆ 0.0641    │
│ 2           ┆ FOPR         ┆ 2014-09-10 00:00:00 ┆ 0.740891  │
│ 2           ┆ FOPR         ┆ 2014-09-10 00:00:00 ┆ 0.740891  │
└─────────────┴──────────────┴─────────────────────┴───────────┘

Alternatively we could strive to achieve something like this:

┌──────────────┬─────────────────────┬───────────┬────────┬──────────┐
│ response_key ┆ time                ┆ 0         ┆ 1      ┆ 2        │
│ ---          ┆ ---                 ┆ ---       ┆ ---    ┆ ---      │
│ str          ┆ datetime[ms]        ┆ f32       ┆ f32    ┆ f32      │
╞══════════════╪═════════════════════╪═══════════╪════════╪══════════╡
│ FOPR         ┆ 2014-09-10 00:00:00 ┆ -1.603837 ┆ 0.0641 ┆ 0.740891 │
│ FOPR         ┆ 2014-09-10 00:00:00 ┆    NaN    ┆  NaN   ┆ 0.740891 │
└──────────────┴─────────────────────┴───────────┴────────┴──────────┘

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be logged / given as a warning somehow, I'm not so familiar with when/why it happens, which may be relevant to what the warning/logging message should be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Performance-wise it might be slow to always check if some values were aggregated, or a naive try-catch around the pivot, as it will pass if there are no duplicate values)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a good, somewhat performant way of warning the user this has happened, that would be good. My hunch is that this would typically happen in pressure tests where the time resolution is quite high, and the simulator does not have the same resolution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be OK to do this in a separate PR? I think the try-catch, first trying without an aggregation, then trying with one, should be easy to add / easy to remove if it turns out to have bad side effects. Should maybe be tested as its own thing just to be sure.

)

# Note2reviewer:
# We need to either assume that if there is a time column
# we will approx-join that, or we could specify in response configs
# that there is a column that requires an approx "asof" join.
# Suggest we simplify and assume that there is always only
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, if and when we add new response types where this might be relevant we can add it then.

# one "time" column, which we will reindex towards the response dataset
# with a given resolution
if "time" in pivoted:
joined = observations_for_type.join_asof(
pivoted,
by=["response_key", *response_cls.primary_key],
on="time",
tolerance="1s",
)
try:
observations_and_responses = observation.merge(all_responses, join="left")
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched index for: "
f"Observation: {obs} attached to response: {group}"
) from e

observation_keys.append([obs] * observations_and_responses["observations"].size)

if group == "summary":
indexes.append(
[
np.datetime_as_string(e, unit="s")
for e in observations_and_responses["time"].data
]
)
else:
indexes.append(
[
f"{e[0]}, {e[1]}"
for e in zip(
list(observations_and_responses["report_step"].data)
* len(observations_and_responses["index"].data),
observations_and_responses["index"].data,
)
]
joined = observations_for_type.join(
pivoted,
how="left",
on=["response_key", *response_cls.primary_key],
)

observation_values.append(
observations_and_responses["observations"].data.ravel()
)
observation_errors.append(observations_and_responses["std"].data.ravel())
joined = joined.sort(by="observation_key")

index_1d = joined.with_columns(
polars.concat_str(response_cls.primary_key, separator=", ").alias("index")
)["index"].to_numpy()

obs_keys_1d = joined["observation_key"].to_numpy()
obs_values_1d = joined["observations"].to_numpy()
obs_errors_1d = joined["std"].to_numpy()

# 4 columns are always there:
# [ response_key, observation_key, observations, std ]
# + one column per "primary key" column
num_non_response_value_columns = 4 + len(response_cls.primary_key)
responses = joined.select(
joined.columns[num_non_response_value_columns:]
).to_numpy()

filtered_responses.append(responses)
observation_keys.append(obs_keys_1d)
observation_values.append(obs_values_1d)
observation_errors.append(obs_errors_1d)
indexes.append(index_1d)

filtered_responses.append(
observations_and_responses["values"]
.transpose(..., "realization")
.values.reshape((-1, len(observations_and_responses.realization)))
)
ensemble.load_responses.cache_clear()
return (
np.concatenate(filtered_responses),
Expand Down Expand Up @@ -288,12 +307,14 @@ def _load_observations_and_responses(
scaling[obs_group_mask] *= scaling_factors

scaling_factors_dfs.append(
pd.DataFrame(
data={
polars.DataFrame(
{
"input_group": [", ".join(input_group)] * len(scaling_factors),
"index": indexes[obs_group_mask],
"obs_key": obs_keys[obs_group_mask],
"scaling_factor": scaling_factors,
"scaling_factor": polars.Series(
scaling_factors, dtype=polars.Float32
),
}
)
)
Expand Down Expand Up @@ -322,10 +343,8 @@ def _load_observations_and_responses(
)
)

scaling_factors_df = pd.concat(scaling_factors_dfs).set_index(
["input_group", "obs_key", "index"], verify_integrity=True
)
ensemble.save_observation_scaling_factors(scaling_factors_df.to_xarray())
scaling_factors_df = polars.concat(scaling_factors_dfs)
ensemble.save_observation_scaling_factors(scaling_factors_df)

# Recompute with updated scales
scaled_errors = errors * scaling
Expand Down
26 changes: 24 additions & 2 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
overload,
)

import xarray as xr
import polars
from pydantic import ValidationError as PydanticValidationError
from typing_extensions import Self

Expand Down Expand Up @@ -112,6 +112,28 @@ class ErtConfig:
Tuple[str, Union[HistoryValues, SummaryValues, GenObsValues]]
] = field(default_factory=list)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ErtConfig):
return False

for attr in vars(self):
if attr == "observations":
if self.observations.keys() != other.observations.keys():
return False

if not all(
self.observations[k].equals(other.observations[k])
for k in self.observations
):
return False

continue

if getattr(self, attr) != getattr(other, attr):
return False

return True

def __post_init__(self) -> None:
self.config_path = (
path.dirname(path.abspath(self.user_config_file))
Expand All @@ -120,7 +142,7 @@ def __post_init__(self) -> None:
)
self.enkf_obs: EnkfObs = self._create_observations(self.observation_config)

self.observations: Dict[str, xr.Dataset] = self.enkf_obs.datasets
self.observations: Dict[str, polars.DataFrame] = self.enkf_obs.datasets

@staticmethod
def with_plugins(
Expand Down
34 changes: 20 additions & 14 deletions src/ert/config/gen_data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Optional, Tuple

import numpy as np
import xarray as xr
import polars
from typing_extensions import Self

from ert.validation import rangestring_to_list
Expand Down Expand Up @@ -107,21 +107,23 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]:
report_steps_list=report_steps,
)

def read_from_file(self, run_path: str, _: int) -> xr.Dataset:
def _read_file(filename: Path, report_step: int) -> xr.Dataset:
def read_from_file(self, run_path: str, _: int) -> polars.DataFrame:
def _read_file(filename: Path, report_step: int) -> polars.DataFrame:
if not filename.exists():
raise ValueError(f"Missing output file: {filename}")
data = np.loadtxt(_run_path / filename, ndmin=1)
active_information_file = _run_path / (str(filename) + "_active")
if active_information_file.exists():
active_list = np.loadtxt(active_information_file)
data[active_list == 0] = np.nan
return xr.Dataset(
{"values": (["report_step", "index"], [data])},
coords={
"index": np.arange(len(data)),
"report_step": [report_step],
},
return polars.DataFrame(
{
"report_step": polars.Series(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This made it much easier to read!

np.full(len(data), report_step), dtype=polars.UInt16
),
"index": polars.Series(np.arange(len(data)), dtype=polars.UInt16),
"values": polars.Series(data, dtype=polars.Float32),
}
)

errors = []
Expand Down Expand Up @@ -150,16 +152,16 @@ def _read_file(filename: Path, report_step: int) -> xr.Dataset:
except ValueError as err:
errors.append(str(err))

ds_all_report_steps = xr.concat(
datasets_per_report_step, dim="report_step"
).expand_dims(name=[name])
ds_all_report_steps = polars.concat(datasets_per_report_step)
ds_all_report_steps.insert_column(
0, polars.Series("response_key", [name] * len(ds_all_report_steps))
)
datasets_per_name.append(ds_all_report_steps)

if errors:
raise ValueError(f"Error reading GEN_DATA: {self.name}, errors: {errors}")

combined = xr.concat(datasets_per_name, dim="name")
combined.attrs["response"] = "gen_data"
combined = polars.concat(datasets_per_name)
return combined

def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]]]:
Expand All @@ -173,5 +175,9 @@ def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]]
def response_type(self) -> str:
return "gen_data"

@property
def primary_key(self) -> List[str]:
return ["report_step", "index"]


responses_index.add_response_type(GenDataConfig)
53 changes: 35 additions & 18 deletions src/ert/config/observation_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Iterable, List, Union

import xarray as xr
import numpy as np

from .enkf_observation_implementation_type import EnkfObservationImplementationType
from .general_observation import GenObservation
Expand All @@ -12,6 +12,8 @@
if TYPE_CHECKING:
from datetime import datetime

import polars


@dataclass
class ObsVector:
Expand All @@ -27,28 +29,38 @@ def __iter__(self) -> Iterable[Union[SummaryObservation, GenObservation]]:
def __len__(self) -> int:
return len(self.observations)

def to_dataset(self, active_list: List[int]) -> xr.Dataset:
def to_dataset(self, active_list: List[int]) -> polars.DataFrame:
if self.observation_type == EnkfObservationImplementationType.GEN_OBS:
datasets = []
dataframes = []
for time_step, node in self.observations.items():
if active_list and time_step not in active_list:
continue

assert isinstance(node, GenObservation)
datasets.append(
xr.Dataset(
dataframes.append(
polars.DataFrame(
{
"observations": (["report_step", "index"], [node.values]),
"std": (["report_step", "index"], [node.stds]),
},
coords={"index": node.indices, "report_step": [time_step]},
"response_key": self.data_key,
"observation_key": self.observation_key,
"report_step": polars.Series(
np.full(len(node.indices), time_step),
dtype=polars.UInt16,
),
"index": polars.Series(node.indices, dtype=polars.UInt16),
"observations": polars.Series(
node.values, dtype=polars.Float32
),
"std": polars.Series(node.stds, dtype=polars.Float32),
}
)
)
combined = xr.combine_by_coords(datasets)
combined.attrs["response"] = self.data_key
return combined # type: ignore

combined = polars.concat(dataframes)
return combined
elif self.observation_type == EnkfObservationImplementationType.SUMMARY_OBS:
observations = []
actual_response_key = self.observation_key
actual_observation_keys = []
errors = []
dates = list(self.observations.keys())
if active_list:
Expand All @@ -57,15 +69,20 @@ def to_dataset(self, active_list: List[int]) -> xr.Dataset:
for time_step in dates:
n = self.observations[time_step]
assert isinstance(n, SummaryObservation)
actual_observation_keys.append(n.observation_key)
observations.append(n.value)
errors.append(n.std)
return xr.Dataset(

dates_series = polars.Series(dates).dt.cast_time_unit("ms")

return polars.DataFrame(
{
"observations": (["name", "time"], [observations]),
"std": (["name", "time"], [errors]),
},
coords={"time": dates, "name": [self.observation_key]},
attrs={"response": "summary"},
"response_key": actual_response_key,
"observation_key": actual_observation_keys,
"time": dates_series,
"observations": polars.Series(observations, dtype=polars.Float32),
"std": polars.Series(errors, dtype=polars.Float32),
}
)
else:
raise ValueError(f"Unknown observation type {self.observation_type}")
Loading
Loading