Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Helpful docstrings for the repo #11

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ docs/_rst/*
docs/_build/*
cover/*
MANIFEST
**/runs/**
**/logs/fit/**


# Per-project virtualenvs
.venv*/
Expand Down
304 changes: 211 additions & 93 deletions examples/ODE_Example_coupled_nonlin.ipynb

Large diffs are not rendered by default.

505 changes: 505 additions & 0 deletions examples/ODE_Example_coupled_nonlin_no_norm.ipynb

Large diffs are not rendered by default.

143 changes: 125 additions & 18 deletions examples/ODE_Lotka_Volterra.ipynb

Large diffs are not rendered by default.

99 changes: 64 additions & 35 deletions examples/PDE_2D_Advection-Diffusion.ipynb

Large diffs are not rendered by default.

723 changes: 667 additions & 56 deletions examples/PDE_Burgers.ipynb

Large diffs are not rendered by default.

93 changes: 70 additions & 23 deletions examples/PDE_KdV.ipynb

Large diffs are not rendered by default.

426 changes: 426 additions & 0 deletions examples/PDE_Keller_Segel.ipynb

Large diffs are not rendered by default.

182 changes: 182 additions & 0 deletions examples/test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import datetime"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"mnist = tf.keras.datasets.mnist\n",
"\n",
"(x_train, y_train),(x_test, y_test) = mnist.load_data()\n",
"x_train, x_test = x_train / 255.0, x_test / 255.0\n",
"\n",
"def create_model():\n",
" return tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28), name='layers_flatten'),\n",
" tf.keras.layers.Dense(512, activation='relu', name='layers_dense'),\n",
" tf.keras.layers.Dropout(0.2, name='layers_dropout'),\n",
" tf.keras.layers.Dense(10, activation='softmax', name='layers_dense_2')\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-10-19 13:22:06.121795: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.\n",
"2023-10-19 13:22:06.121833: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.\n",
"2023-10-19 13:22:06.506406: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.\n",
"2023-10-19 13:22:06.506605: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1748] CUPTI activity buffer flushed\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
" 1/1875 [..............................] - ETA: 6:06 - loss: 2.4212 - accuracy: 0.0312"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-10-19 13:22:06.974034: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.\n",
"2023-10-19 13:22:06.974072: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" 67/1875 [>.............................] - ETA: 12s - loss: 0.8793 - accuracy: 0.7411"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-10-19 13:22:07.251452: I tensorflow/core/profiler/lib/profiler_session.cc:66] Profiler session collecting data.\n",
"2023-10-19 13:22:07.251797: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1748] CUPTI activity buffer flushed\n",
"2023-10-19 13:22:07.263377: I tensorflow/core/profiler/internal/gpu/cupti_collector.cc:673] GpuTracer has collected 68 callback api events and 65 activity events. \n",
"2023-10-19 13:22:07.265665: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.\n",
"2023-10-19 13:22:07.269267: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: logs/fit/20231019-132206/train/plugins/profile/2023_10_19_13_22_07\n",
"\n",
"2023-10-19 13:22:07.271575: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to logs/fit/20231019-132206/train/plugins/profile/2023_10_19_13_22_07/hal5.trace.json.gz\n",
"2023-10-19 13:22:07.274743: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: logs/fit/20231019-132206/train/plugins/profile/2023_10_19_13_22_07\n",
"\n",
"2023-10-19 13:22:07.275741: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to logs/fit/20231019-132206/train/plugins/profile/2023_10_19_13_22_07/hal5.memory_profile.json.gz\n",
"2023-10-19 13:22:07.277781: I tensorflow/core/profiler/rpc/client/capture_profile.cc:251] Creating directory: logs/fit/20231019-132206/train/plugins/profile/2023_10_19_13_22_07\n",
"Dumped tool data for xplane.pb to logs/fit/20231019-132206/train/plugins/profile/2023_10_19_13_22_07/hal5.xplane.pb\n",
"Dumped tool data for overview_page.pb to logs/fit/20231019-132206/train/plugins/profile/2023_10_19_13_22_07/hal5.overview_page.pb\n",
"Dumped tool data for input_pipeline.pb to logs/fit/20231019-132206/train/plugins/profile/2023_10_19_13_22_07/hal5.input_pipeline.pb\n",
"Dumped tool data for tensorflow_stats.pb to logs/fit/20231019-132206/train/plugins/profile/2023_10_19_13_22_07/hal5.tensorflow_stats.pb\n",
"Dumped tool data for kernel_stats.pb to logs/fit/20231019-132206/train/plugins/profile/2023_10_19_13_22_07/hal5.kernel_stats.pb\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1875/1875 [==============================] - 6s 3ms/step - loss: 0.2217 - accuracy: 0.9342 - val_loss: 0.1039 - val_accuracy: 0.9682\n",
"Epoch 2/5\n",
"1875/1875 [==============================] - 5s 3ms/step - loss: 0.0990 - accuracy: 0.9698 - val_loss: 0.0788 - val_accuracy: 0.9745\n",
"Epoch 3/5\n",
"1875/1875 [==============================] - 5s 3ms/step - loss: 0.0702 - accuracy: 0.9784 - val_loss: 0.0767 - val_accuracy: 0.9763\n",
"Epoch 4/5\n",
"1875/1875 [==============================] - 5s 3ms/step - loss: 0.0544 - accuracy: 0.9829 - val_loss: 0.0713 - val_accuracy: 0.9784\n",
"Epoch 5/5\n",
"1875/1875 [==============================] - 5s 3ms/step - loss: 0.0424 - accuracy: 0.9858 - val_loss: 0.0723 - val_accuracy: 0.9800\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fb80fdd4f70>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = create_model()\n",
"model.compile(optimizer='adam',\n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])\n",
"\n",
"log_dir = \"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
"\n",
"model.fit(x=x_train, \n",
" y=y_train, \n",
" epochs=5, \n",
" validation_data=(x_test, y_test), callbacks=[tensorboard_callback]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"UsageError: Line magic function `%tensorboard` not found.\n"
]
}
],
"source": [
"%tensorboard --logdir logs/fit"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py3.10tf",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
68 changes: 62 additions & 6 deletions src/deepymod_torch/DeepMod.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,64 @@


class DeepMod(nn.Module):
''' Class based interface for deepmod.'''
'''Module subclass for data-driven discovery of partial differential equations.

This module implements a neural network architecture for discovering the governing equations of a system
from data. The architecture consists of a fully connected neural network followed by a library of candidate
functions and a sparse regression layer. The library of candidate functions is defined by a user-provided
function and its arguments.

Args:
n_in (int): Number of input features: the number of temporal+spatial dimensions.
hidden_dims (list of int): List of dimensions for the hidden layers of the neural network.
n_out (int): Number of output features (the number of equations to discover).
library_function (callable): Function that generates the library of candidate functions.
library_args (tuple or dict): Arguments to pass to the library function.

Attributes:
network (nn.Sequential): The fully connected neural network.
library (Library): The library of candidate functions.
fit (Fitting): The sparse regression layer.
'''
def __init__(self, n_in, hidden_dims, n_out, library_function, library_args):
super().__init__()
self.network = self.build_network(n_in, hidden_dims, n_out)
self.network = self.build_network(n_in, hidden_dims, n_out) # to make predictions about the dynamical field (variable)
self.library = Library(library_function, library_args)
self.fit = self.build_fit_layer(n_in, n_out, library_function, library_args)

def forward(self, input):
prediction = self.network(input)
time_deriv, theta = self.library((prediction, input))
sparse_theta, coeff_vector = self.fit(theta)
"""
Computes the forward pass of the DeepMoD model.

Args:
input (torch.Tensor): Input tensor (typically X_train) of shape (batch_size, input_dim).

Returns:
tuple: A tuple containing:
- prediction (torch.Tensor): Output tensor of shape (batch_size, output_dim).
- time_deriv (torch.Tensor): Time derivative tensor of shape (batch_size, output_dim).
- sparse_theta (torch.Tensor): Sparse theta tensor of shape (n_terms, input_dim).
- coeff_vector (torch.Tensor): Coefficient vector tensor of shape (n_terms, output_dim).
"""
prediction = self.network(input) # predict the fields as a given location (input)
time_deriv, theta = self.library((prediction, input)) # library function returns time_deriv and theta (equation (4) of the manuscript)
sparse_theta, coeff_vector = self.fit(theta) # Note this attribute `fit` of type `Fitting` not a method of NN
# sparse_theta is theta with sparsity mask applied (extracting relevant terms)
# coeff_vector will play role in the loss function (see `losses.py`) which explains how it is optimized
return prediction, time_deriv, sparse_theta, coeff_vector

def build_network(self, n_in, hidden_dims, n_out):
"""
Builds a neural network with the specified number of input, hidden, and output nodes.

Args:
n_in (int): Number of input nodes.
hidden_dims (list): List of integers specifying the number of nodes in each hidden layer.
n_out (int): Number of output nodes.

Returns:
network (nn.Sequential): A PyTorch sequential neural network object.
"""
# NN
network = []
hs = [n_in] + hidden_dims + [n_out]
Expand All @@ -30,8 +74,20 @@ def build_network(self, n_in, hidden_dims, n_out):
return network

def build_fit_layer(self, n_in, n_out, library_function, library_args):
"""
Builds and returns a Fitting layer for the DeepMoD model which is basically the sparse regression layer which applies sparsity mask

Args:
n_in (int): Number of input features.
n_out (int): Number of output features.
library_function (callable): Function that generates the library.
library_args (dict): Arguments to pass to the library function.

Returns:
Fitting: A Fitting layer with the appropriate number of terms for the given input and output sizes.
"""
sample_input = torch.ones((1, n_in), dtype=torch.float32, requires_grad=True)
n_terms = self.library((self.network(sample_input), sample_input))[1].shape[1] # do sample pass to infer shapes
n_terms = self.library((self.network(sample_input), sample_input))[1].shape[1] # do sample pass to infer shapes: number of terms in the equation
fit_layer = Fitting(n_terms, n_out)

return fit_layer
Expand Down
72 changes: 55 additions & 17 deletions src/deepymod_torch/library_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import reduce

def library_poly(prediction, max_order):
# Calculate the polynomes of u
# Calculate the polynomials of u (technically these are monomials)
u = torch.ones_like(prediction)
for order in np.arange(1, max_order+1):
u = torch.cat((u, u[:, order-1:order] * prediction), dim=1)
Expand All @@ -14,44 +14,82 @@ def library_poly(prediction, max_order):


def library_deriv(data, prediction, max_order):
dy = grad(prediction, data, grad_outputs=torch.ones_like(prediction), create_graph=True)[0]
time_deriv = dy[:, 0:1]
"""
Computes the time derivative and up to max_order spatial derivatives of the prediction tensor with respect to the data tensor.

if max_order == 0:
Args:
data (torch.Tensor): Input tensor of shape (batch_size, input_dim). Example: X_train.
prediction (torch.Tensor): Output tensor of shape (batch_size, output_dim). Example: y_train_pred.
max_order (int): Maximum order of spatial derivatives to compute.

Returns:
time_deriv (torch.Tensor): Time derivative of the prediction tensor with respect to the data tensor.
du (torch.Tensor): Tensor of shape (batch_size, max_order+1) containing the computed spatial derivatives.
"""

dy = grad(prediction, data, grad_outputs=torch.ones_like(prediction), create_graph=True)[0] # Calculate first order derivatives of prediction with respect to data
time_deriv = dy[:, 0:1] # First column is time derivative

if max_order == 0: # If we only want the time derivative, du is just a scalar
du = torch.ones_like(time_deriv)
else:
du = torch.cat((torch.ones_like(time_deriv), dy[:, 1:2]), dim=1)
if max_order >1:
else: # Else we calculate the spatial derivatives
du = torch.cat((torch.ones_like(time_deriv), dy[:, 1:2]), dim=1) # second column of dy gives first order derivative
if max_order > 1: # If we want higher order derivatives, we calculate them successively and concatenate them to du
for order in np.arange(1, max_order):
du = torch.cat((du, grad(du[:, order:order+1], data, grad_outputs=torch.ones_like(prediction), create_graph=True)[0][:, 1:2]), dim=1)

return time_deriv, du


def library_1D_in(input, poly_order, diff_order):
"""
Computes the library matrix for a spatial 1D input signal, given the input data, the maximum polynomial order and the maximum derivative order.

Parameters
----------
input : tuple of two torch.Tensor
A tuple containing the prediction tensor and the data tensor, both of shape (samples, features).
poly_order : int
The maximum polynomial order to include in the library.
diff_order : int
The maximum derivative order to include in the library.

Returns
-------
time_deriv_list : list of torch.Tensor
A list containing the time derivative tensors for each output feature, each of shape (samples, 1).
theta : torch.Tensor
The library matrix, of shape (samples, total_terms), where total_terms is the total number of terms in the library.
when poly_order=2 and diff_order=3 and we have a single output the theta matrix has columns:
['', 'u_x', 'u_xx', 'u_xxx', 'u', 'uu_x', 'uu_xx', 'uu_xxx', 'u^2', 'u^2u_x', 'u^2u_xx', 'u^2u_xxx']
For more details run utilities.terms_definition()
"""
prediction, data = input
poly_list = []
deriv_list = []
time_deriv_list = []

# Creating lists for all outputs
for output in torch.arange(prediction.shape[1]):
# Creating lists for all outputs
for output in torch.arange(prediction.shape[1]): # loop over all dynamical fields modelled by PDE (in case we have system of PDEs, i.e. more than one dynamical field)
time_deriv, du = library_deriv(data, prediction[:, output:output+1], diff_order)
u = library_poly(prediction[:, output:output+1], poly_order)

poly_list.append(u)
deriv_list.append(du)
time_deriv_list.append(time_deriv)

samples = time_deriv_list[0].shape[0]
total_terms = poly_list[0].shape[1] * deriv_list[0].shape[1]
samples = time_deriv_list[0].shape[0] # number of samples
total_terms = poly_list[0].shape[1] * deriv_list[0].shape[1] # product of the number of possible polynomials (i.e. monomials) and the number of derivative terms

# Calculating theta
if len(poly_list) == 1:
theta = torch.matmul(poly_list[0][:, :, None], deriv_list[0][:, None, :]).view(samples, total_terms) # If we have a single output, we simply calculate and flatten matrix product between polynomials and derivatives to get library
# Calculating theta matrix (equation (4) of the manuscript)
if len(poly_list) == 1: # If we have a single output (one dynamical field modelled by the PDE), we simply calculate and flatten matrix product between polynomials and derivatives to get library
theta = torch.matmul(poly_list[0][:, :, None], deriv_list[0][:, None, :]).view(samples, total_terms)
# For each sample poly_list[0][each_sample, :] and deriv_list[0][each_sample, :] the above line is equivalent to np.multiply.outer(poly_list[0][each_sample, :],deriv_list[0][each_sample, :] ).reshape(-1)
# so the logic of the expression can be understood by executing np.add.outer(np.array(['', 'u', 'u^2'], object),np.array(['', 'u_x', 'u_xx','u_xxx'], object)).reshape(-1) <- this is consistent with equation (4)
# this means that we iterate over deriv_list first (fast index) and then over poly_list (slow index)
# this gives, for example: ['', 'u_x', 'u_xx', 'u_xxx', 'u', 'uu_x', 'uu_xx', 'uu_xxx', 'u^2', 'u^2u_x', 'u^2u_xx', 'u^2u_xxx']
else:

theta_uv = reduce((lambda x, y: (x[:, :, None] @ y[:, None, :]).view(samples, -1)), poly_list)
theta_uv = reduce((lambda x, y: (x[:, :, None] @ y[:, None, :]).view(samples, -1)), poly_list) # TODO comment the following lines
theta_dudv = torch.cat([torch.matmul(du[:, :, None], dv[:, None, :]).view(samples, -1)[:, 1:] for du, dv in combinations(deriv_list, 2)], 1) # calculate all unique combinations of derivatives
theta_udu = torch.cat([torch.matmul(u[:, 1:, None], du[:, None, 1:]).view(samples, (poly_list[0].shape[1]-1) * (deriv_list[0].shape[1]-1)) for u, dv in product(poly_list, deriv_list)], 1) # calculate all unique products of polynomials and derivatives
theta = torch.cat([theta_uv, theta_dudv, theta_udu], dim=1)
Expand Down
Loading