Skip to content

Commit

Permalink
Merge pull request #38 from jrmaddison/jrmaddison/network_update
Browse files Browse the repository at this point in the history
Update `Dynamics`
  • Loading branch information
jrmaddison authored Dec 3, 2024
2 parents fed075f + bea78a4 commit e23bdc6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
10 changes: 9 additions & 1 deletion bt_ocean/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Dynamics(keras.layers.Layer):
update : callable
Passed `dynamics` and any arguments defined by `args`, and should
update the state of `dynamics`. Evaluated before taking each timestep.
Must not modify elements of `args` and must have no side effects.
args : tuple
Passed as remaining arguments to `update`.
N : Integral
Expand All @@ -122,6 +123,13 @@ class Dynamics(keras.layers.Layer):
Weight by which to scale each input.
output_weight : Real or :class:`jax.Array`
Weight by which to scale each output.
Warnings
--------
The `update` callable can modify the `dynamics` argument, but must not
change any elements of `args`, and must have no side effects. This e.g.
means that batch normalization cannot be used in a nested neural network.
"""

_update_registry = {}
Expand All @@ -131,7 +139,7 @@ def __init__(self, dynamics, update, *args, N=1, n_output=1,
if "dtype" not in kwargs:
kwargs["dtype"] = dynamics.grid.fdtype
super().__init__(**kwargs)
self.__dynamics = dynamics
self.__dynamics = dynamics.new(copy_prescribed=True)
self.__update = update
self.__args = args
self.__N = N
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/1_keras_integration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"\n",
"This notebook describes the combination of bt_ocean with Keras.\n",
"\n",
"bt_ocean allows a neural network to be used to define an arbitrary right-hand-side forcing term, which can then be trained using time-dependent data. Here, to demonstrate the principles, we will consider an extremely simple case of a Keras model consisting of a single layer which simply outputs the degrees of freedom for a function – so that training reduces to the problem of finding this function. Specifically we will try to find the wind forcing term $Q$ used on the right-hand-side of the barotropic vorticity equation. This toy problem demonstrates the key ideas – while remaining small enough to run quickly!"
"bt_ocean allows a neural network to be used to define a right-hand-side forcing term, which can then be trained using time-dependent data. Here, to demonstrate the principles, we will consider an extremely simple case of a Keras model consisting of a single layer which simply outputs the degrees of freedom for a function – so that training reduces to the problem of finding this function. Specifically we will try to find the wind forcing term $Q$ used on the right-hand-side of the barotropic vorticity equation. This toy problem demonstrates the key ideas – while remaining small enough to run quickly!"
]
},
{
Expand Down Expand Up @@ -217,7 +217,7 @@
"source": [
"## Increasing the complexity\n",
"\n",
"Here we have used Keras to solve a standard variational optimization problem, by defining a very simple Keras model. However we can make the Keras model `Q_network` arbitrarily complex, and can also use the `Dynamics` layer itself as part of a more complicated 'outer' Keras model. That is, we can embed neural networks within bt_ocean, and can also embed bt_ocean within neural networks. Since we also have access to the full functionality of JAX, we can define significantly more complex Keras models, and use these for 'online' training of neural networks, constrained by the dynamics."
"Here we have used Keras to solve a standard variational optimization problem, by defining a very simple Keras model. However we can make the Keras model `Q_network` much more complex, and can also use the `Dynamics` layer itself as part of a more complicated 'outer' Keras model. That is, we can embed neural networks within bt_ocean, and can also embed bt_ocean within neural networks. The main restriction is that use of the embedded neural network (by the `update` callable in this example) can only change the `dynamics` argument, but cannot change other arguments (here `Q_network`) or have other side effects. This means for example that evaluating the embedded neural network cannot change the neural network itself – as occurs e.g. with batch normalization."
]
}
],
Expand Down
8 changes: 5 additions & 3 deletions tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def test_dynamics_roundtrip(tmp_path):
Q_input_layer = keras.layers.Input((model.grid.N_x + 1, model.grid.N_y + 1))
Q_network = keras.models.Model(inputs=Q_input_layer, outputs=Q_input_layer)

# Used to test that Q_callback is being called, although technically
# Q_callback is not allowed to have any side effects
n_calls = 0

@Dynamics.register_update("test_dynamics_roundtrip_Q_callback")
Expand All @@ -114,7 +116,7 @@ def Q_callback(dynamics, Q_network):
input_model = dynamics_network.layers[1]._Dynamics__dynamics

assert type(input_model) is type(model)
assert input_model.n == model.n
assert input_model.n == 0

assert input_model.grid.L_x == model.grid.L_x
assert input_model.grid.L_y == model.grid.L_y
Expand All @@ -128,8 +130,8 @@ def Q_callback(dynamics, Q_network):
assert input_model.parameters[key] == value

assert set(input_model.fields) == set(model.fields)
for key, value in model.fields.items():
assert (input_model.fields[key] == value).all()
for key in model.prescribed_field_keys:
assert (input_model.fields[key] == model.fields[key]).all()

dynamics_network(jnp.zeros((1, model.grid.N_x + 1, model.grid.N_y + 1)))
assert n_calls == 2

0 comments on commit e23bdc6

Please sign in to comment.