Skip to content

Commit

Permalink
Merge branch 'main' into feat/intermediate-plotting/epidemic-demo
Browse files Browse the repository at this point in the history
  • Loading branch information
carynbear authored Jul 22, 2024
2 parents dbc7483 + 447ab6a commit 3d4b218
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 6 deletions.
7 changes: 6 additions & 1 deletion src/dynadojo/baselines/aug_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def fit(self, x: np.ndarray, epochs=100, **kwargs):
x = torch.tensor(x, dtype=torch.float32)
state = x[:, 0, :]
step = end = epochs / self._timesteps

losses = []

for _ in range(epochs):
if _ % step == 0:
t = torch.linspace(0.0, end, self._timesteps)
Expand All @@ -40,6 +41,10 @@ def fit(self, x: np.ndarray, epochs=100, **kwargs):
loss = self.mse_loss(pred, x).float()
loss.backward()
self.opt.step()
losses.append(loss)
return {
"train_losses": losses
}

def predict(self, x0, timesteps):
x0 = torch.tensor(x0, dtype=torch.float32)
Expand Down
8 changes: 7 additions & 1 deletion src/dynadojo/baselines/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def fit(self, x: np.ndarray, epochs=100, **kwargs):
loss_BCEL = nn.BCELoss()
lossMSE = nn.MSELoss()

losses = []
for i in range(epochs):
opt.zero_grad()
pred_states = state[:, 0, :].unsqueeze(1)
Expand All @@ -41,10 +42,15 @@ def fit(self, x: np.ndarray, epochs=100, **kwargs):
next_state = self.lin(self.c1(state_t)) # doesn't even call forward

loss += lossMSE((next_state * (state[:, t, :].unsqueeze(1))), state[:, t+1, :].unsqueeze(1)) #loss_BCEL(next_state, state[:, t + 1, :].unsqueeze(1))

# print(loss.item())
loss.backward()
opt.step()
losses.append(loss.item())

return {
"train_loss": losses
}

def predict(self, x0: np.ndarray, timesteps: int, **kwargs) -> np.ndarray:
state = torch.tensor(x0, dtype=torch.float32)
Expand Down
9 changes: 9 additions & 0 deletions src/dynadojo/baselines/dmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def __init__(self, embed_dim: int, timesteps: int, max_control_cost: float = 0,

def fit(self, x: np.ndarray, **kwargs) -> None:
self._model = self._model.fit(x[0].T)
pred = self._model.predict(x[0].T)
loss = self.mse(x[0].T, pred)
return {
"train_losses": loss
}

def predict(self, x0: np.ndarray, timesteps: int, **kwargs) -> np.ndarray:
result = [x0.T]
Expand All @@ -52,3 +57,7 @@ def predict(self, x0: np.ndarray, timesteps: int, **kwargs) -> np.ndarray:
result = np.array(result)
result = np.transpose(result, axes=(2, 0, 1))
return result

def mse(self, actual, pred):
actual, pred = np.array(actual), np.array(pred)
return np.square(np.subtract(actual,pred)).mean()
11 changes: 9 additions & 2 deletions src/dynadojo/baselines/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def fit(self, x: np.ndarray,
train = np.array(x)
train_size = len(x)
early_stopper = None
val_losses = None
else:
if verbose > 0:
print(f'Training on {1-validation_split} of the data, validating on the rest')
Expand All @@ -73,7 +74,7 @@ def fit(self, x: np.ndarray,
train, val = random_split(x, [len(x)-validation_size, validation_size])
train = np.array(train)
val = np.array(train)

val_losses = []
#Validation dataset
x_val = torch.tensor(np.array(val[:, :-1, :]), dtype=torch.float32).to(self.device)
y_val = torch.tensor(np.array(val[:, 1:, :]), dtype=torch.float32).to(self.device)
Expand Down Expand Up @@ -120,14 +121,20 @@ def fit(self, x: np.ndarray,
if verbose > 0 and (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(dataloader):.4f}, Val Loss: {val_loss:.4f}, , took {time.time() - training_start_time:.2f}s')
training_start_time = time.time()

val_losses.append(val_loss)

if early_stopper.early_stop(epoch, val_loss, self.state_dict()):
if verbose > 0:
print(f'Early stopping at epoch {epoch+1}')
print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(dataloader):.4f}, Val Loss: {val_loss:.4f}')
break
if early_stopper is not None:
self.load_state_dict(early_stopper.best_weights)
return losses
return {
"train_loss": losses,
"val_loss": val_losses
}

def predict(self, x0: np.ndarray, timesteps: int, **kwargs) -> np.ndarray:
self.eval()
Expand Down
9 changes: 8 additions & 1 deletion src/dynadojo/baselines/dnn_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ def fit(self, x: np.ndarray, epochs=2000, verbose=0, **kwargs):
head = x[:, :-1, :]
tail = x[:, 1:, :]
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
self.model.fit(head, tail, validation_split=0.2, epochs=epochs, callbacks=[callback], verbose=verbose)
history = self.model.fit(head, tail, validation_split=0.2, epochs=epochs, callbacks=[callback], verbose=verbose)
# print(history.history.keys())
train_losses = history.history['loss']
val_losses = history.history['val_loss']
return {
"train_loss": train_losses,
"val_loss": val_losses
}

def predict(self, x0: np.ndarray, timesteps: int, **kwargs) -> np.ndarray:
preds = [x0]
Expand Down
11 changes: 11 additions & 0 deletions src/dynadojo/baselines/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def fit(self, x: np.ndarray, **kwargs):

self.model.fit(X_train, y_train)
self.A_hat = self.model.coef_

pred = self.model.predict(X_train)
loss = self.mse(y_train, pred)
# print(loss)
return{
"train_losses": loss
}

def act(self, x, **kwargs):
self.U = self._rng.uniform(-1, 1, [len(x[0]), self._timesteps, self.embed_dim])
Expand All @@ -45,3 +52,7 @@ def predict(self, x0, timesteps, **kwargs):
preds = np.squeeze(np.array(preds), 0)
preds = np.transpose(preds, (2, 0, 1))
return preds

def mse(self, actual, pred):
actual, pred = np.array(actual), np.array(pred)
return np.square(np.subtract(actual,pred)).mean()
6 changes: 5 additions & 1 deletion src/dynadojo/baselines/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@ def fit(self, x: np.ndarray, epochs=100, **kwargs):
x = torch.tensor(x, dtype=torch.float32)
state = x[:, 0, :]
t = torch.linspace(0.0, self._timesteps, self._timesteps)
losses = []
for _ in range(epochs):
self.opt.zero_grad()
pred = odeint(self.forward, state, t, method='rk4')
pred = pred.transpose(0, 1)
loss = self.mse_loss(pred, x).float()
loss.backward()
self.opt.step()

losses.append(loss)
return {
"train_losses": losses
}
def predict(self, x0: np.ndarray, timesteps: int, **kwargs) -> np.ndarray:
x0 = torch.tensor(x0, dtype=torch.float32)
t = torch.linspace(0.0, timesteps, timesteps)
Expand Down

0 comments on commit 3d4b218

Please sign in to comment.