Skip to content

Commit

Permalink
add dimensions parameter to plugin interfaces; test plugin system (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
dionhaefner authored Aug 11, 2022
1 parent a0b0aa0 commit 7b8ef3e
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 19 deletions.
71 changes: 71 additions & 0 deletions test/plugin_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest

from veros.plugins import load_plugin
from veros.routines import veros_routine
from veros.state import get_default_state
from veros.variables import Variable
from veros.settings import Setting


@pytest.fixture
def fake_plugin():
class FakePlugin:
pass

def run_setup(state):
plugin._setup_ran = True

def run_main(state):
plugin._main_ran = True

plugin = FakePlugin()
plugin.__name__ = "foobar"
plugin._setup_ran = False
plugin._main_ran = False
plugin.__VEROS_INTERFACE__ = {
"name": "foo",
"setup_entrypoint": run_setup,
"run_entrypoint": run_main,
"settings": dict(mydimsetting=Setting(15, int, "bar")),
"variables": dict(myvar=Variable("myvar", ("xt", "yt", "mydim"))),
"dimensions": dict(mydim="mydimsetting"),
"diagnostics": [],
}
yield plugin


def test_load_plugin(fake_plugin):
plugin_interface = load_plugin(fake_plugin)
assert plugin_interface.name == "foo"


def test_state_plugin(fake_plugin):
plugin_interface = load_plugin(fake_plugin)
state = get_default_state(plugin_interfaces=plugin_interface)
assert "mydimsetting" in state.settings
assert "mydim" in state.dimensions
assert state.dimensions["mydim"] == state.settings.mydimsetting
state.initialize_variables()
assert "myvar" in state.variables
assert state.variables.myvar.shape == (4, 4, state.settings.mydimsetting)


def test_run_plugin(fake_plugin):
from veros.setups.acc_basic import ACCBasicSetup

class FakeSetup(ACCBasicSetup):
__veros_plugins__ = (fake_plugin,)

@veros_routine
def set_diagnostics(self, state):
pass

setup = FakeSetup(override=dict(dt_tracer=100, runlen=100))

assert not fake_plugin._setup_ran
setup.setup()
assert fake_plugin._setup_ran

assert not fake_plugin._main_ran
setup.run()
assert fake_plugin._main_ran
4 changes: 4 additions & 0 deletions test/setup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ def set_options():
from veros import runtime_settings

object.__setattr__(runtime_settings, "diskless_mode", True)
try:
yield
finally:
object.__setattr__(runtime_settings, "diskless_mode", False)


@pytest.mark.parametrize("float_type", ("float32", "float64"))
Expand Down
13 changes: 10 additions & 3 deletions veros/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"run_entrypoint",
"settings",
"variables",
"dimensions",
"diagnostics",
],
)
Expand All @@ -37,18 +38,23 @@ def load_plugin(module):
if not callable(run_entrypoint):
raise RuntimeError(f"module {modname} is missing a valid run entrypoint")

name = interface.get("name", module.__name__)
name = interface.get("name", modname)

settings = interface.get("settings", [])
settings = interface.get("settings", {})
for setting, val in settings.items():
if not isinstance(val, Setting):
raise TypeError(f"got unexpected type {type(val)} for setting {setting}")

variables = interface.get("variables", [])
variables = interface.get("variables", {})
for variable, val in variables.items():
if not isinstance(val, Variable):
raise TypeError(f"got unexpected type {type(val)} for variable {variable}")

dimensions = interface.get("dimensions", {})
for dim, val in dimensions.items():
if not isinstance(val, (str, int)):
raise TypeError(f"got unexpected type {type(val)} for dimension {dim}")

diagnostics = interface.get("diagnostics", [])
for diagnostic in diagnostics:
if not issubclass(diagnostic, VerosDiagnostic):
Expand All @@ -61,5 +67,6 @@ def load_plugin(module):
run_entrypoint=run_entrypoint,
settings=settings,
variables=variables,
dimensions=dimensions,
diagnostics=diagnostics,
)
20 changes: 10 additions & 10 deletions veros/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,24 +438,24 @@ def to_xarray(self):
return xr.Dataset(data_vars, coords=coords, attrs=attrs)


