diff --git a/main.py b/main.py index d9d3eab..084b0dc 100644 --- a/main.py +++ b/main.py @@ -38,7 +38,7 @@ validation_split = 0.20 test_split = 0.20 -dataloader = TimeSeriesDataLoader(X, y, validation_split=validation_split, test_split=test_split, period=1000, batch_size=10) +dataloader = TimeSeriesDataLoader(X, y, validation_split=validation_split, test_split=test_split, period=1000, batch_size=32) model = SimpleLSTM(X.shape[1], 100, 3, batch_first=True, dropout=0.5) if cuda_available: @@ -61,6 +61,8 @@ forecast, _ = model.forecast(X.unsqueeze(0)) forecast = forecast.flatten().cpu().detach().numpy() +print(f"Forecast, {forecast}") +print(f"Expected, {y}") fig, axs = plt.subplots(ncols=2, figsize=(10, 5))