Skip to content

Commit

Permalink
Maybe better?
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Sep 4, 2024
1 parent e89c854 commit 05ec1d9
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
7 changes: 5 additions & 2 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class Refcase:
values: List[Any]

def __post_init__(self):

Check failure on line 60 in src/ert/config/ensemble_config.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
self.values = self.values.tolist()
if hasattr(self.values, "tolist"):
self.values = self.values.tolist()

def __eq__(self, other: object) -> bool:
if not isinstance(other, Refcase):
Expand All @@ -78,7 +79,9 @@ def all_dates(self) -> List[datetime]:
@dataclass
class EnsembleConfig:
grid_file: Optional[str] = None
response_configs: Dict[str, ResponseConfig] = field(default_factory=dict)
response_configs: Dict[str, Union[SummaryConfig, GenDataConfig]] = field(
default_factory=dict
)
parameter_configs: Dict[str, ParameterConfig] = field(default_factory=dict)
refcase: Optional[Refcase] = None
eclbase: Optional[str] = None
Expand Down
6 changes: 2 additions & 4 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ class ErtConfig:
DEFAULT_RUNPATH_FILE: ClassVar[str] = ".ert_runpath_list"
PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[Dict[str, ForwardModelStep]] = {}

substitution_list: Union[SubstitutionList, Dict[str, str]] = field(
default_factory=dict
)
substitution_list: SubstitutionList = field(default_factory=SubstitutionList)
ensemble_config: EnsembleConfig = field(default_factory=EnsembleConfig)
ens_path: str = DEFAULT_ENSPATH
env_vars: Dict[str, str] = field(default_factory=dict)
Expand All @@ -112,7 +110,7 @@ class ErtConfig:
Tuple[str, Union[HistoryValues, SummaryValues, GenObsValues]]
] = field(default_factory=list)

@field_validator("substitution_list")
@field_validator("substitution_list", mode="before")
@classmethod
def convert_to_substitution_list(cls, v: Dict[str, str]) -> SubstitutionList:
if isinstance(v, SubstitutionList):
Expand Down
42 changes: 37 additions & 5 deletions src/ert/substitution_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import re
from typing import TYPE_CHECKING, Any, Optional

from pydantic import GetCoreSchemaHandler
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
from pydantic_core.core_schema import CoreSchema

logger = logging.getLogger(__name__)
_PATTERN = re.compile("<[^<>]+>")
Expand Down Expand Up @@ -82,9 +82,41 @@ def __str__(self) -> str:

@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(cls, handler(str))
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def _serialize(instance: Any, info: Any) -> Any:
# if info.mode == 'json':
# # return json.dumps(dict(instance))
# return dict(instance)
return dict(instance)

from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(cls),
]
)

return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
from_str_schema,
core_schema.is_instance_schema(cls),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
_serialize, info_arg=True
),
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.str_schema())


def _replace_strings(subst_list: SubstitutionList, string: str) -> Optional[str]:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/config/test_ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def test_that_ert_config_is_serializable(tmp_path_factory, config_generator):
config_values.to_config_dict("config.ert", os.getcwd())
)
config_json = json.loads(RootModel[ErtConfig](ert_config).model_dump_json())
from_json = ErtConfig(config_json)
from_json = ErtConfig(**config_json)
assert from_json == ert_config


Expand Down

0 comments on commit 05ec1d9

Please sign in to comment.