def get_default_state(use_plugins=None):
if use_plugins is not None:
plugin_interfaces = tuple(plugins.load_plugin(p) for p in use_plugins)
else:
plugin_interfaces = tuple()

default_settings = deepcopy(settings_mod.SETTINGS)
def get_default_state(plugin_interfaces=()):
if isinstance(plugin_interfaces, plugins.VerosPlugin):
plugin_interfaces = [plugin_interfaces]

for plugin in plugin_interfaces:
default_settings.update(plugin.settings)
if not isinstance(plugin, plugins.VerosPlugin):
raise TypeError(f"Got unexpected type {type(plugin)}")

default_dimensions = deepcopy(var_mod.DIM_TO_SHAPE_VAR)
settings = deepcopy(settings_mod.SETTINGS)
dimensions = deepcopy(var_mod.DIM_TO_SHAPE_VAR)
var_meta = deepcopy(var_mod.VARIABLES)

for plugin in plugin_interfaces:
settings.update(plugin.settings)
var_meta.update(plugin.variables)
dimensions.update(plugin.dimensions)

return VerosState(var_meta, default_settings, default_dimensions, plugin_interfaces=plugin_interfaces)
return VerosState(var_meta, settings, dimensions, plugin_interfaces=plugin_interfaces)


def veros_state_pytree_flatten(state):
Expand Down
2 changes: 1 addition & 1 deletion veros/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(

self.get_mask = mask

elif dims is not None:
elif isinstance(dims, tuple):
if dims[:3] in DEFAULT_MASKS:
self.get_mask = DEFAULT_MASKS[dims[:3]]
elif dims[:2] in DEFAULT_MASKS:
Expand Down
10 changes: 5 additions & 5 deletions veros/veros.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class VerosSetup(metaclass=abc.ABCMeta):
This class is meant to be subclassed. Subclasses need to implement the
methods :meth:`set_parameter`, :meth:`set_topography`, :meth:`set_grid`,
:meth:`set_coriolis`, :meth:`set_initial_conditions`, :meth:`set_forcing`,
and :meth:`set_diagnostics`.
:meth:`set_diagnostics`, and :meth:`after_timestep`.
Example:
>>> import matplotlib.pyplot as plt
Expand All @@ -42,7 +42,7 @@ def __init__(self, override=None):
self._plugin_interfaces = tuple(load_plugin(p) for p in self.__veros_plugins__)
self._setup_done = False

self.state = get_default_state(use_plugins=self.__veros_plugins__)
self.state = get_default_state(plugin_interfaces=self._plugin_interfaces)

@abc.abstractmethod
def set_parameter(self, state):
Expand Down Expand Up @@ -142,7 +142,7 @@ def set_forcing(self, state):
pass

@abc.abstractmethod
def set_diagnostics(self, vs):
def set_diagnostics(self, state):
"""To be implemented by subclass.
Called before setting up the :ref:`diagnostics <diagnostics>`. Use this method e.g. to
Expand Down Expand Up @@ -204,7 +204,7 @@ def setup(self):

self.state.diagnostics.update(diagnostics.create_default_diagnostics(self.state))

for plugin in self.state.plugin_interfaces:
for plugin in self._plugin_interfaces:
for diagnostic in plugin.diagnostics:
self.state.diagnostics[diagnostic.name] = diagnostic()

Expand Down Expand Up @@ -413,7 +413,7 @@ def _timing_summary(self):
timing_summary.extend(
[
" {:<22} = {:.2f}s".format(plugin.name, self.state.timers[plugin.name].total_time)
for plugin in self.state._plugin_interfaces
for plugin in self._plugin_interfaces
]
)

Expand Down

0 comments on commit 7b8ef3e

Please sign in to comment.