Skip to content

Commit

Permalink
Allow setting of array in highlevel.Message.set (#109)
Browse files Browse the repository at this point in the history
* Set array
* Tidy value checking
* Test set_array and set
  • Loading branch information
jinmannwong authored Nov 25, 2024
1 parent 6e01ef6 commit 3da45c0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
20 changes: 10 additions & 10 deletions eccodes/highlevel/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import io
from contextlib import contextmanager

import numpy as np

import eccodes

_TYPES_MAP = {
Expand Down Expand Up @@ -95,23 +97,21 @@ def set(self, *args, check_values: bool = True):

for name, value in key_values.items():
with raise_keyerror(name):
eccodes.codes_set(self._handle, name, value)
if np.ndim(value) > 0:
eccodes.codes_set_array(self._handle, name, value)
else:
eccodes.codes_set(self._handle, name, value)

if check_values:
# Check values just set
for name, value in key_values.items():
cast_value = value
if isinstance(value, str):
saved_value = eccodes.codes_get_string(self._handle, name)
elif isinstance(value, int):
saved_value = eccodes.codes_get_long(self._handle, name)
if type(value) in _TYPES_MAP.values():
saved_value = self.get(f"{name}:{type(value).__name__}")
else:
saved_value = self.get(name)
if not isinstance(value, type(saved_value)):
cast_value = type(saved_value)(value)
if saved_value != cast_value:
if not np.all(saved_value == value):
raise ValueError(
f"Unexpected retrieved value {saved_value} for key {name}. Expected {cast_value}"
f"Unexpected retrieved value {saved_value} for key {name}. Expected {value}"
)

def get_array(self, name):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def test_message_set_plain():
message.set("centre", "ecmf")
vals = np.arange(message.get("numberOfValues"), dtype=np.float32)
message.set_array("values", vals)
assert np.all(message.get("values") == vals)
message.set("values", vals)
message.set_missing(missing_key)
assert message.get("centre") == "ecmf"
assert np.all(message.get("values") == vals)
Expand Down

0 comments on commit 3da45c0

Please sign in to comment.