Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609860237
  • Loading branch information
The swirl_dynamics Authors committed Feb 27, 2024
1 parent 1a32c7f commit f2a01eb
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 1 deletion.
2 changes: 1 addition & 1 deletion swirl_dynamics/projects/evolve_smoothly/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Abstract

We present a data-driven, space-time continuous framework to learn surrogate models for complex physical systems described by advection-dominated partial differential equations. Those systems have slow-decaying Kolmogorov $n$-width that hinders standard methods, including reduced order modeling, from producing high-fidelity simulations at low cost. In this work, we construct hypernetwork-based latent dynamical models directly on the parameter space of a compact representation network. We leverage the expressive power of the network and a specially designed consistency-inducing regularization to obtain latent trajectories that are both low-dimensional and smooth. These properties render our surrogate models highly efficient at inference time. We show the efficacy of our framework by learning models that generate accurate multi-step rollout predictions at much faster inference speed compared to competitors, for several challenging examples.
We present a data-driven, space-time continuous framework to learn surrogate models for complex physical systems described by advection-dominated partial differential equations. Those systems have slow-decaying Kolmogorov $$n$$-width that hinders standard methods, including reduced order modeling, from producing high-fidelity simulations at low cost. In this work, we construct hypernetwork-based latent dynamical models directly on the parameter space of a compact representation network. We leverage the expressive power of the network and a specially designed consistency-inducing regularization to obtain latent trajectories that are both low-dimensional and smooth. These properties render our surrogate models highly efficient at inference time. We show the efficacy of our framework by learning models that generate accurate multi-step rollout predictions at much faster inference speed compared to competitors, for several challenging examples.

## Datasets

Expand Down
163 changes: 163 additions & 0 deletions swirl_dynamics/projects/weno_nn/colabs/init_network.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OYIaUeqU_GoF"
},
"outputs": [],
"source": [
"!pip install git+https://github.com/google-research/swirl-dynamics.git@main"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "t-c9iNBAy75u"
},
"outputs": [],
"source": [
"#@title Imports\n",
"from typing import Any, Literal, Optional\n",
"import flax.linen as nn\n",
"import functools\n",
"import jax\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"from swirl_dynamics.data import hdf5_utils\n",
"from swirl_dynamics.lib.networks import rational_networks\n",
"from swirl_dynamics.projects.weno_nn import weno\n",
"from swirl_dynamics.projects.weno_nn import weno_nn\n",
"from swirl_dynamics.projects.weno_nn import utils"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kF-DZ55uzdp6"
},
"outputs": [],
"source": [
"flax_model_main_folder = '../model_weights/'\n",
"xid=94741459\n",
"model_num=113\n",
"filename = flax_model_main_folder+f'/xid_{xid}_model_{model_num}.hdf5'\n",
"network_vars = hdf5_utils.read_all_arrays_as_dict(filename)\n",
"mlp_model = weno_nn.OmegaNN(\n",
" features=tuple(network_vars['config']['features'].astype(int)),\n",
" features_fun=utils.get_feature_func(network_vars['config']['features_fun'].decode()),\n",
" act_fun=utils.get_act_func(network_vars['config']['act_fun'].decode()),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dLBi5s2GbwOz"
},
"outputs": [],
"source": [
"x=np.linspace(0.0, 1.0, 101)\n",
"u=np.sin(np.pi*x)\n",
"# Stack neighbor information for [u_{i-1}, u_{i}, u_{i+1}].\n",
"u_nb=np.stack([u[:-2], u[1:-1], u[2:]], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tC_fW3IicqHV"
},
"outputs": [],
"source": [
"# Individual functions are written over scalar inputs of\n",
"# [u_{i-1}, u_{i}, u_{i+1}]. Hence we vectorize over the axis for u_nb above.\n",
"# Function to perform interpolation:\n",
"weno_interp_func_vmap = jax.vmap(weno.weno_interpolation, in_axes=(0, None))\n",
"model_apply_func = functools.partial(mlp_model.apply, test=True)\n",
"# Function to calculate WENO-weights:\n",
"weno_nn_wt_func_vmap = jax.vmap(model_apply_func, in_axes=(None,0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lK3lUqPmd1hp"
},
"outputs": [],
"source": [
"# Estimate WENO-weights on the negative side:\n",
"wt_neg = weno_nn_wt_func_vmap({\"params\": network_vars[\"params\"]}, u_nb)\n",
"# Perform WENO interpolation on both positive and negative sides.\n",
"u_interp = weno_interp_func_vmap(\n",
" u_nb,\n",
" lambda x, params: model_apply_func({\"params\": network_vars[\"params\"]}, x),\n",
")\n",
"# Unstack the positive and negative side interpolations.\n",
"u_interp_pos = u_interp[:, 0]\n",
"u_interp_neg = u_interp[:, 1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2U3qVwNSeGE0"
},
"outputs": [],
"source": [
"plt.figure()\n",
"plt.plot(x[1:-1], wt_neg[:,0]); plt.ylim([0.0,1.0]);\n",
"plt.plot(x[1:-1], wt_neg[:,1]); plt.ylim([0.0,1.0]);\n",
"plt.xlabel('X'); plt.ylabel('WENO Weights');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cXdHHtYVe5Jy"
},
"outputs": [],
"source": [
"x_half = x+(x[1]-x[0])*0.5\n",
"plt.figure()\n",
"plt.plot(x_half[1:-1], u_interp_pos, '-.b', label='Pos');\n",
"plt.plot(x_half[1:-1], u_interp_neg, '-.g', label='Neg');\n",
"plt.plot(x, u, '--r', label='Cell');\n",
"plt.xlabel('X'); plt.ylabel('WENO Weight'); plt.legend();"
]
}
],
"metadata": {
"colab": {
"last_runtime": {
"build_target": "//research/simulation/tools:notebook",
"kind": "shared"
},
"provenance": [
{
"file_id": "12h2QSHNj9FUYRqR6axIbK9dFZoQ29-IH",
"timestamp": 1707952544896
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit f2a01eb

Please sign in to comment.