-
Notifications
You must be signed in to change notification settings - Fork 316
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
68 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,42 @@ | ||
# Copyright (2017-2021) | ||
# The Wormnet project | ||
# Mathias Lechner ([email protected]) | ||
import numpy as np | ||
import torch.nn as nn | ||
import kerasncp as kncp | ||
from kerasncp.torch import LTCCell | ||
from ncps.wirings import AutoNCP | ||
from ncps.torch import LTC | ||
import pytorch_lightning as pl | ||
import torch | ||
import torch.utils.data as data | ||
|
||
# nn.Module that unfolds a RNN cell into a sequence | ||
class RNNSequence(nn.Module): | ||
def __init__( | ||
self, | ||
rnn_cell, | ||
): | ||
super(RNNSequence, self).__init__() | ||
self.rnn_cell = rnn_cell | ||
|
||
def forward(self, x): | ||
device = x.device | ||
batch_size = x.size(0) | ||
seq_len = x.size(1) | ||
hidden_state = torch.zeros( | ||
(batch_size, self.rnn_cell.state_size), device=device | ||
) | ||
outputs = [] | ||
for t in range(seq_len): | ||
inputs = x[:, t] | ||
new_output, hidden_state = self.rnn_cell.forward(inputs, hidden_state) | ||
outputs.append(new_output) | ||
outputs = torch.stack(outputs, dim=1) # return entire sequence | ||
return outputs | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
|
||
N = 48 # Length of the time-series | ||
out_features = 1 | ||
in_features = 2 | ||
# Input feature is a sine and a cosine wave | ||
data_x = np.stack( | ||
[np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], axis=1 | ||
) | ||
data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension | ||
# Target output is a sine with double the frequency of the input signal | ||
data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) | ||
print("data_x.shape: ", str(data_x.shape)) | ||
print("data_y.shape: ", str(data_y.shape)) | ||
data_x = torch.Tensor(data_x) | ||
data_y = torch.Tensor(data_y) | ||
dataloader = data.DataLoader( | ||
data.TensorDataset(data_x, data_y), batch_size=1, shuffle=True, num_workers=4 | ||
) | ||
|
||
# Let's visualize the training data | ||
sns.set() | ||
plt.figure(figsize=(6, 4)) | ||
plt.plot(data_x[0, :, 0], label="Input feature 1") | ||
plt.plot(data_x[0, :, 1], label="Input feature 1") | ||
plt.plot(data_y[0, :, 0], label="Target output") | ||
plt.ylim((-1, 1)) | ||
plt.title("Training data") | ||
plt.legend(loc="upper right") | ||
plt.savefig("pt_plot1.png") | ||
|
||
|
||
# LightningModule for training a RNNSequence module | ||
|
@@ -43,15 +48,15 @@ def __init__(self, model, lr=0.005): | |
|
||
def training_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self.model.forward(x) | ||
y_hat, _ = self.model.forward(x) | ||
y_hat = y_hat.view_as(y) | ||
loss = nn.MSELoss()(y_hat, y) | ||
self.log("train_loss", loss, prog_bar=True) | ||
return {"loss": loss} | ||
|
||
def validation_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self.model.forward(x) | ||
y_hat, _ = self.model.forward(x) | ||
y_hat = y_hat.view_as(y) | ||
loss = nn.MSELoss()(y_hat, y) | ||
|
||
|
@@ -65,52 +70,41 @@ def test_step(self, batch, batch_idx): | |
def configure_optimizers(self): | ||
return torch.optim.Adam(self.model.parameters(), lr=self.lr) | ||
|
||
def optimizer_step( | ||
self, | ||
current_epoch, | ||
batch_nb, | ||
optimizer, | ||
optimizer_idx, | ||
closure, | ||
on_tpu=False, | ||
using_native_amp=False, | ||
using_lbfgs=False, | ||
): | ||
optimizer.optimizer.step(closure=closure) | ||
# Apply weight constraints | ||
self.model.rnn_cell.apply_weight_constraints() | ||
|
||
wiring = AutoNCP(16, out_features) # 16 units, 1 motor neuron | ||
|
||
in_features = 2 | ||
out_features = 1 | ||
N = 48 # Length of the time-series | ||
# Input feature is a sine and a cosine wave | ||
data_x = np.stack( | ||
[np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], axis=1 | ||
) | ||
data_x = np.expand_dims(data_x, axis=0).astype(np.float32) # Add batch dimension | ||
# Target output is a sine with double the frequency of the input signal | ||
data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32) | ||
data_x = torch.Tensor(data_x) | ||
data_y = torch.Tensor(data_y) | ||
print("data_y.shape: ", str(data_y.shape)) | ||
|
||
wiring = kncp.wirings.FullyConnected(8, out_features) # 16 units, 8 motor neurons | ||
ltc_cell = LTCCell(wiring, in_features) | ||
dataloader = data.DataLoader( | ||
data.TensorDataset(data_x, data_y), batch_size=1, shuffle=True, num_workers=4 | ||
) | ||
|
||
ltc_sequence = RNNSequence( | ||
ltc_cell, | ||
) | ||
learn = SequenceLearner(ltc_sequence, lr=0.01) | ||
ltc_model = LTC(in_features, wiring, batch_first=True) | ||
learn = SequenceLearner(ltc_model, lr=0.01) | ||
trainer = pl.Trainer( | ||
logger=pl.loggers.CSVLogger("log"), | ||
max_epochs=400, | ||
progress_bar_refresh_rate=1, | ||
gradient_clip_val=1, # Clip gradient to stabilize training | ||
gpus=1, | ||
) | ||
|
||
|
||
# Train the model for 400 epochs (= training steps) | ||
trainer.fit(learn, dataloader) | ||
results = trainer.test(learn, dataloader) | ||
|
||
# Let's visualize how LTC initialy performs before the training | ||
sns.set() | ||
with torch.no_grad(): | ||
prediction = ltc_model(data_x)[0].numpy() | ||
plt.figure(figsize=(6, 4)) | ||
plt.plot(data_y[0, :, 0], label="Target output") | ||
plt.plot(prediction[0, :, 0], label="NCP output") | ||
plt.ylim((-1, 1)) | ||
plt.title("Before training") | ||
plt.legend(loc="upper right") | ||
plt.savefig("pt_plot2.png") | ||
|
||
# How does the trained model now fit to the sinusoidal function? | ||
sns.set() | ||
with torch.no_grad(): | ||
prediction = ltc_model(data_x)[0].numpy() | ||
plt.figure(figsize=(6, 4)) | ||
plt.plot(data_y[0, :, 0], label="Target output") | ||
plt.plot(prediction[0, :, 0], label="NCP output") | ||
plt.ylim((-1, 1)) | ||
plt.title("After training") | ||
plt.legend(loc="upper right") | ||
plt.savefig("pt_plot3.png") |