From 456d495b557b3b4e873f0ab936a9121450312fe6 Mon Sep 17 00:00:00 2001 From: Maisa Ben Salah <76703998+MaiBe-ctrl@users.noreply.github.com> Date: Wed, 31 Jul 2024 13:40:01 -0700 Subject: [PATCH] [Minor] Torchify timenet (#1620) * Convert to tensors * clarify ID drop * fixed tests * added vectorization * added sequential components * fixed linters * fixed cml plotting * added newlines to CML markdowns * fixed newlines rendering --------- Co-authored-by: ourownstory --- .github/workflows/metrics.yml | 25 +++++- neuralprophet/time_net.py | 146 +++++++++++++++++----------------- 2 files changed, 94 insertions(+), 77 deletions(-) diff --git a/.github/workflows/metrics.yml b/.github/workflows/metrics.yml index 2ea4c8fce..b3befc60c 100644 --- a/.github/workflows/metrics.yml +++ b/.github/workflows/metrics.yml @@ -11,6 +11,7 @@ on: - main - develop workflow_dispatch: + jobs: metrics: runs-on: ubuntu-latest # container: docker://ghcr.io/iterative/cml:0-dvc2-base1 @@ -19,24 +20,32 @@ jobs: uses: actions/checkout@v3 with: ref: ${{ github.event.pull_request.head.sha }} + - name: Install Python 3.12 uses: actions/setup-python@v5 with: python-version: "3.12" + - name: Setup NodeJS (for CML) uses: actions/setup-node@v3 # For CML with: node-version: '16' + - name: Setup CML uses: iterative/setup-cml@v1 + - name: Install Poetry uses: snok/install-poetry@v1 + - name: Install Dependencies run: poetry install --no-interaction --no-root --with=pytest,metrics --without=dev,docs,linters + - name: Install Project run: poetry install --no-interaction --with=pytest,metrics --without=dev,docs,linters + - name: Train model run: poetry run pytest tests/test_model_performance.py -n 1 --durations=0 + - name: Download metrics from main uses: dawidd6/action-download-artifact@v2 with: @@ -45,28 +54,40 @@ jobs: name: metrics path: tests/metrics-main/ if_no_artifact_found: warn + - name: Open Benchmark Report run: echo "## Model Benchmark" >> report.md + - name: Write Benchmark Report run: poetry run python tests/metrics/compareMetrics.py >> report.md + - name: Publish Report with CML env: REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - echo "
\nModel training plots\n" >> report.md + echo "
Model training plots" >> report.md + echo "" >> report.md echo "## Model Training" >> report.md + echo "" >> report.md echo "### PeytonManning" >> report.md cml asset publish tests/metrics/PeytonManning.svg --md >> report.md + echo "" >> report.md echo "### YosemiteTemps" >> report.md cml asset publish tests/metrics/YosemiteTemps.svg --md >> report.md + echo "" >> report.md echo "### AirPassengers" >> report.md cml asset publish tests/metrics/AirPassengers.svg --md >> report.md + echo "" >> report.md echo "### EnergyPriceDaily" >> report.md cml asset publish tests/metrics/EnergyPriceDaily.svg --md >> report.md - echo "\n
" >> report.md + echo "" >> report.md + echo "
" >> report.md + echo "" >> report.md cml comment update --target=pr report.md # Post reports as comments in GitHub PRs cml check create --title=ModelReport report.md # update status of check in PR + - name: Upload metrics if on main + if: github.ref == 'refs/heads/main' uses: actions/upload-artifact@v3 with: name: metrics diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 7aeecd058..a4fbfee3a 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -268,28 +268,34 @@ def __init__( self.ar_layers = ar_layers self.max_lags = max_lags if self.n_lags > 0: - self.ar_net = nn.ModuleList() + ar_net_layers = [] d_inputs = self.n_lags for d_hidden_i in self.ar_layers: - self.ar_net.append(nn.Linear(d_inputs, d_hidden_i, bias=True)) + ar_net_layers.append(nn.Linear(d_inputs, d_hidden_i, bias=True)) + ar_net_layers.append(nn.ReLU()) d_inputs = d_hidden_i # final layer has input size d_inputs and output size equal to no. of forecasts * no. of quantiles - self.ar_net.append(nn.Linear(d_inputs, self.n_forecasts * len(self.quantiles), bias=False)) + ar_net_layers.append(nn.Linear(d_inputs, self.n_forecasts * len(self.quantiles), bias=False)) + self.ar_net = nn.Sequential(*ar_net_layers) for lay in self.ar_net: - nn.init.kaiming_normal_(lay.weight, mode="fan_in") + if isinstance(lay, nn.Linear): + nn.init.kaiming_normal_(lay.weight, mode="fan_in") # Lagged regressors self.lagged_reg_layers = lagged_reg_layers self.config_lagged_regressors = config_lagged_regressors if self.config_lagged_regressors is not None: - self.covar_net = nn.ModuleList() + covar_net_layers = [] d_inputs = sum([covar.n_lags for _, covar in self.config_lagged_regressors.items()]) for d_hidden_i in self.lagged_reg_layers: - self.covar_net.append(nn.Linear(d_inputs, d_hidden_i, bias=True)) + covar_net_layers.append(nn.Linear(d_inputs, d_hidden_i, bias=True)) + covar_net_layers.append(nn.ReLU()) d_inputs = d_hidden_i - self.covar_net.append(nn.Linear(d_inputs, self.n_forecasts * len(self.quantiles), bias=False)) + covar_net_layers.append(nn.Linear(d_inputs, self.n_forecasts * len(self.quantiles), bias=False)) + self.covar_net = nn.Sequential(*covar_net_layers) for lay in self.covar_net: - nn.init.kaiming_normal_(lay.weight, mode="fan_in") + if isinstance(lay, nn.Linear): + nn.init.kaiming_normal_(lay.weight, mode="fan_in") # Regressors self.config_regressors = config_regressors @@ -310,7 +316,9 @@ def __init__( def ar_weights(self) -> torch.Tensor: """sets property auto-regression weights for regularization. Update if AR is modelled differently""" # TODO: this is wrong for deep networks, use utils_torch.interprete_model - return self.ar_net[0].weight + for layer in self.ar_net: + if isinstance(layer, nn.Linear): + return layer.weight def get_covar_weights(self, covar_input=None) -> torch.Tensor: """ @@ -393,49 +401,50 @@ def _compute_quantile_forecasts_from_diffs(self, diffs: torch.Tensor, predict_mo dim (batch, n_forecasts, no_quantiles) final forecasts """ - if len(self.quantiles) > 1: - # generate the actual quantile forecasts from predicted differences - if any(quantile > 0.5 for quantile in self.quantiles): - quantiles_divider_index = next(i for i, quantile in enumerate(self.quantiles) if quantile > 0.5) - else: - quantiles_divider_index = len(self.quantiles) - - n_upper_quantiles = diffs.shape[-1] - quantiles_divider_index - n_lower_quantiles = quantiles_divider_index - 1 - - out = torch.zeros_like(diffs) - out[:, :, 0] = diffs[:, :, 0] # set the median where 0 is the median quantile index - - if n_upper_quantiles > 0: # check if upper quantiles exist - upper_quantile_diffs = diffs[:, :, quantiles_divider_index:] - if predict_mode: # check for quantile crossing and correct them in predict mode - upper_quantile_diffs[:, :, 0] = torch.max( - torch.tensor(0, device=self.device), upper_quantile_diffs[:, :, 0] - ) - for i in range(n_upper_quantiles - 1): - next_diff = upper_quantile_diffs[:, :, i + 1] - diff = upper_quantile_diffs[:, :, i] - upper_quantile_diffs[:, :, i + 1] = torch.max(next_diff, diff) - out[:, :, quantiles_divider_index:] = ( - upper_quantile_diffs + diffs[:, :, 0].unsqueeze(dim=2).repeat(1, 1, n_upper_quantiles).detach() - ) # set the upper quantiles - - if n_lower_quantiles > 0: # check if lower quantiles exist - lower_quantile_diffs = diffs[:, :, 1:quantiles_divider_index] - if predict_mode: # check for quantile crossing and correct them in predict mode - lower_quantile_diffs[:, :, -1] = torch.max( - torch.tensor(0, device=self.device), lower_quantile_diffs[:, :, -1] - ) - for i in range(n_lower_quantiles - 1, 0, -1): - next_diff = lower_quantile_diffs[:, :, i - 1] - diff = lower_quantile_diffs[:, :, i] - lower_quantile_diffs[:, :, i - 1] = torch.max(next_diff, diff) - lower_quantile_diffs = -lower_quantile_diffs - out[:, :, 1:quantiles_divider_index] = ( - lower_quantile_diffs + diffs[:, :, 0].unsqueeze(dim=2).repeat(1, 1, n_lower_quantiles).detach() - ) # set the lower quantiles + + if len(self.quantiles) <= 1: + return diffs + # generate the actual quantile forecasts from predicted differences + if any(quantile > 0.5 for quantile in self.quantiles): + quantiles_divider_index = next(i for i, quantile in enumerate(self.quantiles) if quantile > 0.5) else: - out = diffs + quantiles_divider_index = len(self.quantiles) + + n_upper_quantiles = diffs.shape[-1] - quantiles_divider_index + n_lower_quantiles = quantiles_divider_index - 1 + + out = torch.zeros_like(diffs) + out[:, :, 0] = diffs[:, :, 0] # set the median where 0 is the median quantile index + + if n_upper_quantiles > 0: # check if upper quantiles exist + upper_quantile_diffs = diffs[:, :, quantiles_divider_index:] + if predict_mode: # check for quantile crossing and correct them in predict mode + upper_quantile_diffs[:, :, 0] = torch.max( + torch.tensor(0, device=self.device), upper_quantile_diffs[:, :, 0] + ) + for i in range(n_upper_quantiles - 1): + next_diff = upper_quantile_diffs[:, :, i + 1] + diff = upper_quantile_diffs[:, :, i] + upper_quantile_diffs[:, :, i + 1] = torch.max(next_diff, diff) + out[:, :, quantiles_divider_index:] = ( + upper_quantile_diffs + diffs[:, :, 0].unsqueeze(dim=2).repeat(1, 1, n_upper_quantiles).detach() + ) # set the upper quantiles + + if n_lower_quantiles > 0: # check if lower quantiles exist + lower_quantile_diffs = diffs[:, :, 1:quantiles_divider_index] + if predict_mode: # check for quantile crossing and correct them in predict mode + lower_quantile_diffs[:, :, -1] = torch.max( + torch.tensor(0, device=self.device), lower_quantile_diffs[:, :, -1] + ) + for i in range(n_lower_quantiles - 1, 0, -1): + next_diff = lower_quantile_diffs[:, :, i - 1] + diff = lower_quantile_diffs[:, :, i] + lower_quantile_diffs[:, :, i - 1] = torch.max(next_diff, diff) + lower_quantile_diffs = -lower_quantile_diffs + out[:, :, 1:quantiles_divider_index] = ( + lower_quantile_diffs + diffs[:, :, 0].unsqueeze(dim=2).repeat(1, 1, n_lower_quantiles).detach() + ) # set the lower quantiles + return out def scalar_features_effects(self, features: torch.Tensor, params: nn.Parameter, indices=None) -> torch.Tensor: @@ -474,14 +483,9 @@ def auto_regression(self, lags: Union[torch.Tensor, float]) -> torch.Tensor: torch.Tensor Forecast component of dims: (batch, n_forecasts) """ - x = lags - for i in range(len(self.ar_layers) + 1): - if i > 0: - x = nn.functional.relu(x) - x = self.ar_net[i](x) - + x = self.ar_net(lags) # segment the last dimension to match the quantiles - x = x.reshape(x.shape[0], self.n_forecasts, len(self.quantiles)) + x = x.view(x.shape[0], self.n_forecasts, len(self.quantiles)) return x def forward_covar_net(self, covariates): @@ -501,13 +505,9 @@ def forward_covar_net(self, covariates): x = torch.cat([covar for _, covar in covariates.items()], axis=1) else: x = covariates - for i in range(len(self.lagged_reg_layers) + 1): - if i > 0: - x = nn.functional.relu(x) - x = self.covar_net[i](x) - + x = self.covar_net(x) # segment the last dimension to match the quantiles - x = x.reshape(x.shape[0], self.n_forecasts, len(self.quantiles)) + x = x.view(x.shape[0], self.n_forecasts, len(self.quantiles)) return x def forward(self, inputs: Dict, meta: Dict = None, compute_components_flag: bool = False) -> torch.Tensor: @@ -880,8 +880,7 @@ def _get_time_based_sample_weight(self, t): end_w = self.config_train.newer_samples_weight start_t = self.config_train.newer_samples_start time = (t.detach() - start_t) / (1.0 - start_t) - time = torch.maximum(torch.zeros_like(time), time) - time = torch.minimum(torch.ones_like(time), time) # time = 0 to 1 + time = torch.clamp(time, 0.0, 1.0) # time = 0 to 1 time = np.pi * (time - 1.0) # time = -pi to 0 time = 0.5 * torch.cos(time) + 0.5 # time = 0 to 1 # scales end to be end weight times bigger than start weight @@ -1019,11 +1018,13 @@ class DeepNet(nn.Module): def __init__(self, d_inputs, d_outputs, lagged_reg_layers=[]): # Perform initialization of the pytorch superclass super(DeepNet, self).__init__() - self.layers = nn.ModuleList() + layers = [] for d_hidden_i in lagged_reg_layers: - self.layers.append(nn.Linear(d_inputs, d_hidden_i, bias=True)) + layers.append(nn.Linear(d_inputs, d_hidden_i, bias=True)) + layers.append(nn.ReLU()) d_inputs = d_hidden_i - self.layers.append(nn.Linear(d_inputs, d_outputs, bias=True)) + layers.append(nn.Linear(d_inputs, d_outputs, bias=True)) + self.layers = nn.Sequential(*layers) for lay in self.layers: nn.init.kaiming_normal_(lay.weight, mode="fan_in") @@ -1031,12 +1032,7 @@ def forward(self, x): """ This method defines the network layering and activation functions """ - activation = nn.functional.relu - for i in range(len(self.layers)): - if i > 0: - x = activation(x) - x = self.layers[i](x) - return x + return self.layers(x) @property def ar_weights(self):