From 4687046afb76cd236314ea190c0f3011d558b99b Mon Sep 17 00:00:00 2001 From: Mathias Lechner Date: Fri, 28 Jul 2023 11:53:19 -0400 Subject: [PATCH] update pytorch docs --- docs/examples/torch_first_steps.rst | 4 +- examples/pt_example.py | 136 +++++++++++++--------------- 2 files changed, 68 insertions(+), 72 deletions(-) diff --git a/docs/examples/torch_first_steps.rst b/docs/examples/torch_first_steps.rst index 689b060..9509584 100644 --- a/docs/examples/torch_first_steps.rst +++ b/docs/examples/torch_first_steps.rst @@ -115,6 +115,9 @@ For the wiring we will use the ```AutoNCP`` class, which creates a NCP wiring di .. code-block:: python + out_features = 1 + in_features = 2 + wiring = AutoNCP(16, out_features) # 16 units, 1 motor neuron ltc_model = LTC(in_features, wiring, batch_first=True) @@ -123,7 +126,6 @@ For the wiring we will use the ```AutoNCP`` class, which creates a NCP wiring di logger=pl.loggers.CSVLogger("log"), max_epochs=400, gradient_clip_val=1, # Clip gradient to stabilize training - gpus=0, ) Draw the wiring diagram of the network diff --git a/examples/pt_example.py b/examples/pt_example.py index 70d4654..9e82c98 100644 --- a/examples/pt_example.py +++ b/examples/pt_example.py @@ -1,37 +1,42 @@ -# Copyright (2017-2021) -# The Wormnet project -# Mathias Lechner (mlechner@ist.ac.at) import numpy as np import torch.nn as nn -import kerasncp as kncp -from kerasncp.torch import LTCCell +from ncps.wirings import AutoNCP +from ncps.torch import LTC import pytorch_lightning as pl import torch import torch.utils.data as data -# nn.Module that unfolds a RNN cell into a sequence -class RNNSequence(nn.Module): - def __init__( - self, - rnn_cell, - ): - super(RNNSequence, self).__init__() - self.rnn_cell = rnn_cell - - def forward(self, x): - device = x.device - batch_size = x.size(0) - seq_len = x.size(1) - hidden_state = torch.zeros( - (batch_size, self.rnn_cell.state_size), device=device - ) - outputs = [] - for t in range(seq_len): - inputs = x[:, t] - new_output, hidden_state = self.rnn_cell.forward(inputs, hidden_state) - outputs.append(new_output) - outputs = torch.stack(outputs, dim=1) # return entire sequence - return outputs +import matplotlib.pyplot as plt +import seaborn as sns + +N = 48 # Length of the time-series +out_features = 1 +in_features = 2 +# Input feature is a sine and a cosine wave +data_x = np.stack( + [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], axis=1 +) +data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension +# Target output is a sine with double the frequency of the input signal +data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) +print("data_x.shape: ", str(data_x.shape)) +print("data_y.shape: ", str(data_y.shape)) +data_x = torch.Tensor(data_x) +data_y = torch.Tensor(data_y) +dataloader = data.DataLoader( + data.TensorDataset(data_x, data_y), batch_size=1, shuffle=True, num_workers=4 +) + +# Let's visualize the training data +sns.set() +plt.figure(figsize=(6, 4)) +plt.plot(data_x[0, :, 0], label="Input feature 1") +plt.plot(data_x[0, :, 1], label="Input feature 1") +plt.plot(data_y[0, :, 0], label="Target output") +plt.ylim((-1, 1)) +plt.title("Training data") +plt.legend(loc="upper right") +plt.savefig("pt_plot1.png") # LightningModule for training a RNNSequence module @@ -43,7 +48,7 @@ def __init__(self, model, lr=0.005): def training_step(self, batch, batch_idx): x, y = batch - y_hat = self.model.forward(x) + y_hat, _ = self.model.forward(x) y_hat = y_hat.view_as(y) loss = nn.MSELoss()(y_hat, y) self.log("train_loss", loss, prog_bar=True) @@ -51,7 +56,7 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): x, y = batch - y_hat = self.model.forward(x) + y_hat, _ = self.model.forward(x) y_hat = y_hat.view_as(y) loss = nn.MSELoss()(y_hat, y) @@ -65,52 +70,41 @@ def test_step(self, batch, batch_idx): def configure_optimizers(self): return torch.optim.Adam(self.model.parameters(), lr=self.lr) - def optimizer_step( - self, - current_epoch, - batch_nb, - optimizer, - optimizer_idx, - closure, - on_tpu=False, - using_native_amp=False, - using_lbfgs=False, - ): - optimizer.optimizer.step(closure=closure) - # Apply weight constraints - self.model.rnn_cell.apply_weight_constraints() +wiring = AutoNCP(16, out_features) # 16 units, 1 motor neuron -in_features = 2 -out_features = 1 -N = 48 # Length of the time-series -# Input feature is a sine and a cosine wave -data_x = np.stack( - [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], axis=1 -) -data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension -# Target output is a sine with double the frequency of the input signal -data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) -data_x = torch.Tensor(data_x) -data_y = torch.Tensor(data_y) -print("data_y.shape: ", str(data_y.shape)) - -wiring = kncp.wirings.FullyConnected(8, out_features) # 16 units, 8 motor neurons -ltc_cell = LTCCell(wiring, in_features) -dataloader = data.DataLoader( - data.TensorDataset(data_x, data_y), batch_size=1, shuffle=True, num_workers=4 -) - -ltc_sequence = RNNSequence( - ltc_cell, -) -learn = SequenceLearner(ltc_sequence, lr=0.01) +ltc_model = LTC(in_features, wiring, batch_first=True) +learn = SequenceLearner(ltc_model, lr=0.01) trainer = pl.Trainer( logger=pl.loggers.CSVLogger("log"), max_epochs=400, - progress_bar_refresh_rate=1, gradient_clip_val=1, # Clip gradient to stabilize training - gpus=1, ) + + +# Train the model for 400 epochs (= training steps) trainer.fit(learn, dataloader) -results = trainer.test(learn, dataloader) + +# Let's visualize how LTC initialy performs before the training +sns.set() +with torch.no_grad(): + prediction = ltc_model(data_x)[0].numpy() +plt.figure(figsize=(6, 4)) +plt.plot(data_y[0, :, 0], label="Target output") +plt.plot(prediction[0, :, 0], label="NCP output") +plt.ylim((-1, 1)) +plt.title("Before training") +plt.legend(loc="upper right") +plt.savefig("pt_plot2.png") + +# How does the trained model now fit to the sinusoidal function? +sns.set() +with torch.no_grad(): + prediction = ltc_model(data_x)[0].numpy() +plt.figure(figsize=(6, 4)) +plt.plot(data_y[0, :, 0], label="Target output") +plt.plot(prediction[0, :, 0], label="NCP output") +plt.ylim((-1, 1)) +plt.title("After training") +plt.legend(loc="upper right") +plt.savefig("pt_plot3.png")