diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a8c807b9..8479c6ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # hooks for checking files - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -9,15 +9,18 @@ repos: # hooks for linting code - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 24.8.0 hooks: - id: black + args: [ + --line-length=120, # refer to pyproject.toml + ] - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 7.1.1 hooks: - id: flake8 args: [ --max-line-length=120, # refer to pyproject.toml - --extend-ignore=E203, # why ignore E203? Refer to https://github.com/PyCQA/pycodestyle/issues/373 + --extend-ignore=E203,E231 ] diff --git a/docs/conf.py b/docs/conf.py index 4a22ce4e..5b8e3822 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -108,8 +108,7 @@ html_context["READTHEDOCS"] = True html_favicon = ( - "https://raw.githubusercontent.com/" - "PyPOTS/pypots.github.io/main/static/figs/pypots_logos/PyPOTS/logo_FFBG.svg" + "https://raw.githubusercontent.com/PyPOTS/pypots.github.io/main/static/figs/pypots_logos/PyPOTS/logo_FFBG.svg" ) html_sidebars = { diff --git a/pypots/base.py b/pypots/base.py index 15e1f64b..bb0a27e1 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -106,9 +106,7 @@ def _setup_device(self, device: Union[None, str, torch.device, list]) -> None: self.device = device elif isinstance(device, list): if len(device) == 0: - raise ValueError( - "The list of devices should have at least 1 device, but got 0." - ) + raise ValueError("The list of devices should have at least 1 device, but got 0.") elif len(device) == 1: return self._setup_device(device[0]) # parallely training on multiple CUDA devices @@ -179,18 +177,14 @@ def _setup_path(self, saving_path) -> None: logger.info(f"Model files will be saved to {self.saving_path}") logger.info(f"Tensorboard file will be saved to {tb_saving_path}") else: - logger.warning( - "‼️ saving_path not given. Model files and tensorboard file will not be saved." - ) + logger.warning("‼️ saving_path not given. Model files and tensorboard file will not be saved.") def _send_model_to_given_device(self) -> None: if isinstance(self.device, list): # parallely training on multiple devices self.model = torch.nn.DataParallel(self.model, device_ids=self.device) self.model = self.model.cuda() - logger.info( - f"Model has been allocated to the given multiple devices: {self.device}" - ) + logger.info(f"Model has been allocated to the given multiple devices: {self.device}") else: self.model = self.model.to(self.device) @@ -291,9 +285,7 @@ def save( if os.path.exists(saving_path): if overwrite: - logger.warning( - f"‼️ File {saving_path} exists. Argument `overwrite` is True. Overwriting now..." - ) + logger.warning(f"‼️ File {saving_path} exists. Argument `overwrite` is True. Overwriting now...") else: logger.error( f"❌ File {saving_path} exists. Saving operation aborted. " @@ -309,9 +301,7 @@ def save( torch.save(self.model, saving_path) logger.info(f"Saved the model to {saving_path}") except Exception as e: - raise RuntimeError( - f'Failed to save the model to "{saving_path}" because of the below error! \n{e}' - ) + raise RuntimeError(f'Failed to save the model to "{saving_path}" because of the below error! \n{e}') def load(self, path: str) -> None: """Load the saved model from a disk file. @@ -519,9 +509,7 @@ def __init__( def _print_model_size(self) -> None: """Print the number of trainable parameters in the initialized NN model.""" - self.num_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) + self.num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) logger.info( f"{self.__class__.__name__} initialized with the given hyperparameters, " f"the number of trainable parameters: {self.num_params:,}" diff --git a/pypots/classification/base.py b/pypots/classification/base.py index ca587c29..e1848602 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -313,9 +313,7 @@ def _train_model( for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) results = self.model.forward(inputs) - epoch_val_loss_collector.append( - results["loss"].sum().item() - ) + epoch_val_loss_collector.append(results["loss"].sum().item()) mean_val_loss = np.mean(epoch_val_loss_collector) @@ -333,15 +331,11 @@ def _train_model( ) mean_loss = mean_val_loss else: - logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" - ) + logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}") mean_loss = mean_train_loss if np.isnan(mean_loss): - logger.warning( - f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." - ) + logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.") if mean_loss < self.best_loss: self.best_epoch = epoch @@ -363,9 +357,7 @@ def _train_model( nni.report_final_result(self.best_loss) if self.patience == 0: - logger.info( - "Exceeded the training patience. Terminating the training procedure..." - ) + logger.info("Exceeded the training patience. Terminating the training procedure...") break except KeyboardInterrupt: # if keyboard interrupt, only warning @@ -386,9 +378,7 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info( - f"Finished training. The best model is from epoch#{self.best_epoch}." - ) + logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.") @abstractmethod def fit( diff --git a/pypots/classification/grud/core.py b/pypots/classification/grud/core.py index ed656a39..ca2b635d 100644 --- a/pypots/classification/grud/core.py +++ b/pypots/classification/grud/core.py @@ -58,9 +58,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: empirical_mean = inputs["empirical_mean"] X_filledLOCF = inputs["X_filledLOCF"] - _, hidden_state = self.model( - X, missing_mask, deltas, empirical_mean, X_filledLOCF - ) + _, hidden_state = self.model(X, missing_mask, deltas, empirical_mean, X_filledLOCF) logits = self.classifier(hidden_state) classification_pred = torch.softmax(logits, dim=1) @@ -68,9 +66,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: - classification_loss = F.nll_loss( - torch.log(classification_pred), inputs["label"] - ) + classification_loss = F.nll_loss(torch.log(classification_pred), inputs["label"]) results["loss"] = classification_loss return results diff --git a/pypots/classification/grud/data.py b/pypots/classification/grud/data.py index 3287a6f6..5028b7d3 100644 --- a/pypots/classification/grud/data.py +++ b/pypots/classification/grud/data.py @@ -60,9 +60,9 @@ def __init__( self.X_filledLOCF = locf_torch(self.X) self.X = torch.nan_to_num(self.X) self.deltas = _parse_delta_torch(self.missing_mask) - self.empirical_mean = torch.sum( - self.missing_mask * self.X, dim=[0, 1] - ) / torch.sum(self.missing_mask, dim=[0, 1]) + self.empirical_mean = torch.sum(self.missing_mask * self.X, dim=[0, 1]) / torch.sum( + self.missing_mask, dim=[0, 1] + ) # fill nan with 0, in case some features have no observations self.empirical_mean = torch.nan_to_num(self.empirical_mean, 0) @@ -134,9 +134,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: X_filledLOCF = locf_torch(X.unsqueeze(dim=0)).squeeze() X = torch.nan_to_num(X) deltas = _parse_delta_torch(missing_mask) - empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum( - missing_mask, dim=[0] - ) + empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum(missing_mask, dim=[0]) sample = [ torch.tensor(idx), diff --git a/pypots/classification/raindrop/core.py b/pypots/classification/raindrop/core.py index 5e6deb99..24d9f814 100644 --- a/pypots/classification/raindrop/core.py +++ b/pypots/classification/raindrop/core.py @@ -3,7 +3,6 @@ and takes over the forward progress of the algorithm. """ - # Created by Wenjie Du # License: BSD-3-Clause @@ -84,21 +83,13 @@ def forward(self, inputs, training=True): lengths2 = lengths.unsqueeze(1).to(device) mask2 = mask.permute(1, 0).unsqueeze(2).long() if self.sensor_wise_mask: - output = torch.zeros( - [batch_size, self.n_features, self.d_ob + 16], device=device - ) + output = torch.zeros([batch_size, self.n_features, self.d_ob + 16], device=device) extended_missing_mask = missing_mask.view(-1, batch_size, self.n_features) for se in range(self.n_features): - representation = representation.view( - -1, batch_size, self.n_features, (self.d_ob + 16) - ) + representation = representation.view(-1, batch_size, self.n_features, (self.d_ob + 16)) out = representation[:, :, se, :] - l_ = torch.sum(extended_missing_mask[:, :, se], dim=0).unsqueeze( - 1 - ) # length - out_sensor = torch.sum( - out * (1 - extended_missing_mask[:, :, se].unsqueeze(-1)), dim=0 - ) / (l_ + 1) + l_ = torch.sum(extended_missing_mask[:, :, se], dim=0).unsqueeze(1) # length + out_sensor = torch.sum(out * (1 - extended_missing_mask[:, :, se].unsqueeze(-1)), dim=0) / (l_ + 1) output[:, se, :] = out_sensor output = output.view([-1, self.n_features * (self.d_ob + 16)]) elif self.aggregation == "mean": @@ -116,9 +107,7 @@ def forward(self, inputs, training=True): # if in training mode, return results with losses if training: - classification_loss = F.nll_loss( - torch.log(classification_pred), inputs["label"] - ) + classification_loss = F.nll_loss(torch.log(classification_pred), inputs["label"]) results["loss"] = classification_loss return results diff --git a/pypots/classification/raindrop/model.py b/pypots/classification/raindrop/model.py index c773f16a..f599b204 100644 --- a/pypots/classification/raindrop/model.py +++ b/pypots/classification/raindrop/model.py @@ -3,7 +3,6 @@ """ - # Created by Wenjie Du # License: BSD-3-Clause diff --git a/pypots/cli/dev.py b/pypots/cli/dev.py index 85a5dbf1..3522fdbe 100644 --- a/pypots/cli/dev.py +++ b/pypots/cli/dev.py @@ -131,10 +131,9 @@ def checkup(self): ) if self._cleanup: - assert not self._run_tests and not self._lint_code, ( - "Argument `--cleanup` should be used alone. " - "Try `pypots-cli dev --cleanup`" - ) + assert ( + not self._run_tests and not self._lint_code + ), "Argument `--cleanup` should be used alone. Try `pypots-cli dev --cleanup`" def run(self): """Execute the given command.""" @@ -149,14 +148,8 @@ def run(self): elif self._build: self.execute_command("python -m build") elif self._run_tests: - pytest_command = ( - f"pytest -k {self._k}" if self._k is not None else "pytest" - ) - command_to_run_test = ( - f"coverage run -m {pytest_command}" - if self._show_coverage - else pytest_command - ) + pytest_command = f"pytest -k {self._k}" if self._k is not None else "pytest" + command_to_run_test = f"coverage run -m {pytest_command}" if self._show_coverage else pytest_command self.execute_command(command_to_run_test) if self._show_coverage and os.path.exists(".coverage"): self.execute_command("coverage report -m") diff --git a/pypots/cli/doc.py b/pypots/cli/doc.py index d525b169..e0c01d51 100644 --- a/pypots/cli/doc.py +++ b/pypots/cli/doc.py @@ -46,9 +46,7 @@ def doc_command_factory(args: Namespace): def purge_temp_files(): - logger.info( - f"Directories _build and {CLONED_LATEST_PYPOTS} will be deleted if exist" - ) + logger.info(f"Directories _build and {CLONED_LATEST_PYPOTS} will be deleted if exist") shutil.rmtree("docs/_build", ignore_errors=True) shutil.rmtree(CLONED_LATEST_PYPOTS, ignore_errors=True) @@ -148,10 +146,9 @@ def checkup(self): self.check_if_under_root_dir(strict=True) if self._cleanup: - assert not self._gene_rst and not self._gene_html and not self._view_doc, ( - "Argument `--cleanup` should be used alone. " - "Try `pypots-cli doc --cleanup`" - ) + assert ( + not self._gene_rst and not self._gene_html and not self._view_doc + ), "Argument `--cleanup` should be used alone. Try `pypots-cli doc --cleanup`" def run(self): """Execute the given command.""" @@ -166,9 +163,7 @@ def run(self): if self._gene_rst: if os.path.exists(CLONED_LATEST_PYPOTS): - logger.info( - f"Directory {CLONED_LATEST_PYPOTS} exists, deleting it..." - ) + logger.info(f"Directory {CLONED_LATEST_PYPOTS} exists, deleting it...") shutil.rmtree(CLONED_LATEST_PYPOTS, ignore_errors=True) # Download the latest code from GitHub @@ -185,18 +180,12 @@ def run(self): for f_ in files_to_move: shutil.move(os.path.join(code_dir, f_), destination_dir) # delete code in tests because we don't need its doc - shutil.rmtree( - f"{CLONED_LATEST_PYPOTS}/pypots/tests", ignore_errors=True - ) + shutil.rmtree(f"{CLONED_LATEST_PYPOTS}/pypots/tests", ignore_errors=True) # Generate the docs according to the cloned code logger.info("Generating rst files...") - os.environ[ - "SPHINX_APIDOC_OPTIONS" - ] = "members,undoc-members,show-inheritance,inherited-members" - self.execute_command( - f"sphinx-apidoc {CLONED_LATEST_PYPOTS} -o {CLONED_LATEST_PYPOTS}/rst" - ) + os.environ["SPHINX_APIDOC_OPTIONS"] = "members,undoc-members,show-inheritance,inherited-members" + self.execute_command(f"sphinx-apidoc {CLONED_LATEST_PYPOTS} -o {CLONED_LATEST_PYPOTS}/rst") # Only save the files we need. logger.info("Updating the old documentation...") @@ -217,9 +206,7 @@ def run(self): "docs/_build/html" ), "docs/_build/html does not exists, please run `pypots-cli doc --gene_html` first" logger.info(f"Deploying HTML to http://127.0.0.1:{self._port}...") - self.execute_command( - f"python -m http.server {self._port} -d docs/_build/html -b 127.0.0.1" - ) + self.execute_command(f"python -m http.server {self._port} -d docs/_build/html -b 127.0.0.1") except ImportError: raise ImportError(IMPORT_ERROR_MESSAGE) diff --git a/pypots/cli/env.py b/pypots/cli/env.py index be1330cd..028377ff 100644 --- a/pypots/cli/env.py +++ b/pypots/cli/env.py @@ -94,18 +94,14 @@ def run(self): # run checks first self.checkup() - logger.info( - f"Installing the dependencies in scope `{self._install}` for you..." - ) + logger.info(f"Installing the dependencies in scope `{self._install}` for you...") if self._tool == "conda": assert ( self.execute_command("which conda").returncode == 0 ), "Conda not installed, cannot set --tool=conda, please check your conda." - self.execute_command( - "conda install pyg pytorch-scatter pytorch-sparse -c pyg" - ) + self.execute_command("conda install pyg pytorch-scatter pytorch-sparse -c pyg") else: # self._tool == "pip" torch_version = torch.__version__ diff --git a/pypots/cli/pypots_cli.py b/pypots/cli/pypots_cli.py index 7fbf3108..c116755a 100644 --- a/pypots/cli/pypots_cli.py +++ b/pypots/cli/pypots_cli.py @@ -14,9 +14,7 @@ def main(): - parser = ArgumentParser( - "PyPOTS Command-Line-Interface tool", usage="pypots-cli []" - ) + parser = ArgumentParser("PyPOTS Command-Line-Interface tool", usage="pypots-cli []") commands_parser = parser.add_subparsers(help="pypots-cli command helpers") # Register commands here diff --git a/pypots/cli/tuning.py b/pypots/cli/tuning.py index 23cb2b43..2af0a863 100644 --- a/pypots/cli/tuning.py +++ b/pypots/cli/tuning.py @@ -249,12 +249,8 @@ def run(self): model_arguments_set = set(model_all_arguments) if_hyperparameter_match = tuner_params_set.issubset(model_arguments_set) if not if_hyperparameter_match: # raise runtime error if mismatch - hyperparameter_intersection = tuner_params_set.intersection( - model_arguments_set - ) - mismatched = tuner_params_set.difference( - set(hyperparameter_intersection) - ) + hyperparameter_intersection = tuner_params_set.intersection(model_arguments_set) + mismatched = tuner_params_set.difference(set(hyperparameter_intersection)) raise RuntimeError( f"Hyperparameters do not match. Mismatched hyperparameters " f"(in the tuning configuration but not in {model_class.__name__}'s arguments): {list(mismatched)}" @@ -277,9 +273,7 @@ def run(self): if self._lazy_load: train_set, val_set = self._train_set, self._val_set else: - logger.info( - "Option lazy_load is set as False, hence loading all data from file..." - ) + logger.info("Option lazy_load is set as False, hence loading all data from file...") train_set = load_dict_from_h5(self._train_set) val_set = load_dict_from_h5(self._val_set) diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index 528d8fc2..d0781c89 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -320,9 +320,7 @@ def _train_model( for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) results = self.model.forward(inputs) - epoch_val_loss_collector.append( - results["loss"].sum().item() - ) + epoch_val_loss_collector.append(results["loss"].sum().item()) mean_val_loss = np.mean(epoch_val_loss_collector) logger.info( @@ -332,15 +330,11 @@ def _train_model( ) mean_loss = mean_val_loss else: - logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" - ) + logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}") mean_loss = mean_train_loss if np.isnan(mean_loss): - logger.warning( - f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." - ) + logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.") if mean_loss < self.best_loss: self.best_epoch = epoch @@ -356,9 +350,7 @@ def _train_model( nni.report_final_result(self.best_loss) if self.patience == 0: - logger.info( - "Exceeded the training patience. Terminating the training procedure..." - ) + logger.info("Exceeded the training patience. Terminating the training procedure...") break except KeyboardInterrupt: # if keyboard interrupt, only warning @@ -379,9 +371,7 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info( - f"Finished training. The best model is from epoch#{self.best_epoch}." - ) + logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.") @abstractmethod def fit( diff --git a/pypots/clustering/crli/core.py b/pypots/clustering/crli/core.py index 755d9ff7..baf6f5fc 100644 --- a/pypots/clustering/crli/core.py +++ b/pypots/clustering/crli/core.py @@ -58,9 +58,7 @@ def forward( training: bool = True, ) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] - imputation_latent, discrimination, reconstruction, fcn_latent = self.backbone( - X, missing_mask - ) + imputation_latent, discrimination, reconstruction, fcn_latent = self.backbone(X, missing_mask) results = { "imputation_latent": imputation_latent, "discrimination": discrimination, @@ -77,23 +75,16 @@ def forward( results["discrimination_loss"] = l_D else: # discrimination = discrimination.detach() - l_G = F.binary_cross_entropy_with_logits( - discrimination, 1 - missing_mask, weight=1 - missing_mask - ) + l_G = F.binary_cross_entropy_with_logits(discrimination, 1 - missing_mask, weight=1 - missing_mask) l_pre = calc_mse(imputation_latent, X, missing_mask) l_rec = calc_mse(reconstruction, X, missing_mask) HTH = torch.matmul(fcn_latent, fcn_latent.permute(1, 0)) - if ( - self.counter_for_updating_F == 0 - or self.counter_for_updating_F % 10 == 0 - ): + if self.counter_for_updating_F == 0 or self.counter_for_updating_F % 10 == 0: U, s, V = torch.linalg.svd(fcn_latent) self.term_F = U[:, : self.n_clusters] - FTHTHF = torch.matmul( - torch.matmul(self.term_F.permute(1, 0), HTH), self.term_F - ) + FTHTHF = torch.matmul(torch.matmul(self.term_F.permute(1, 0), HTH), self.term_F) l_kmeans = torch.trace(HTH) - torch.trace(FTHTHF) # k-means loss loss_gene = l_G + l_pre + l_rec + l_kmeans * self.lambda_kmeans results["generation_loss"] = loss_gene diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index 8210fc6d..f1838af3 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -216,25 +216,17 @@ def _train_model( step_train_loss_D_collector = [] for _ in range(self.D_steps): self.D_optimizer.zero_grad() - results = self.model.forward( - inputs, training_object="discriminator" - ) + results = self.model.forward(inputs, training_object="discriminator") results["discrimination_loss"].backward(retain_graph=True) self.D_optimizer.step() - step_train_loss_D_collector.append( - results["discrimination_loss"].sum().item() - ) + step_train_loss_D_collector.append(results["discrimination_loss"].sum().item()) for _ in range(self.G_steps): self.G_optimizer.zero_grad() - results = self.model.forward( - inputs, training_object="generator" - ) + results = self.model.forward(inputs, training_object="generator") results["generation_loss"].backward() self.G_optimizer.step() - step_train_loss_G_collector.append( - results["generation_loss"].sum().item() - ) + step_train_loss_G_collector.append(results["generation_loss"].sum().item()) mean_step_train_D_loss = np.mean(step_train_loss_D_collector) mean_step_train_G_loss = np.mean(step_train_loss_G_collector) @@ -250,9 +242,7 @@ def _train_model( "generation_loss": mean_step_train_G_loss, "discrimination_loss": mean_step_train_D_loss, } - self._save_log_into_tb_file( - training_step, "training", loss_results - ) + self._save_log_into_tb_file(training_step, "training", loss_results) mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector) mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector) @@ -264,9 +254,7 @@ def _train_model( for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) results = self.model.forward(inputs, training=True) - epoch_val_loss_G_collector.append( - results["generation_loss"].sum().item() - ) + epoch_val_loss_G_collector.append(results["generation_loss"].sum().item()) mean_val_G_loss = np.mean(epoch_val_loss_G_collector) # save validation loss logs into the tensorboard file for every epoch if in need if self.summary_writer is not None: @@ -290,9 +278,7 @@ def _train_model( mean_loss = mean_epoch_train_G_loss if np.isnan(mean_loss): - logger.warning( - f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." - ) + logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.") if mean_loss < self.best_loss: self.best_epoch = epoch @@ -314,9 +300,7 @@ def _train_model( ) if self.patience == 0: - logger.info( - "Exceeded the training patience. Terminating the training procedure..." - ) + logger.info("Exceeded the training patience. Terminating the training procedure...") break except KeyboardInterrupt: # if keyboard interrupt, only warning @@ -337,9 +321,7 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info( - f"Finished training. The best model is from epoch#{self.best_epoch}." - ) + logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.") def fit( self, @@ -426,9 +408,7 @@ def predict( if return_latent_vars: imputation_collector.append(inputs["imputation_latent"]) - clustering_latent = ( - torch.cat(clustering_latent_collector).cpu().detach().numpy() - ) + clustering_latent = torch.cat(clustering_latent_collector).cpu().detach().numpy() clustering = self.model.kmeans.fit_predict(clustering_latent) result_dict = { diff --git a/pypots/clustering/vader/core.py b/pypots/clustering/vader/core.py index 52843b72..4ab9cb70 100644 --- a/pypots/clustering/vader/core.py +++ b/pypots/clustering/vader/core.py @@ -102,12 +102,7 @@ def forward( # calculate the reconstruction loss unscaled_reconstruction_loss = calc_mse(X_reconstructed, X, missing_mask) - reconstruction_loss = ( - unscaled_reconstruction_loss - * self.n_steps - * self.d_input - / missing_mask.sum() - ) + reconstruction_loss = unscaled_reconstruction_loss * self.n_steps * self.d_input / missing_mask.sum() if pretrain: results["loss"] = reconstruction_loss @@ -136,9 +131,7 @@ def forward( sc_b = var_c.index_select(dim=0, index=ii) z_b = z.index_select(dim=0, index=jj) log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b) - log_pdf_z = log_pdf_z.reshape( - [batch_size, self.n_clusters, self.d_mu_stddev] - ) + log_pdf_z = log_pdf_z.reshape([batch_size, self.n_clusters, self.d_mu_stddev]) log_p = log_phi_c + log_pdf_z.sum(dim=2) lse_p = log_p.logsumexp(dim=1, keepdim=True) @@ -159,9 +152,7 @@ def forward( [batch_size, self.n_clusters, self.d_mu_stddev], ) - latent_loss1 = 0.5 * torch.sum( - gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1 - ) + latent_loss1 = 0.5 * torch.sum(gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1) latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1) latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1) diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index cb5f3201..0a6e6418 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -122,18 +122,14 @@ def __init__( verbose, ) - assert ( - pretrain_epochs > 0 - ), f"pretrain_epochs must be a positive integer, but got {pretrain_epochs}" + assert pretrain_epochs > 0, f"pretrain_epochs must be a positive integer, but got {pretrain_epochs}" self.n_steps = n_steps self.n_features = n_features self.pretrain_epochs = pretrain_epochs # set up the model - self.model = _VaDER( - n_steps, n_features, n_clusters, rnn_hidden_size, d_mu_stddev - ) + self.model = _VaDER(n_steps, n_features, n_clusters, rnn_hidden_size, d_mu_stddev) self._send_model_to_given_device() self._print_model_size() @@ -181,9 +177,7 @@ def _train_model( # save pre-training loss logs into the tensorboard file for every step if in need if self.summary_writer is not None: - self._save_log_into_tb_file( - pretraining_step, "pretraining", results - ) + self._save_log_into_tb_file(pretraining_step, "pretraining", results) with torch.no_grad(): sample_collector = [] @@ -212,9 +206,7 @@ def _train_model( flag = 1 except ValueError as e: logger.error(f"❌ Exception: {e}") - logger.warning( - "‼️ Met with ValueError, double `reg_covar` to re-train the GMM model." - ) + logger.warning("‼️ Met with ValueError, double `reg_covar` to re-train the GMM model.") flag -= 1 if flag == -5: @@ -277,9 +269,7 @@ def _train_model( for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) results = self.model.forward(inputs) - epoch_val_loss_collector.append( - results["loss"].sum().item() - ) + epoch_val_loss_collector.append(results["loss"].sum().item()) mean_val_loss = np.mean(epoch_val_loss_collector) @@ -297,15 +287,11 @@ def _train_model( ) mean_loss = mean_val_loss else: - logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" - ) + logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}") mean_loss = mean_train_loss if np.isnan(mean_loss): - logger.warning( - f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." - ) + logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.") if mean_loss < self.best_loss: self.best_epoch = epoch @@ -327,9 +313,7 @@ def _train_model( nni.report_final_result(self.best_loss) if self.patience == 0: - logger.info( - "Exceeded the training patience. Terminating the training procedure..." - ) + logger.info("Exceeded the training patience. Terminating the training procedure...") break except KeyboardInterrupt: # if keyboard interrupt, only warning @@ -350,9 +334,7 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info( - f"Finished training. The best model is from epoch#{self.best_epoch}." - ) + logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.") def fit( self, @@ -457,16 +439,10 @@ def func_to_apply( ) -> np.ndarray: # the covariance matrix is diagonal, so we can just take the product return np.log(1e-9 + phi_) + np.log( - 1e-9 - + multivariate_normal.pdf(mu_t_, mean=mu_, cov=np.diag(stddev_)) + 1e-9 + multivariate_normal.pdf(mu_t_, mean=mu_, cov=np.diag(stddev_)) ) - p = np.array( - [ - func_to_apply(mu_tilde, mu[i], var[i], phi[i]) - for i in np.arange(mu.shape[0]) - ] - ) + p = np.array([func_to_apply(mu_tilde, mu[i], var[i], phi[i]) for i in np.arange(mu.shape[0])]) clustering_results = np.argmax(p, axis=0) clustering_results_collector.append(clustering_results) diff --git a/pypots/data/checking.py b/pypots/data/checking.py index 4f0e7767..d807f983 100644 --- a/pypots/data/checking.py +++ b/pypots/data/checking.py @@ -35,6 +35,4 @@ def key_in_data_set(key: str, dataset: Union[str, dict]) -> bool: elif isinstance(dataset, dict): return key in dataset.keys() else: - raise TypeError( - f"dataset must be a str or a Python dictionary, but got {type(dataset)}" - ) + raise TypeError(f"dataset must be a str or a Python dictionary, but got {type(dataset)}") diff --git a/pypots/data/dataset/base.py b/pypots/data/dataset/base.py index b2cbbbf7..a9b309c3 100644 --- a/pypots/data/dataset/base.py +++ b/pypots/data/dataset/base.py @@ -112,9 +112,7 @@ def __init__( # open the file handle self.file_handle = self._open_file_handle() # check if X exists in the file - assert ( - "X" in self.file_handle.keys() - ), "The given dataset file doesn't contains X. Please double check." + assert "X" in self.file_handle.keys(), "The given dataset file doesn't contains X. Please double check." # check whether X_ori, X_pred, and y exist in the file if they are required if self.return_X_ori: assert ( @@ -125,18 +123,14 @@ def __init__( "X_pred" in self.file_handle.keys() ), "The given dataset file doesn't contains X_pred. Please double check." if self.return_y: - assert ( - "y" in self.file_handle.keys() - ), "The given dataset file doesn't contains y. Please double check." + assert "y" in self.file_handle.keys(), "The given dataset file doesn't contains y. Please double check." # set up the function fetch_data() to fetch data from file self.fetch_data = self._fetch_data_from_file else: # data from array # check if X exists in the dictionary - assert ( - "X" in self.data.keys() - ), "The given dataset dictionary doesn't contains X. Please double check." + assert "X" in self.data.keys(), "The given dataset dictionary doesn't contains X. Please double check." # check whether X_ori, X_pred, and y exist in the file if they are required if self.return_X_ori: assert ( @@ -147,17 +141,13 @@ def __init__( "X_pred" in self.data.keys() ), "The given dataset dictionary doesn't contains X_pred. Please double check." if self.return_y: - assert ( - "y" in self.data.keys() - ), "The given dataset dictionary doesn't contains y. Please double check." + assert "y" in self.data.keys(), "The given dataset dictionary doesn't contains y. Please double check." X = data["X"] X_ori = None if "X_ori" not in data.keys() else data["X_ori"] X_pred = None if "X_pred" not in data.keys() else data["X_pred"] y = None if "y" not in data.keys() else data["y"] - self.X, self.X_ori, self.X_pred, self.y = self._check_array_input( - X, X_ori, X_pred, y, "tensor" - ) + self.X, self.X_ori, self.X_pred, self.y = self._check_array_input(X, X_ori, X_pred, y, "tensor") if self.return_X_ori: # Only when X_ori is given and fixed, we fill the missing values in X here in advance. @@ -169,9 +159,7 @@ def __init__( self.indicating_mask = indicating_mask.to(torch.float32) if self.return_X_pred: - self.X_pred, self.X_pred_missing_mask = fill_and_get_mask_torch( - self.X_pred - ) + self.X_pred, self.X_pred_missing_mask = fill_and_get_mask_torch(self.X_pred) # set up the function fetch_data() to fetch data from array self.fetch_data = self._fetch_data_from_array @@ -295,8 +283,7 @@ def _check_array_input( # check the shape of X here X_shape = X.shape assert len(X_shape) == 3, ( - f"input should have 3 dimensions [n_samples, seq_len, n_features]," - f"but got X: {X_shape}" + f"input should have 3 dimensions [n_samples, seq_len, n_features]," f"but got X: {X_shape}" ) if X_ori is not None: X_ori = turn_data_into_specified_dtype(X_ori, out_dtype) @@ -313,9 +300,7 @@ def _check_array_input( ), f"X and X_pred must have the same number of samples, but got X: f{X.shape} and X_pred: {X_pred.shape}" if y is not None: - assert len(X) == len(y), ( - f"lengths of X and y must match, " f"but got f{len(X)} and {len(y)}" - ) + assert len(X) == len(y), f"lengths of X and y must match, " f"but got f{len(X)} and {len(y)}" y = turn_data_into_specified_dtype(y, out_dtype) y = y.to(torch.long) if out_dtype == "tensor" else y @@ -383,9 +368,7 @@ def _open_file_handle(self) -> h5py.File: "r", ) # set swmr=True if the h5 file need to be written into new content during reading except ImportError: - raise ImportError( - "h5py is missing and cannot be imported. Please install it first." - ) + raise ImportError("h5py is missing and cannot be imported. Please install it first.") except FileNotFoundError as e: raise FileNotFoundError(f"{e}") except OSError as e: diff --git a/pypots/data/generating.py b/pypots/data/generating.py index f50b5276..b9ab93f5 100644 --- a/pypots/data/generating.py +++ b/pypots/data/generating.py @@ -115,9 +115,7 @@ def gene_complete_random_walk_for_classification( std = 1 for c_ in range(n_classes): - ts_samples = gene_complete_random_walk( - n_samples_each_class, n_steps, n_features, mu, std, random_state - ) + ts_samples = gene_complete_random_walk(n_samples_each_class, n_steps, n_features, mu, std, random_state) label_samples = np.asarray([1 for _ in range(n_samples_each_class)]) * c_ ts_collector.extend(ts_samples) label_collector.extend(label_samples) @@ -186,12 +184,8 @@ def gene_complete_random_walk_for_anomaly_detection( y : array, shape of [n_samples] Labels indicating if time-series samples are anomalies. """ - assert ( - 0 < anomaly_proportion < 1 - ), f"anomaly_proportion should be >0 and <1, but got {anomaly_proportion}" - assert ( - 0 < anomaly_fraction < 1 - ), f"anomaly_fraction should be >0 and <1, but got {anomaly_fraction}" + assert 0 < anomaly_proportion < 1, f"anomaly_proportion should be >0 and <1, but got {anomaly_proportion}" + assert 0 < anomaly_fraction < 1, f"anomaly_fraction should be >0 and <1, but got {anomaly_fraction}" seed = check_random_state(random_state) X = seed.randn(n_samples, n_steps, n_features) * std + mu n_anomaly = math.floor(n_samples * anomaly_proportion) @@ -204,9 +198,7 @@ def gene_complete_random_walk_for_anomaly_detection( max_difference = min_val - max_val n_points = n_steps * n_features n_anomaly_points = int(n_points * anomaly_fraction) - point_indices = np.random.choice( - a=n_points, size=n_anomaly_points, replace=False - ) + point_indices = np.random.choice(a=n_points, size=n_anomaly_points, replace=False) for p_i in point_indices: anomaly_sample[p_i] = mu + np.random.uniform( low=min_val - anomaly_scale_factor * max_difference, @@ -304,9 +296,7 @@ def gene_random_walk( if missing_rate > 0: # mask values in the test set as ground truth - train_X_ori = scaler.transform(train_X_ori.reshape(-1, n_features)).reshape( - -1, n_steps, n_features - ) + train_X_ori = scaler.transform(train_X_ori.reshape(-1, n_features)).reshape(-1, n_steps, n_features) data["train_X_ori"] = train_X_ori val_X_ori = val_X @@ -324,9 +314,7 @@ def gene_random_walk( def gene_physionet2012(artificially_missing_rate: float = 0.1): - dataset_from_benchpots = preprocess_physionet2012( - subset="all", rate=artificially_missing_rate - ) + dataset_from_benchpots = preprocess_physionet2012(subset="all", rate=artificially_missing_rate) logger.warning( "🚨 Due to the full release of BenchPOTS package, " "gene_physionet2012() has been deprecated and will be removed in pypots v0.8" diff --git a/pypots/data/load_specific_datasets.py b/pypots/data/load_specific_datasets.py index 50c6c297..6a6a246a 100644 --- a/pypots/data/load_specific_datasets.py +++ b/pypots/data/load_specific_datasets.py @@ -55,9 +55,7 @@ def load_specific_dataset(dataset_name: str, use_cache: bool = True) -> dict: e.g. standardizing and splitting. """ - logger.info( - f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Data_Beans)..." - ) + logger.info(f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Data_Beans)...") assert dataset_name in SUPPORTED_DATASETS, ( f"Dataset {dataset_name} is not supported. " f"If you believe this dataset is valuable to be supported by PyPOTS," diff --git a/pypots/data/saving/h5.py b/pypots/data/saving/h5.py index 820a6c93..57717ac1 100644 --- a/pypots/data/saving/h5.py +++ b/pypots/data/saving/h5.py @@ -48,12 +48,8 @@ def save_set(handle, name, data): handle.create_dataset(name, data=data) # check typing - assert isinstance( - data_dict, dict - ), f"`data_dict` should be a Python dictionary, but got {type(data_dict)}" - assert isinstance( - saving_path, str - ), f"`saving_path` should be a string, but got {type(saving_path)}" + assert isinstance(data_dict, dict), f"`data_dict` should be a Python dictionary, but got {type(data_dict)}" + assert isinstance(saving_path, str), f"`saving_path` should be a string, but got {type(saving_path)}" if file_name is None: # if file_name is not given # check suffix @@ -64,9 +60,7 @@ def save_set(handle, name, data): ) else: # if file_name is given # check typing - assert isinstance( - file_name, str - ), f"`file_name` should be a string, but got {type(file_name)}." + assert isinstance(file_name, str), f"`file_name` should be a string, but got {type(file_name)}." # check suffix if not file_name.endswith(".h5") or file_name.endswith(".hdf5"): logger.warning( @@ -107,9 +101,7 @@ def load_dict_from_h5( The data loaded from the given h5 file. """ - assert isinstance( - file_path, str - ), f"`file_path` should be a string, but got {type(file_path)}." + assert isinstance(file_path, str), f"`file_path` should be a string, but got {type(file_path)}." assert os.path.exists(file_path), f"file_path {file_path} does not exist." def load_set(handle, datadict): diff --git a/pypots/data/saving/pickle.py b/pypots/data/saving/pickle.py index c8ef9129..8cac631a 100644 --- a/pypots/data/saving/pickle.py +++ b/pypots/data/saving/pickle.py @@ -34,9 +34,7 @@ def pickle_dump(data: object, path: str) -> None: pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info(f"Successfully saved to {path}") except Exception as e: - logger.error( - f"❌ Pickling failed. No cache data saved. Investigate the error below:\n{e}" - ) + logger.error(f"❌ Pickling failed. No cache data saved. Investigate the error below:\n{e}") return None @@ -59,9 +57,7 @@ def pickle_load(path: str) -> object: with open(path, "rb") as f: data = pickle.load(f) except Exception as e: - logger.error( - f"❌ Loading data failed. Operation aborted. Investigate the error below:\n{e}" - ) + logger.error(f"❌ Loading data failed. Operation aborted. Investigate the error below:\n{e}") return None return data diff --git a/pypots/data/utils.py b/pypots/data/utils.py index 7762ff7f..40919fa8 100644 --- a/pypots/data/utils.py +++ b/pypots/data/utils.py @@ -25,9 +25,7 @@ def turn_data_into_specified_dtype( elif isinstance(data, np.ndarray): data = torch.from_numpy(data) if dtype == "tensor" else data else: - raise TypeError( - f"data should be an instance of list/np.ndarray/torch.Tensor, but got {type(data)}" - ) + raise TypeError(f"data should be an instance of list/np.ndarray/torch.Tensor, but got {type(data)}") return data @@ -61,9 +59,7 @@ def cal_delta_for_single_sample(mask: torch.Tensor) -> torch.Tensor: d = [torch.zeros(1, n_features, device=device)] for step in range(1, n_steps): - d.append( - torch.ones(1, n_features, device=device) + (1 - mask[step - 1]) * d[-1] - ) + d.append(torch.ones(1, n_features, device=device) + (1 - mask[step - 1]) * d[-1]) d = torch.concat(d, dim=0) return d @@ -129,9 +125,7 @@ def cal_delta_for_single_sample(mask: np.ndarray) -> np.ndarray: return delta -def parse_delta( - missing_mask: Union[np.ndarray, torch.Tensor] -) -> Union[np.ndarray, torch.Tensor]: +def parse_delta(missing_mask: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: """Generate the time-gap matrix (i.e. the delta metrix) from the missing mask. Please refer to :cite:`che2018GRUD` for its math definition. diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 1cf41c1b..5113876d 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -330,15 +330,11 @@ def _train_model( ) mean_loss = mean_val_loss else: - logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" - ) + logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}") mean_loss = mean_train_loss if np.isnan(mean_loss): - logger.warning( - f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." - ) + logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.") if mean_loss < self.best_loss: self.best_epoch = epoch @@ -360,9 +356,7 @@ def _train_model( nni.report_final_result(self.best_loss) if self.patience == 0: - logger.info( - "Exceeded the training patience. Terminating the training procedure..." - ) + logger.info("Exceeded the training patience. Terminating the training procedure...") break except KeyboardInterrupt: # if keyboard interrupt, only warning @@ -383,9 +377,7 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info( - f"Finished training. The best model is from epoch#{self.best_epoch}." - ) + logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.") @abstractmethod def fit( diff --git a/pypots/forecasting/bttf/core.py b/pypots/forecasting/bttf/core.py index f0fcccc2..0a87e248 100644 --- a/pypots/forecasting/bttf/core.py +++ b/pypots/forecasting/bttf/core.py @@ -65,9 +65,7 @@ def _BTTF( U = sample_factor_u(tau_sparse_tensor, tau_ind, U, V, X) V = sample_factor_v(tau_sparse_tensor, tau_ind, U, V, X) A, Sigma = sample_var_coefficient(X, time_lags) - X = sample_factor_x( - tau_sparse_tensor, tau_ind, time_lags, U, V, X, A, inv(Sigma) - ) + X = sample_factor_x(tau_sparse_tensor, tau_ind, time_lags, U, V, X, A, inv(Sigma)) tensor_hat = np.einsum("is, js, ts -> ijt", U, V, X) tau = np.random.gamma( 1e-6 + 0.5 * np.sum(ind), @@ -99,9 +97,7 @@ def _BTTF( return tensor_hat, U_plus, V_plus, X_plus, A_plus, Sigma_plus, tau_plus -def sample_factor_x_partial( - tau_sparse_tensor, tau_ind, time_lags, U, V, X, A, Lambda_x, back_step -): +def sample_factor_x_partial(tau_sparse_tensor, tau_ind, time_lags, U, V, X, A, Lambda_x, back_step): """Sampling T-by-R factor matrix X.""" dim3, rank = X.shape @@ -117,9 +113,7 @@ def sample_factor_x_partial( var1 = kr_prod(V, U).T var2 = kr_prod(var1, var1) - var3 = (var2 @ ten2mat(tau_ind[:, :, -back_step:], 2).T).reshape( - [rank, rank, back_step] - ) + Lambda_x[:, :, None] + var3 = (var2 @ ten2mat(tau_ind[:, :, -back_step:], 2).T).reshape([rank, rank, back_step]) + Lambda_x[:, :, None] var4 = var1 @ ten2mat(tau_sparse_tensor[:, :, -back_step:], 2).T for t in range(dim3 - back_step, dim3): Mt = np.zeros((rank, rank)) @@ -135,9 +129,7 @@ def sample_factor_x_partial( for k in index: temp[:, n] = X[t + time_lags[k] - time_lags, :].reshape(rank * d) n += 1 - temp0 = X[t + time_lags[index], :].T - np.einsum( - "ijk, ik -> jk", A0[:, :, index], temp - ) + temp0 = X[t + time_lags[index], :].T - np.einsum("ijk, ik -> jk", A0[:, :, index], temp) Nt = np.einsum("kij, jk -> i", mat1[index, :, :], temp0) var3[:, :, t + back_step - dim3] = var3[:, :, t + back_step - dim3] + Mt X[t, :] = mvnrnd_pre( @@ -150,9 +142,7 @@ def sample_factor_x_partial( return X -def _BTTF_partial( - sparse_tensor, init, rank, time_lags, gibbs_iter, multi_step=1, gamma=10 -): +def _BTTF_partial(sparse_tensor, init, rank, time_lags, gibbs_iter, multi_step=1, gamma=10): """Bayesian Temporal Tensor Factorization, BTTF.""" dim1, dim2, dim3 = sparse_tensor.shape @@ -186,9 +176,7 @@ def _BTTF_partial( ) X0 = ar4cast(A_plus[:, :, it], X, Sigma_plus[:, :, it], time_lags, multi_step) X_new_plus[:, :, it] = X0 - tensor_new_plus += np.einsum( - "is, js, ts -> ijt", U_plus[:, :, it], V_plus[:, :, it], X0[-multi_step:, :] - ) + tensor_new_plus += np.einsum("is, js, ts -> ijt", U_plus[:, :, it], V_plus[:, :, it], X0[-multi_step:, :]) tensor_hat = tensor_new_plus / gibbs_iter tensor_hat[tensor_hat < 0] = 0 @@ -252,7 +240,5 @@ def BTTF_forecast( multi_step, gamma, ) - tensor_hat[:, :, t * multi_step : (t + 1) * multi_step] = tensor[ - :, :, -multi_step: - ] + tensor_hat[:, :, t * multi_step : (t + 1) * multi_step] = tensor[:, :, -multi_step:] return tensor_hat diff --git a/pypots/forecasting/bttf/submodules.py b/pypots/forecasting/bttf/submodules.py index 3a73408e..e1abb875 100644 --- a/pypots/forecasting/bttf/submodules.py +++ b/pypots/forecasting/bttf/submodules.py @@ -48,21 +48,14 @@ def sample_factor_u(tau_sparse_tensor, tau_ind, U, V, X, beta0=1): U_bar = np.mean(U, axis=0) temp = dim1 / (dim1 + beta0) var_mu_hyper = temp * U_bar - var_U_hyper = inv( - np.eye(rank) + cov_mat(U, U_bar) + temp * beta0 * np.outer(U_bar, U_bar) - ) + var_U_hyper = inv(np.eye(rank) + cov_mat(U, U_bar) + temp * beta0 * np.outer(U_bar, U_bar)) var_Lambda_hyper = wishart.rvs(df=dim1 + rank, scale=var_U_hyper) var_mu_hyper = mvnrnd_pre(var_mu_hyper, (dim1 + beta0) * var_Lambda_hyper) var1 = kr_prod(X, V).T var2 = kr_prod(var1, var1) - var3 = (var2 @ ten2mat(tau_ind, 0).T).reshape( - [rank, rank, dim1] - ) + var_Lambda_hyper[:, :, None] - var4 = ( - var1 @ ten2mat(tau_sparse_tensor, 0).T - + (var_Lambda_hyper @ var_mu_hyper)[:, None] - ) + var3 = (var2 @ ten2mat(tau_ind, 0).T).reshape([rank, rank, dim1]) + var_Lambda_hyper[:, :, None] + var4 = var1 @ ten2mat(tau_sparse_tensor, 0).T + (var_Lambda_hyper @ var_mu_hyper)[:, None] for i in range(dim1): U[i, :] = mvnrnd_pre(solve(var3[:, :, i], var4[:, i]), var3[:, :, i]) @@ -76,21 +69,14 @@ def sample_factor_v(tau_sparse_tensor, tau_ind, U, V, X, beta0=1): V_bar = np.mean(V, axis=0) temp = dim2 / (dim2 + beta0) var_mu_hyper = temp * V_bar - var_V_hyper = inv( - np.eye(rank) + cov_mat(V, V_bar) + temp * beta0 * np.outer(V_bar, V_bar) - ) + var_V_hyper = inv(np.eye(rank) + cov_mat(V, V_bar) + temp * beta0 * np.outer(V_bar, V_bar)) var_Lambda_hyper = wishart.rvs(df=dim2 + rank, scale=var_V_hyper) var_mu_hyper = mvnrnd_pre(var_mu_hyper, (dim2 + beta0) * var_Lambda_hyper) var1 = kr_prod(X, U).T var2 = kr_prod(var1, var1) - var3 = (var2 @ ten2mat(tau_ind, 1).T).reshape( - [rank, rank, dim2] - ) + var_Lambda_hyper[:, :, None] - var4 = ( - var1 @ ten2mat(tau_sparse_tensor, 1).T - + (var_Lambda_hyper @ var_mu_hyper)[:, None] - ) + var3 = (var2 @ ten2mat(tau_ind, 1).T).reshape([rank, rank, dim2]) + var_Lambda_hyper[:, :, None] + var4 = var1 @ ten2mat(tau_sparse_tensor, 1).T + (var_Lambda_hyper @ var_mu_hyper)[:, None] for j in range(dim2): V[j, :] = mvnrnd_pre(solve(var3[:, :, j], var4[:, j]), var3[:, :, j]) @@ -118,9 +104,7 @@ def sample_var_coefficient(X, time_lags): Z_mat = X[tmax:dim, :] Q_mat = np.zeros((dim - tmax, rank * d)) for k in range(d): - Q_mat[:, k * rank : (k + 1) * rank] = X[ - tmax - time_lags[k] : dim - time_lags[k], : - ] + Q_mat[:, k * rank : (k + 1) * rank] = X[tmax - time_lags[k] : dim - time_lags[k], :] var_Psi0 = np.eye(rank * d) + Q_mat.T @ Q_mat var_Psi = inv(var_Psi0) var_M = var_Psi @ Q_mat.T @ Z_mat @@ -146,9 +130,7 @@ def sample_factor_x(tau_sparse_tensor, tau_ind, time_lags, U, V, X, A, Lambda_x) var1 = kr_prod(V, U).T var2 = kr_prod(var1, var1) - var3 = (var2 @ ten2mat(tau_ind, 2).T).reshape([rank, rank, dim3]) + Lambda_x[ - :, :, None - ] + var3 = (var2 @ ten2mat(tau_ind, 2).T).reshape([rank, rank, dim3]) + Lambda_x[:, :, None] var4 = var1 @ ten2mat(tau_sparse_tensor, 2).T for t in range(dim3): Mt = np.zeros((rank, rank)) @@ -167,9 +149,7 @@ def sample_factor_x(tau_sparse_tensor, tau_ind, time_lags, U, V, X, A, Lambda_x) for k in index: temp[:, n] = X[t + time_lags[k] - time_lags, :].reshape(rank * d) n += 1 - temp0 = X[t + time_lags[index], :].T - np.einsum( - "ijk, ik -> jk", A0[:, :, index], temp - ) + temp0 = X[t + time_lags[index], :].T - np.einsum("ijk, ik -> jk", A0[:, :, index], temp) Nt = np.einsum("kij, jk -> i", mat1[index, :, :], temp0) var3[:, :, t] = var3[:, :, t] + Mt diff --git a/pypots/forecasting/csdi/core.py b/pypots/forecasting/csdi/core.py index e488cb20..4b497b89 100644 --- a/pypots/forecasting/csdi/core.py +++ b/pypots/forecasting/csdi/core.py @@ -59,9 +59,7 @@ def __init__( def time_embedding(pos, d_model=128): pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(pos.device) position = pos.unsqueeze(2) - div_term = 1 / torch.pow( - 10000.0, torch.arange(0, d_model, 2, device=pos.device) / d_model - ) + div_term = 1 / torch.pow(10000.0, torch.arange(0, d_model, 2, device=pos.device) / d_model) pe[:, :, 0::2] = torch.sin(position * div_term) pe[:, :, 1::2] = torch.cos(position * div_term) return pe @@ -69,25 +67,17 @@ def time_embedding(pos, d_model=128): def get_side_info(self, observed_tp, cond_mask, feature_id): B, K, L = cond_mask.shape device = observed_tp.device - time_embed = self.time_embedding( - observed_tp, self.d_time_embedding - ) # (B,L,emb) + time_embed = self.time_embedding(observed_tp, self.d_time_embedding) # (B,L,emb) time_embed = time_embed.to(device) time_embed = time_embed.unsqueeze(2).expand(-1, -1, self.n_pred_features, -1) if self.n_pred_features == self.n_features: - feature_embed = self.embed_layer( - torch.arange(self.n_pred_features).to(device) - ) # (K,emb) + feature_embed = self.embed_layer(torch.arange(self.n_pred_features).to(device)) # (K,emb) feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1) else: - feature_embed = ( - self.embed_layer(feature_id).unsqueeze(1).expand(-1, L, -1, -1) - ) + feature_embed = self.embed_layer(feature_id).unsqueeze(1).expand(-1, L, -1, -1) - side_info = torch.cat( - [time_embed, feature_embed], dim=-1 - ) # (B,L,K,emb+d_feature_embedding) + side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,emb+d_feature_embedding) side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L) if not self.is_unconditional: @@ -107,9 +97,7 @@ def forward(self, inputs, training=True, n_sampling_times=1): inputs["feature_id"], ) side_info = self.get_side_info(observed_tp, cond_mask, feature_id) - training_loss = self.backbone.calc_loss( - observed_data, cond_mask, indicating_mask, side_info, training - ) + training_loss = self.backbone.calc_loss(observed_data, cond_mask, indicating_mask, side_info, training) results["loss"] = training_loss elif not training and n_sampling_times == 0: # for validating (observed_data, indicating_mask, cond_mask, observed_tp, feature_id) = ( diff --git a/pypots/forecasting/csdi/data.py b/pypots/forecasting/csdi/data.py index 5f91f842..d10eb87b 100644 --- a/pypots/forecasting/csdi/data.py +++ b/pypots/forecasting/csdi/data.py @@ -77,9 +77,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: # apply specifically given mask or the hist masking strategy, rather than the random masking strategy if "for_pattern_mask" in self.data.keys(): - for_pattern_mask = torch.from_numpy(self.data["for_pattern_mask"][idx]).to( - torch.float32 - ) + for_pattern_mask = torch.from_numpy(self.data["for_pattern_mask"][idx]).to(torch.float32) else: previous_sample = self.X[idx - 1] for_pattern_mask = (~torch.isnan(previous_sample)).to(torch.float32) @@ -93,9 +91,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: observed_mask, feature_id, cond_mask, - ) = self.sample_features( - observed_data, observed_mask, feature_id, cond_mask - ) + ) = self.sample_features(observed_data, observed_mask, feature_id, cond_mask) X_pred = self.X_pred[idx] X_pred_missing_mask = self.X_pred_missing_mask[idx] @@ -103,9 +99,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: observed_data = torch.concat([observed_data, X_pred], dim=0) indicating_mask = torch.concat([indicating_mask, X_pred_missing_mask], dim=0) cond_mask = torch.concat([cond_mask, torch.zeros(X_pred.shape)], dim=0) - observed_tp = torch.arange( - 0, self.n_steps + self.n_pred_steps, dtype=torch.float32 - ) + observed_tp = torch.arange(0, self.n_steps + self.n_pred_steps, dtype=torch.float32) sample = [ torch.tensor(idx), @@ -161,13 +155,9 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: # apply specifically given mask or the hist masking strategy, rather than the random masking strategy if "for_pattern_mask" in self.file_handle.keys(): - for_pattern_mask = torch.from_numpy( - self.file_handle["for_pattern_mask"][idx] - ).to(torch.float32) + for_pattern_mask = torch.from_numpy(self.file_handle["for_pattern_mask"][idx]).to(torch.float32) else: - previous_sample = torch.from_numpy(self.file_handle["X"][idx - 1]).to( - torch.float32 - ) + previous_sample = torch.from_numpy(self.file_handle["X"][idx - 1]).to(torch.float32) for_pattern_mask = (~torch.isnan(previous_sample)).to(torch.float32) cond_mask = observed_mask * for_pattern_mask @@ -179,9 +169,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: observed_mask, feature_id, cond_mask, - ) = self.sample_features( - observed_data, observed_mask, feature_id, cond_mask - ) + ) = self.sample_features(observed_data, observed_mask, feature_id, cond_mask) X_pred = torch.from_numpy(self.file_handle["X_pred"][idx]).to(torch.float32) X_pred, X_pred_missing_mask = fill_and_get_mask_torch(X_pred) @@ -189,9 +177,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: observed_data = torch.concat([observed_data, X_pred], dim=0) indicating_mask = torch.concat([indicating_mask, X_pred_missing_mask], dim=0) cond_mask = torch.concat([cond_mask, torch.zeros(X_pred.shape)], dim=0) - observed_tp = torch.arange( - 0, self.n_steps + self.n_pred_steps, dtype=torch.float32 - ) + observed_tp = torch.arange(0, self.n_steps + self.n_pred_steps, dtype=torch.float32) sample = [ torch.tensor(idx), @@ -262,21 +248,15 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: observed_mask, feature_id, cond_mask, - ) = self.sample_features( - observed_data, observed_mask, feature_id, cond_mask - ) + ) = self.sample_features(observed_data, observed_mask, feature_id, cond_mask) observed_data = torch.concat( [observed_data, torch.zeros([self.n_pred_steps, self.n_pred_features])], dim=0, ) - cond_mask = torch.concat( - [cond_mask, torch.zeros([self.n_pred_steps, self.n_pred_features])], dim=0 - ) - observed_tp = torch.arange( - 0, self.n_steps + self.n_pred_steps, dtype=torch.float32 - ) + cond_mask = torch.concat([cond_mask, torch.zeros([self.n_pred_steps, self.n_pred_features])], dim=0) + observed_tp = torch.arange(0, self.n_steps + self.n_pred_steps, dtype=torch.float32) sample = [ torch.tensor(idx), @@ -333,21 +313,15 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: observed_mask, feature_id, cond_mask, - ) = self.sample_features( - observed_data, observed_mask, feature_id, cond_mask - ) + ) = self.sample_features(observed_data, observed_mask, feature_id, cond_mask) observed_data = torch.concat( [observed_data, torch.zeros([self.n_pred_steps, self.n_pred_features])], dim=0, ) - cond_mask = torch.concat( - [cond_mask, torch.zeros([self.n_pred_steps, self.n_pred_features])], dim=0 - ) - observed_tp = torch.arange( - 0, self.n_steps + self.n_pred_steps, dtype=torch.float32 - ) + cond_mask = torch.concat([cond_mask, torch.zeros([self.n_pred_steps, self.n_pred_features])], dim=0) + observed_tp = torch.arange(0, self.n_steps + self.n_pred_steps, dtype=torch.float32) feature_id = torch.arange(self.n_pred_features) diff --git a/pypots/forecasting/csdi/model.py b/pypots/forecasting/csdi/model.py index 734d3870..8492f87b 100644 --- a/pypots/forecasting/csdi/model.py +++ b/pypots/forecasting/csdi/model.py @@ -272,9 +272,7 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward( - inputs, training=False, n_sampling_times=0 - ) + results = self.model.forward(inputs, training=False, n_sampling_times=0) val_loss_collector.append(results["loss"].sum().item()) mean_val_loss = np.asarray(val_loss_collector).mean() @@ -293,15 +291,11 @@ def _train_model( ) mean_loss = mean_val_loss else: - logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" - ) + logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}") mean_loss = mean_train_loss if np.isnan(mean_loss): - logger.warning( - f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." - ) + logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.") if mean_loss < self.best_loss: self.best_epoch = epoch @@ -323,9 +317,7 @@ def _train_model( nni.report_final_result(self.best_loss) if self.patience == 0: - logger.info( - "Exceeded the training patience. Terminating the training procedure..." - ) + logger.info("Exceeded the training patience. Terminating the training procedure...") break except KeyboardInterrupt: # if keyboard interrupt, only warning @@ -346,9 +338,7 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info( - f"Finished training. The best model is from epoch#{self.best_epoch}." - ) + logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.") def fit( self, @@ -450,9 +440,7 @@ def predict( training=False, n_sampling_times=n_sampling_times, ) - forecasting_data = results["forecasting_data"][ - :, :, -self.n_pred_steps : - ] + forecasting_data = results["forecasting_data"][:, :, -self.n_pred_steps :] forecasting_collector.append(forecasting_data) # Step 3: output collection and return diff --git a/pypots/imputation/autoformer/core.py b/pypots/imputation/autoformer/core.py index fb883c4e..0f3bcc37 100644 --- a/pypots/imputation/autoformer/core.py +++ b/pypots/imputation/autoformer/core.py @@ -76,9 +76,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/autoformer/model.py b/pypots/imputation/autoformer/model.py index e102814b..38a044e5 100644 --- a/pypots/imputation/autoformer/model.py +++ b/pypots/imputation/autoformer/model.py @@ -209,9 +209,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForAutoformer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForAutoformer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -222,9 +220,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForAutoformer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForAutoformer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index 6ca8bcb2..1a20dc72 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -330,15 +330,11 @@ def _train_model( ) mean_loss = mean_val_loss else: - logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" - ) + logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}") mean_loss = mean_train_loss if np.isnan(mean_loss): - logger.warning( - f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." - ) + logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.") if mean_loss < self.best_loss: self.best_epoch = epoch @@ -360,9 +356,7 @@ def _train_model( nni.report_final_result(self.best_loss) if self.patience == 0: - logger.info( - "Exceeded the training patience. Terminating the training procedure..." - ) + logger.info("Exceeded the training patience. Terminating the training procedure...") break except KeyboardInterrupt: # if keyboard interrupt, only warning @@ -383,9 +377,7 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info( - f"Finished training. The best model is from epoch#{self.best_epoch}." - ) + logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.") @abstractmethod def fit( diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py index 0391d3e4..06ec6f4e 100644 --- a/pypots/imputation/brits/model.py +++ b/pypots/imputation/brits/model.py @@ -191,9 +191,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForBRITS( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForBRITS(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -204,9 +202,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForBRITS( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForBRITS(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, @@ -228,9 +224,7 @@ def predict( file_type: str = "hdf5", ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForBRITS( - test_set, return_X_ori=False, return_y=False, file_type=file_type - ) + test_set = DatasetForBRITS(test_set, return_X_ori=False, return_y=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, diff --git a/pypots/imputation/crossformer/core.py b/pypots/imputation/crossformer/core.py index e26f27ca..1832c5df 100644 --- a/pypots/imputation/crossformer/core.py +++ b/pypots/imputation/crossformer/core.py @@ -50,9 +50,7 @@ def __init__( pad_in_len - n_steps, 0, ) - self.enc_pos_embedding = nn.Parameter( - torch.randn(1, d_model, in_seg_num, d_model) - ) + self.enc_pos_embedding = nn.Parameter(torch.randn(1, d_model, in_seg_num, d_model)) self.pre_norm = nn.LayerNorm(d_model) # Encoder @@ -94,9 +92,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: input_X = self.saits_embedding(X, missing_mask) x_enc = self.enc_value_embedding(input_X.permute(0, 2, 1)) - x_enc = rearrange( - x_enc, "(b d) seg_num d_model -> b d seg_num d_model", d=self.d_model - ) + x_enc = rearrange(x_enc, "(b d) seg_num d_model -> b d seg_num d_model", d=self.d_model) x_enc += self.enc_pos_embedding # Crossformer processing @@ -115,9 +111,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/crossformer/model.py b/pypots/imputation/crossformer/model.py index 41ecabe0..5e8c3016 100644 --- a/pypots/imputation/crossformer/model.py +++ b/pypots/imputation/crossformer/model.py @@ -215,9 +215,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForCrossformer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForCrossformer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -228,9 +226,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForCrossformer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForCrossformer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/csdi/core.py b/pypots/imputation/csdi/core.py index a80acce3..639727b4 100644 --- a/pypots/imputation/csdi/core.py +++ b/pypots/imputation/csdi/core.py @@ -57,9 +57,7 @@ def __init__( def time_embedding(pos, d_model=128): pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(pos.device) position = pos.unsqueeze(2) - div_term = 1 / torch.pow( - 10000.0, torch.arange(0, d_model, 2, device=pos.device) / d_model - ) + div_term = 1 / torch.pow(10000.0, torch.arange(0, d_model, 2, device=pos.device) / d_model) pe[:, :, 0::2] = torch.sin(position * div_term) pe[:, :, 1::2] = torch.cos(position * div_term) return pe @@ -67,19 +65,13 @@ def time_embedding(pos, d_model=128): def get_side_info(self, observed_tp, cond_mask): B, K, L = cond_mask.shape device = observed_tp.device - time_embed = self.time_embedding( - observed_tp, self.d_time_embedding - ) # (B,L,emb) + time_embed = self.time_embedding(observed_tp, self.d_time_embedding) # (B,L,emb) time_embed = time_embed.to(device) time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1) - feature_embed = self.embed_layer( - torch.arange(self.n_features).to(device) - ) # (K,emb) + feature_embed = self.embed_layer(torch.arange(self.n_features).to(device)) # (K,emb) feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1) - side_info = torch.cat( - [time_embed, feature_embed], dim=-1 - ) # (B,L,K,emb+d_feature_embedding) + side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,emb+d_feature_embedding) side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L) if not self.is_unconditional: @@ -98,9 +90,7 @@ def forward(self, inputs, training=True, n_sampling_times=1): inputs["observed_tp"], ) side_info = self.get_side_info(observed_tp, cond_mask) - training_loss = self.backbone.calc_loss( - observed_data, cond_mask, indicating_mask, side_info, training - ) + training_loss = self.backbone.calc_loss(observed_data, cond_mask, indicating_mask, side_info, training) results["loss"] = training_loss elif not training and n_sampling_times == 0: # for validating (observed_data, indicating_mask, cond_mask, observed_tp) = ( diff --git a/pypots/imputation/csdi/data.py b/pypots/imputation/csdi/data.py index 491a8d0a..2738a977 100644 --- a/pypots/imputation/csdi/data.py +++ b/pypots/imputation/csdi/data.py @@ -106,16 +106,12 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: cond_mask = self.get_rand_mask(observed_mask) else: if "for_pattern_mask" in self.data.keys(): - for_pattern_mask = torch.from_numpy( - self.data["for_pattern_mask"][idx] - ).to(torch.float32) + for_pattern_mask = torch.from_numpy(self.data["for_pattern_mask"][idx]).to(torch.float32) else: previous_sample = self.X[idx - 1] for_pattern_mask = (~torch.isnan(previous_sample)).to(torch.float32) - cond_mask = self.get_hist_mask( - observed_mask, for_pattern_mask=for_pattern_mask - ) + cond_mask = self.get_hist_mask(observed_mask, for_pattern_mask=for_pattern_mask) indicating_mask = observed_mask - cond_mask observed_tp = ( @@ -172,42 +168,30 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: self.file_handle = self._open_file_handle() if self.return_X_ori: - observed_data = torch.from_numpy(self.file_handle["X_ori"][idx]).to( - torch.float32 - ) + observed_data = torch.from_numpy(self.file_handle["X_ori"][idx]).to(torch.float32) observed_data, observed_mask = fill_and_get_mask_torch(observed_data) X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) _, cond_mask = fill_and_get_mask_torch(X) indicating_mask = observed_mask - cond_mask else: - observed_data = torch.from_numpy(self.file_handle["X"][idx]).to( - torch.float32 - ) + observed_data = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) observed_data, observed_mask = fill_and_get_mask_torch(observed_data) if self.target_strategy == "random": cond_mask = self.get_rand_mask(observed_mask) else: if "for_pattern_mask" in self.data.keys(): - for_pattern_mask = torch.from_numpy( - self.file_handle["for_pattern_mask"][idx] - ).to(torch.float32) + for_pattern_mask = torch.from_numpy(self.file_handle["for_pattern_mask"][idx]).to(torch.float32) else: - previous_sample = torch.from_numpy( - self.file_handle["X"][idx - 1] - ).to(torch.float32) + previous_sample = torch.from_numpy(self.file_handle["X"][idx - 1]).to(torch.float32) for_pattern_mask = (~torch.isnan(previous_sample)).to(torch.float32) - cond_mask = self.get_hist_mask( - observed_mask, for_pattern_mask=for_pattern_mask - ) + cond_mask = self.get_hist_mask(observed_mask, for_pattern_mask=for_pattern_mask) indicating_mask = observed_mask - cond_mask observed_tp = ( torch.arange(0, self.n_steps, dtype=torch.float32) if "time_points" not in self.file_handle.keys() - else torch.from_numpy(self.file_handle["time_points"][idx]).to( - torch.float32 - ) + else torch.from_numpy(self.file_handle["time_points"][idx]).to(torch.float32) ) sample = [ @@ -321,9 +305,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: observed_tp = ( torch.arange(0, self.n_steps, dtype=torch.float32) if "time_points" not in self.file_handle.keys() - else torch.from_numpy(self.file_handle["time_points"][idx]).to( - torch.float32 - ) + else torch.from_numpy(self.file_handle["time_points"][idx]).to(torch.float32) ) sample = [ diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 7d7138e1..19c3ecfd 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -252,9 +252,7 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward( - inputs, training=False, n_sampling_times=0 - ) + results = self.model.forward(inputs, training=False, n_sampling_times=0) val_loss_collector.append(results["loss"].sum().item()) mean_val_loss = np.asarray(val_loss_collector).mean() @@ -273,15 +271,11 @@ def _train_model( ) mean_loss = mean_val_loss else: - logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" - ) + logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}") mean_loss = mean_train_loss if np.isnan(mean_loss): - logger.warning( - f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." - ) + logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.") if mean_loss < self.best_loss: self.best_epoch = epoch @@ -303,9 +297,7 @@ def _train_model( nni.report_final_result(self.best_loss) if self.patience == 0: - logger.info( - "Exceeded the training patience. Terminating the training procedure..." - ) + logger.info("Exceeded the training patience. Terminating the training procedure...") break except KeyboardInterrupt: # if keyboard interrupt, only warning @@ -326,9 +318,7 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info( - f"Finished training. The best model is from epoch#{self.best_epoch}." - ) + logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.") def fit( self, diff --git a/pypots/imputation/dlinear/core.py b/pypots/imputation/dlinear/core.py index 78d3bcbd..aa957bb2 100644 --- a/pypots/imputation/dlinear/core.py +++ b/pypots/imputation/dlinear/core.py @@ -36,12 +36,8 @@ def __init__( self.backbone = BackboneDLinear(n_steps, n_features, individual, d_model) if not individual: - self.seasonal_saits_embedding = SaitsEmbedding( - n_features * 2, d_model, with_pos=False - ) - self.trend_saits_embedding = SaitsEmbedding( - n_features * 2, d_model, with_pos=False - ) + self.seasonal_saits_embedding = SaitsEmbedding(n_features * 2, d_model, with_pos=False) + self.trend_saits_embedding = SaitsEmbedding(n_features * 2, d_model, with_pos=False) self.linear_seasonal_output = nn.Linear(d_model, n_features) self.linear_trend_output = nn.Linear(d_model, n_features) @@ -80,9 +76,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/dlinear/model.py b/pypots/imputation/dlinear/model.py index 1ba9fae6..ea65df87 100644 --- a/pypots/imputation/dlinear/model.py +++ b/pypots/imputation/dlinear/model.py @@ -186,9 +186,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForDLinear( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForDLinear(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -199,9 +197,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForDLinear( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForDLinear(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/etsformer/core.py b/pypots/imputation/etsformer/core.py index 92c61f5d..0793044f 100644 --- a/pypots/imputation/etsformer/core.py +++ b/pypots/imputation/etsformer/core.py @@ -100,9 +100,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/etsformer/model.py b/pypots/imputation/etsformer/model.py index dc19ba01..7ecb0c03 100644 --- a/pypots/imputation/etsformer/model.py +++ b/pypots/imputation/etsformer/model.py @@ -209,9 +209,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForETSformer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForETSformer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -222,9 +220,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForETSformer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForETSformer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/fedformer/core.py b/pypots/imputation/fedformer/core.py index 617a1462..061089be 100644 --- a/pypots/imputation/fedformer/core.py +++ b/pypots/imputation/fedformer/core.py @@ -80,9 +80,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/fedformer/model.py b/pypots/imputation/fedformer/model.py index 5dccaaa6..05d8e7cd 100644 --- a/pypots/imputation/fedformer/model.py +++ b/pypots/imputation/fedformer/model.py @@ -223,9 +223,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForFEDformer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForFEDformer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -236,9 +234,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForFEDformer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForFEDformer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/film/core.py b/pypots/imputation/film/core.py index 2e48f8c2..1c660f7f 100644 --- a/pypots/imputation/film/core.py +++ b/pypots/imputation/film/core.py @@ -71,9 +71,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/film/model.py b/pypots/imputation/film/model.py index ae2c1513..1f505e64 100644 --- a/pypots/imputation/film/model.py +++ b/pypots/imputation/film/model.py @@ -203,9 +203,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForFiLM( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForFiLM(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -216,9 +214,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForFiLM( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForFiLM(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/frets/core.py b/pypots/imputation/frets/core.py index 488880d9..1f6ac157 100644 --- a/pypots/imputation/frets/core.py +++ b/pypots/imputation/frets/core.py @@ -67,9 +67,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/frets/model.py b/pypots/imputation/frets/model.py index 4a667159..0fc730b7 100644 --- a/pypots/imputation/frets/model.py +++ b/pypots/imputation/frets/model.py @@ -185,9 +185,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForFreTS( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForFreTS(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -198,9 +196,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForFreTS( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForFreTS(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 85314b28..f8ff2193 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -150,9 +150,7 @@ def __init__( verbose, ) available_kernel_type = ["cauchy", "diffusion", "rbf", "matern"] - assert ( - kernel in available_kernel_type - ), f"kernel should be one of {available_kernel_type}, but got {kernel}" + assert kernel in available_kernel_type, f"kernel should be one of {available_kernel_type}, but got {kernel}" self.n_steps = n_steps self.n_features = n_features @@ -268,9 +266,7 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward( - inputs, training=False, n_sampling_times=1 - ) + results = self.model.forward(inputs, training=False, n_sampling_times=1) imputed_data = results["imputed_data"].mean(axis=1) imputation_mse = ( calc_mse( @@ -300,15 +296,11 @@ def _train_model( ) mean_loss = mean_val_loss else: - logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" - ) + logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}") mean_loss = mean_train_loss if np.isnan(mean_loss): - logger.warning( - f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." - ) + logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.") if mean_loss < self.best_loss: self.best_epoch = epoch @@ -330,9 +322,7 @@ def _train_model( nni.report_final_result(self.best_loss) if self.patience == 0: - logger.info( - "Exceeded the training patience. Terminating the training procedure..." - ) + logger.info("Exceeded the training patience. Terminating the training procedure...") break except KeyboardInterrupt: # if keyboard interrupt, only warning @@ -353,9 +343,7 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info( - f"Finished training. The best model is from epoch#{self.best_epoch}." - ) + logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.") def fit( self, @@ -364,9 +352,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForGPVAE( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForGPVAE(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -377,9 +363,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForGPVAE( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForGPVAE(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, @@ -430,9 +414,7 @@ def predict( assert n_sampling_times > 0, "n_sampling_times should be greater than 0." self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForGPVAE( - test_set, return_X_ori=False, return_y=False, file_type=file_type - ) + test_set = DatasetForGPVAE(test_set, return_X_ori=False, return_y=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -444,9 +426,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward( - inputs, training=False, n_sampling_times=n_sampling_times - ) + results = self.model.forward(inputs, training=False, n_sampling_times=n_sampling_times) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/imputation/grud/core.py b/pypots/imputation/grud/core.py index 08681e14..98f368e0 100644 --- a/pypots/imputation/grud/core.py +++ b/pypots/imputation/grud/core.py @@ -55,9 +55,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: empirical_mean = inputs["empirical_mean"] X_filledLOCF = inputs["X_filledLOCF"] - hidden_states, _ = self.backbone( - X, missing_mask, deltas, empirical_mean, X_filledLOCF - ) + hidden_states, _ = self.backbone(X, missing_mask, deltas, empirical_mean, X_filledLOCF) # project back the original data space reconstruction = self.output_projection(hidden_states) diff --git a/pypots/imputation/grud/data.py b/pypots/imputation/grud/data.py index 084ee738..6bfd829f 100644 --- a/pypots/imputation/grud/data.py +++ b/pypots/imputation/grud/data.py @@ -67,9 +67,7 @@ def __init__( self.X_filledLOCF = locf_torch(X) self.deltas = _parse_delta_torch(missing_mask) - self.empirical_mean = torch.sum(missing_mask * X, dim=[0, 1]) / torch.sum( - missing_mask, dim=[0, 1] - ) + self.empirical_mean = torch.sum(missing_mask * X, dim=[0, 1]) / torch.sum(missing_mask, dim=[0, 1]) # fill nan with 0, in case some features have no observations self.empirical_mean = torch.nan_to_num(self.empirical_mean, 0) @@ -144,9 +142,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: X_filledLOCF = locf_torch(X.unsqueeze(dim=0)).squeeze() X = torch.nan_to_num(X) deltas = _parse_delta_torch(missing_mask) - empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum( - missing_mask, dim=[0] - ) + empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum(missing_mask, dim=[0]) sample = [ torch.tensor(idx), diff --git a/pypots/imputation/grud/model.py b/pypots/imputation/grud/model.py index cc08cd6e..269888d0 100644 --- a/pypots/imputation/grud/model.py +++ b/pypots/imputation/grud/model.py @@ -178,9 +178,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForGRUD( - train_set, return_X_ori=False, file_type=file_type - ) + training_set = DatasetForGRUD(train_set, return_X_ori=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, diff --git a/pypots/imputation/imputeformer/core.py b/pypots/imputation/imputeformer/core.py index a7ed10e5..ceb81630 100644 --- a/pypots/imputation/imputeformer/core.py +++ b/pypots/imputation/imputeformer/core.py @@ -59,9 +59,7 @@ def __init__( self.d_ffn = d_ffn self.learnable_embedding = nn.init.xavier_uniform_( - nn.Parameter( - torch.empty(self.in_steps, self.n_nodes, self.learnable_embedding_dim) - ) + nn.Parameter(torch.empty(self.in_steps, self.n_nodes, self.learnable_embedding_dim)) ) self.readout = MLP(self.model_dim, self.model_dim, output_dim, n_layers=2) @@ -109,12 +107,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict: x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) # Learnable node embedding - node_emb = self.learnable_embedding.expand( - batch_size, *self.learnable_embedding.shape - ) - x = torch.cat( - [x, node_emb], dim=-1 - ) # (batch_size, in_steps, num_nodes, model_dim) + node_emb = self.learnable_embedding.expand(batch_size, *self.learnable_embedding.shape) + x = torch.cat([x, node_emb], dim=-1) # (batch_size, in_steps, num_nodes, model_dim) # Spatial and temporal processing with customized attention layers x = x.permute(0, 2, 1, 3) # [b n s c] @@ -140,9 +134,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/imputeformer/model.py b/pypots/imputation/imputeformer/model.py index 04a267bb..92daf873 100644 --- a/pypots/imputation/imputeformer/model.py +++ b/pypots/imputation/imputeformer/model.py @@ -228,9 +228,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForImputeFormer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForImputeFormer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -241,9 +239,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForImputeFormer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForImputeFormer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, @@ -318,8 +314,6 @@ def impute( array-like, shape [n_samples, sequence length (time steps), n_features], Imputed data. """ - logger.warning( - "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." - ) + logger.warning("🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead.") results_dict = self.predict(X, file_type=file_type) return results_dict["imputation"] diff --git a/pypots/imputation/informer/core.py b/pypots/imputation/informer/core.py index e9199b02..60647e51 100644 --- a/pypots/imputation/informer/core.py +++ b/pypots/imputation/informer/core.py @@ -93,9 +93,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/informer/model.py b/pypots/imputation/informer/model.py index 040e6e68..07788534 100644 --- a/pypots/imputation/informer/model.py +++ b/pypots/imputation/informer/model.py @@ -203,9 +203,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForInformer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForInformer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -216,9 +214,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForInformer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForInformer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/itransformer/core.py b/pypots/imputation/itransformer/core.py index 5747f12e..87d429d1 100644 --- a/pypots/imputation/itransformer/core.py +++ b/pypots/imputation/itransformer/core.py @@ -35,9 +35,7 @@ def __init__( self.ORT_weight = ORT_weight self.MIT_weight = MIT_weight - self.saits_embedding = SaitsEmbedding( - n_steps, d_model, with_pos=False, dropout=dropout - ) + self.saits_embedding = SaitsEmbedding(n_steps, d_model, with_pos=False, dropout=dropout) self.encoder = TransformerEncoder( n_layers, d_model, @@ -81,9 +79,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/itransformer/model.py b/pypots/imputation/itransformer/model.py index a1022c90..46774670 100644 --- a/pypots/imputation/itransformer/model.py +++ b/pypots/imputation/itransformer/model.py @@ -154,9 +154,7 @@ def __init__( f"and the result should be equal to d_k, but got d_model={d_model}, n_heads={n_heads}, d_k={d_k}" ) d_model = n_heads * d_k - logger.warning( - f"⚠️ d_model is reset to {d_model} = n_heads ({n_heads}) * d_k ({d_k})" - ) + logger.warning(f"⚠️ d_model is reset to {d_model} = n_heads ({n_heads}) * d_k ({d_k})") self.n_steps = n_steps self.n_features = n_features @@ -232,9 +230,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForiTransformer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForiTransformer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -245,9 +241,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForiTransformer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForiTransformer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, @@ -322,8 +316,6 @@ def impute( array-like, shape [n_samples, sequence length (time steps), n_features], Imputed data. """ - logger.warning( - "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." - ) + logger.warning("🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead.") results_dict = self.predict(X, file_type=file_type) return results_dict["imputation"] diff --git a/pypots/imputation/koopa/core.py b/pypots/imputation/koopa/core.py index 219c5818..39fd036f 100644 --- a/pypots/imputation/koopa/core.py +++ b/pypots/imputation/koopa/core.py @@ -75,9 +75,7 @@ def forward( # if in training mode, return results with losses if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/koopa/model.py b/pypots/imputation/koopa/model.py index cbd29fc3..60cbc482 100644 --- a/pypots/imputation/koopa/model.py +++ b/pypots/imputation/koopa/model.py @@ -149,9 +149,7 @@ def __init__( self.multistep = multistep self.alpha = alpha - assert ( - math.ceil(n_steps / n_seg_steps) > 1 - ), "n_seg_steps should be smaller than n_steps." + assert math.ceil(n_steps / n_seg_steps) > 1, "n_seg_steps should be smaller than n_steps." self.ORT_weight = ORT_weight self.MIT_weight = MIT_weight @@ -215,9 +213,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForKoopa( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForKoopa(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -228,9 +224,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForKoopa( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForKoopa(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/lerp/model.py b/pypots/imputation/lerp/model.py index ffdd60db..5b0d1a5b 100644 --- a/pypots/imputation/lerp/model.py +++ b/pypots/imputation/lerp/model.py @@ -42,10 +42,7 @@ def fit( Linear interpolation class does not need to run fit(). Please run func ``predict()`` directly. """ - warnings.warn( - "Linear interpolation class has no parameter to train. " - "Please run func `predict()` directly." - ) + warnings.warn("Linear interpolation class has no parameter to train. Please run func `predict()` directly.") def predict( self, diff --git a/pypots/imputation/locf/model.py b/pypots/imputation/locf/model.py index d20ebcfc..f2b9729a 100644 --- a/pypots/imputation/locf/model.py +++ b/pypots/imputation/locf/model.py @@ -115,9 +115,7 @@ def predict( elif isinstance(X, torch.Tensor): imputed_data = locf_torch(X, self.first_step_imputation) else: - raise TypeError( - "X must be type of list/np.ndarray/torch.Tensor, " f"but got {type(X)}" - ) + raise TypeError("X must be type of list/np.ndarray/torch.Tensor, " f"but got {type(X)}") result_dict = { "imputation": imputed_data, diff --git a/pypots/imputation/mean/model.py b/pypots/imputation/mean/model.py index 129f15ec..70731a07 100644 --- a/pypots/imputation/mean/model.py +++ b/pypots/imputation/mean/model.py @@ -38,10 +38,7 @@ def fit( Please run func ``predict()`` directly. """ - warnings.warn( - "Mean imputation class has no parameter to train. " - "Please run func `predict()` directly." - ) + warnings.warn("Mean imputation class has no parameter to train. Please run func `predict()` directly.") def predict( self, @@ -90,17 +87,13 @@ def predict( X_imputed_reshaped = np.copy(X).reshape(-1, n_features) mean_values = np.nanmean(X_imputed_reshaped, axis=0) for i, v in enumerate(mean_values): - X_imputed_reshaped[:, i] = np.nan_to_num( - X_imputed_reshaped[:, i], nan=v - ) + X_imputed_reshaped[:, i] = np.nan_to_num(X_imputed_reshaped[:, i], nan=v) imputed_data = X_imputed_reshaped.reshape(n_samples, n_steps, n_features) elif isinstance(X, torch.Tensor): X_imputed_reshaped = torch.clone(X).reshape(-1, n_features) mean_values = torch.nanmean(X_imputed_reshaped, dim=0).numpy() for i, v in enumerate(mean_values): - X_imputed_reshaped[:, i] = torch.nan_to_num( - X_imputed_reshaped[:, i], nan=v - ) + X_imputed_reshaped[:, i] = torch.nan_to_num(X_imputed_reshaped[:, i], nan=v) imputed_data = X_imputed_reshaped.reshape(n_samples, n_steps, n_features) else: raise ValueError() diff --git a/pypots/imputation/median/model.py b/pypots/imputation/median/model.py index ffa315e4..76c56412 100644 --- a/pypots/imputation/median/model.py +++ b/pypots/imputation/median/model.py @@ -38,10 +38,7 @@ def fit( Please run func ``predict()`` directly. """ - warnings.warn( - "Median imputation class has no parameter to train. " - "Please run func `predict()` directly." - ) + warnings.warn("Median imputation class has no parameter to train. Please run func `predict()` directly.") def predict( self, @@ -90,17 +87,13 @@ def predict( X_imputed_reshaped = np.copy(X).reshape(-1, n_features) median_values = np.nanmedian(X_imputed_reshaped, axis=0) for i, v in enumerate(median_values): - X_imputed_reshaped[:, i] = np.nan_to_num( - X_imputed_reshaped[:, i], nan=v - ) + X_imputed_reshaped[:, i] = np.nan_to_num(X_imputed_reshaped[:, i], nan=v) imputed_data = X_imputed_reshaped.reshape(n_samples, n_steps, n_features) elif isinstance(X, torch.Tensor): X_imputed_reshaped = torch.clone(X).reshape(-1, n_features) median_values = torch.nanmedian(X_imputed_reshaped, dim=0).values.numpy() for i, v in enumerate(median_values): - X_imputed_reshaped[:, i] = torch.nan_to_num( - X_imputed_reshaped[:, i], nan=v - ) + X_imputed_reshaped[:, i] = torch.nan_to_num(X_imputed_reshaped[:, i], nan=v) imputed_data = X_imputed_reshaped.reshape(n_samples, n_steps, n_features) else: diff --git a/pypots/imputation/micn/core.py b/pypots/imputation/micn/core.py index a37cbaf8..11bfa394 100644 --- a/pypots/imputation/micn/core.py +++ b/pypots/imputation/micn/core.py @@ -84,9 +84,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/micn/model.py b/pypots/imputation/micn/model.py index 56069338..edfa8d3d 100644 --- a/pypots/imputation/micn/model.py +++ b/pypots/imputation/micn/model.py @@ -197,9 +197,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForMICN( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForMICN(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -210,9 +208,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForMICN( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForMICN(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/moderntcn/model.py b/pypots/imputation/moderntcn/model.py index 2efb3fed..e408f5eb 100644 --- a/pypots/imputation/moderntcn/model.py +++ b/pypots/imputation/moderntcn/model.py @@ -227,9 +227,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForModernTCN( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForModernTCN(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -240,9 +238,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForModernTCN( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForModernTCN(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py index 86bfcc10..40f8dcac 100644 --- a/pypots/imputation/mrnn/model.py +++ b/pypots/imputation/mrnn/model.py @@ -193,9 +193,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForMRNN( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForMRNN(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -206,9 +204,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForMRNN( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForMRNN(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, @@ -230,9 +226,7 @@ def predict( file_type: str = "hdf5", ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForMRNN( - test_set, return_X_ori=False, return_y=False, file_type=file_type - ) + test_set = DatasetForMRNN(test_set, return_X_ori=False, return_y=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, diff --git a/pypots/imputation/nonstationary_transformer/core.py b/pypots/imputation/nonstationary_transformer/core.py index 80a12346..90a50b12 100644 --- a/pypots/imputation/nonstationary_transformer/core.py +++ b/pypots/imputation/nonstationary_transformer/core.py @@ -100,9 +100,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/nonstationary_transformer/model.py b/pypots/imputation/nonstationary_transformer/model.py index 1d662967..814cff3d 100644 --- a/pypots/imputation/nonstationary_transformer/model.py +++ b/pypots/imputation/nonstationary_transformer/model.py @@ -329,9 +329,7 @@ def impute( array-like, shape [n_samples, sequence length (time steps), n_features], Imputed data. """ - logger.warning( - "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." - ) + logger.warning("🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead.") results_dict = self.predict(X, file_type=file_type) return results_dict["imputation"] diff --git a/pypots/imputation/patchtst/core.py b/pypots/imputation/patchtst/core.py index 9a356173..532d43d1 100644 --- a/pypots/imputation/patchtst/core.py +++ b/pypots/imputation/patchtst/core.py @@ -36,9 +36,7 @@ def __init__( padding = stride self.saits_embedding = SaitsEmbedding(n_features * 2, d_model, with_pos=False) - self.patch_embedding = PatchEmbedding( - d_model, patch_len, stride, padding, dropout - ) + self.patch_embedding = PatchEmbedding(d_model, patch_len, stride, padding, dropout) self.encoder = PatchtstEncoder( n_layers, d_model, @@ -64,9 +62,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: input_X = self.saits_embedding(X, missing_mask) # do patch embedding - enc_out = self.patch_embedding( - input_X.permute(0, 2, 1) - ) # [bz * d_model, n_patches, d_model] + enc_out = self.patch_embedding(input_X.permute(0, 2, 1)) # [bz * d_model, n_patches, d_model] # PatchTST encoder processing enc_out, attns = self.encoder(enc_out) @@ -82,9 +78,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/patchtst/model.py b/pypots/imputation/patchtst/model.py index b6aff0ff..81d09fc7 100644 --- a/pypots/imputation/patchtst/model.py +++ b/pypots/imputation/patchtst/model.py @@ -159,9 +159,7 @@ def __init__( f"and the result should be equal to d_k, but got d_model={d_model}, n_heads={n_heads}, d_k={d_k}" ) d_model = n_heads * d_k - logger.warning( - f"⚠️ d_model is reset to {d_model} = n_heads ({n_heads}) * d_k ({d_k})" - ) + logger.warning(f"⚠️ d_model is reset to {d_model} = n_heads ({n_heads}) * d_k ({d_k})") self.n_steps = n_steps self.n_features = n_features @@ -241,9 +239,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForPatchTST( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForPatchTST(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -254,9 +250,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForPatchTST( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForPatchTST(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/pyraformer/core.py b/pypots/imputation/pyraformer/core.py index cc0fdf1c..be65c639 100644 --- a/pypots/imputation/pyraformer/core.py +++ b/pypots/imputation/pyraformer/core.py @@ -75,9 +75,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/pyraformer/model.py b/pypots/imputation/pyraformer/model.py index 5d4e6ac9..576e7c87 100644 --- a/pypots/imputation/pyraformer/model.py +++ b/pypots/imputation/pyraformer/model.py @@ -215,9 +215,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForPyraformer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForPyraformer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -228,9 +226,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForPyraformer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForPyraformer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/reformer/core.py b/pypots/imputation/reformer/core.py index c1c70fe4..ec55c7ad 100644 --- a/pypots/imputation/reformer/core.py +++ b/pypots/imputation/reformer/core.py @@ -77,9 +77,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/reformer/model.py b/pypots/imputation/reformer/model.py index 47c21664..76b23cb4 100644 --- a/pypots/imputation/reformer/model.py +++ b/pypots/imputation/reformer/model.py @@ -216,9 +216,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForReformer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForReformer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -229,9 +227,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForReformer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForReformer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/revinscinet/core.py b/pypots/imputation/revinscinet/core.py index 75fe4652..16d199d3 100644 --- a/pypots/imputation/revinscinet/core.py +++ b/pypots/imputation/revinscinet/core.py @@ -82,9 +82,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/revinscinet/model.py b/pypots/imputation/revinscinet/model.py index 65a20a9f..20a78807 100644 --- a/pypots/imputation/revinscinet/model.py +++ b/pypots/imputation/revinscinet/model.py @@ -221,9 +221,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForRevINSCINet( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForRevINSCINet(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -234,9 +232,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForRevINSCINet( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForRevINSCINet(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/saits/core.py b/pypots/imputation/saits/core.py index f5189ab3..a1dec185 100644 --- a/pypots/imputation/saits/core.py +++ b/pypots/imputation/saits/core.py @@ -64,9 +64,7 @@ def forward( X, missing_mask = inputs["X"], inputs["missing_mask"] # determine the attention mask - if (training and self.diagonal_attention_mask) or ( - (not training) and diagonal_attention_mask - ): + if (training and self.diagonal_attention_mask) or ((not training) and diagonal_attention_mask): diagonal_attention_mask = (1 - torch.eye(self.n_steps)).to(X.device) # then broadcast on the batch axis diagonal_attention_mask = diagonal_attention_mask.unsqueeze(0) @@ -109,9 +107,7 @@ def forward( ORT_loss = self.ORT_weight * ORT_loss # calculate loss for the masked imputation task (MIT) - MIT_loss = self.MIT_weight * self.customized_loss_func( - X_tilde_3, X_ori, indicating_mask - ) + MIT_loss = self.MIT_weight * self.customized_loss_func(X_tilde_3, X_ori, indicating_mask) # `loss` is always the item for backward propagating to update the model loss = ORT_loss + MIT_loss diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index 8c4ce9f4..cecb3cbe 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -163,9 +163,7 @@ def __init__( f"and the result should be equal to d_k, but got d_model={d_model}, n_heads={n_heads}, d_k={d_k}" ) d_model = n_heads * d_k - logger.warning( - f"⚠️ d_model is reset to {d_model} = n_heads ({n_heads}) * d_k ({d_k})" - ) + logger.warning(f"⚠️ d_model is reset to {d_model} = n_heads ({n_heads}) * d_k ({d_k})") self.n_steps = n_steps self.n_features = n_features @@ -245,9 +243,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForSAITS( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForSAITS(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -258,9 +254,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForSAITS( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForSAITS(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, @@ -336,21 +330,13 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward( - inputs, diagonal_attention_mask, training=False - ) + results = self.model.forward(inputs, diagonal_attention_mask, training=False) imputation_collector.append(results["imputed_data"]) if return_latent_vars: - first_DMSA_attn_weights_collector.append( - results["first_DMSA_attn_weights"].cpu().numpy() - ) - second_DMSA_attn_weights_collector.append( - results["second_DMSA_attn_weights"].cpu().numpy() - ) - combining_weights_collector.append( - results["combining_weights"].cpu().numpy() - ) + first_DMSA_attn_weights_collector.append(results["first_DMSA_attn_weights"].cpu().numpy()) + second_DMSA_attn_weights_collector.append(results["second_DMSA_attn_weights"].cpu().numpy()) + combining_weights_collector.append(results["combining_weights"].cpu().numpy()) # Step 3: output collection and return imputation = torch.cat(imputation_collector).cpu().detach().numpy() @@ -360,12 +346,8 @@ def predict( if return_latent_vars: latent_var_collector = { - "first_DMSA_attn_weights": np.concatenate( - first_DMSA_attn_weights_collector - ), - "second_DMSA_attn_weights": np.concatenate( - second_DMSA_attn_weights_collector - ), + "first_DMSA_attn_weights": np.concatenate(first_DMSA_attn_weights_collector), + "second_DMSA_attn_weights": np.concatenate(second_DMSA_attn_weights_collector), "combining_weights": np.concatenate(combining_weights_collector), } result_dict["latent_vars"] = latent_var_collector diff --git a/pypots/imputation/scinet/core.py b/pypots/imputation/scinet/core.py index df857c15..4d2b02a1 100644 --- a/pypots/imputation/scinet/core.py +++ b/pypots/imputation/scinet/core.py @@ -78,9 +78,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/scinet/model.py b/pypots/imputation/scinet/model.py index 525a53c1..86caceb8 100644 --- a/pypots/imputation/scinet/model.py +++ b/pypots/imputation/scinet/model.py @@ -223,9 +223,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForSCINet( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForSCINet(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -236,9 +234,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForSCINet( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForSCINet(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/stemgnn/core.py b/pypots/imputation/stemgnn/core.py index ac730259..d8d51efb 100644 --- a/pypots/imputation/stemgnn/core.py +++ b/pypots/imputation/stemgnn/core.py @@ -71,9 +71,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/stemgnn/model.py b/pypots/imputation/stemgnn/model.py index ea9a109a..743ed3d5 100644 --- a/pypots/imputation/stemgnn/model.py +++ b/pypots/imputation/stemgnn/model.py @@ -197,9 +197,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForStemGNN( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForStemGNN(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -210,9 +208,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForStemGNN( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForStemGNN(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/tcn/core.py b/pypots/imputation/tcn/core.py index 07be25b6..c38390b5 100644 --- a/pypots/imputation/tcn/core.py +++ b/pypots/imputation/tcn/core.py @@ -70,9 +70,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/tcn/model.py b/pypots/imputation/tcn/model.py index b51cceb9..8c01981f 100644 --- a/pypots/imputation/tcn/model.py +++ b/pypots/imputation/tcn/model.py @@ -191,9 +191,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForTCN( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForTCN(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -204,9 +202,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForTCN( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForTCN(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/tefn/model.py b/pypots/imputation/tefn/model.py index 8fdb57bf..ff30eca5 100644 --- a/pypots/imputation/tefn/model.py +++ b/pypots/imputation/tefn/model.py @@ -163,9 +163,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForTEFN( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForTEFN(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -176,9 +174,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForTEFN( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForTEFN(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/tide/core.py b/pypots/imputation/tide/core.py index 4d8f4e47..e826cbeb 100644 --- a/pypots/imputation/tide/core.py +++ b/pypots/imputation/tide/core.py @@ -114,9 +114,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/tide/model.py b/pypots/imputation/tide/model.py index 6e2bb3e1..949b15fe 100644 --- a/pypots/imputation/tide/model.py +++ b/pypots/imputation/tide/model.py @@ -203,9 +203,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForTiDE( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForTiDE(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -216,9 +214,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForTiDE( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForTiDE(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/timemixer/model.py b/pypots/imputation/timemixer/model.py index b79280a0..5e274d7f 100644 --- a/pypots/imputation/timemixer/model.py +++ b/pypots/imputation/timemixer/model.py @@ -228,9 +228,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForTimeMixer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForTimeMixer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -241,9 +239,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForTimeMixer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForTimeMixer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/timesnet/model.py b/pypots/imputation/timesnet/model.py index 40ca2d87..e3029e93 100644 --- a/pypots/imputation/timesnet/model.py +++ b/pypots/imputation/timesnet/model.py @@ -199,9 +199,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForTimesNet( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForTimesNet(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -212,9 +210,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForTimesNet( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForTimesNet(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/transformer/core.py b/pypots/imputation/transformer/core.py index e769a3aa..3e2a1350 100644 --- a/pypots/imputation/transformer/core.py +++ b/pypots/imputation/transformer/core.py @@ -78,9 +78,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] - loss, ORT_loss, MIT_loss = self.saits_loss_func( - reconstruction, X_ori, missing_mask, indicating_mask - ) + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss results["MIT_loss"] = MIT_loss # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index f601c33f..33eefee1 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -155,9 +155,7 @@ def __init__( f"and the result should be equal to d_k, but got d_model={d_model}, n_heads={n_heads}, d_k={d_k}" ) d_model = n_heads * d_k - logger.warning( - f"⚠️ d_model is reset to {d_model} = n_heads ({n_heads}) * d_k ({d_k})" - ) + logger.warning(f"⚠️ d_model is reset to {d_model} = n_heads ({n_heads}) * d_k ({d_k})") self.n_steps = n_steps self.n_features = n_features @@ -233,9 +231,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForTransformer( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForTransformer(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -246,9 +242,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForTransformer( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForTransformer(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, diff --git a/pypots/imputation/usgan/core.py b/pypots/imputation/usgan/core.py index a0b04ade..8ab19224 100644 --- a/pypots/imputation/usgan/core.py +++ b/pypots/imputation/usgan/core.py @@ -47,9 +47,7 @@ def forward( results = {} if training: if training_object == "discriminator": - imputed_data, discrimination_loss = self.backbone( - inputs, training_object, training - ) + imputed_data, discrimination_loss = self.backbone(inputs, training_object, training) loss = discrimination_loss else: imputed_data, generation_loss = self.backbone( diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index f29eaced..e329fdf0 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -243,21 +243,15 @@ def _train_model( if idx % self.G_steps == 0: self.G_optimizer.zero_grad() - results = self.model.forward( - inputs, training_object="generator" - ) + results = self.model.forward(inputs, training_object="generator") results["loss"].backward() # generation loss self.G_optimizer.step() step_train_loss_G_collector.append(results["loss"].item()) if idx % self.D_steps == 0: self.D_optimizer.zero_grad() - results = self.model.forward( - inputs, training_object="discriminator" - ) - results["loss"].backward( - retain_graph=True - ) # discrimination loss + results = self.model.forward(inputs, training_object="discriminator") + results["loss"].backward(retain_graph=True) # discrimination loss self.D_optimizer.step() step_train_loss_D_collector.append(results["loss"].item()) @@ -272,9 +266,7 @@ def _train_model( "generation_loss": mean_step_train_G_loss, "discrimination_loss": mean_step_train_D_loss, } - self._save_log_into_tb_file( - training_step, "training", loss_results - ) + self._save_log_into_tb_file(training_step, "training", loss_results) mean_epoch_train_D_loss = np.mean(step_train_loss_D_collector) mean_epoch_train_G_loss = np.mean(step_train_loss_G_collector) @@ -320,9 +312,7 @@ def _train_model( mean_loss = mean_epoch_train_G_loss if np.isnan(mean_loss): - logger.warning( - f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." - ) + logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.") if mean_loss < self.best_loss: self.best_epoch = epoch @@ -344,9 +334,7 @@ def _train_model( nni.report_final_result(self.best_loss) if self.patience == 0: - logger.info( - "Exceeded the training patience. Terminating the training procedure..." - ) + logger.info("Exceeded the training patience. Terminating the training procedure...") break except KeyboardInterrupt: # if keyboard interrupt, only warning @@ -367,9 +355,7 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info( - f"Finished training. The best model is from epoch#{self.best_epoch}." - ) + logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.") def fit( self, @@ -378,9 +364,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForUSGAN( - train_set, return_X_ori=False, return_y=False, file_type=file_type - ) + training_set = DatasetForUSGAN(train_set, return_X_ori=False, return_y=False, file_type=file_type) training_loader = DataLoader( training_set, batch_size=self.batch_size, @@ -391,9 +375,7 @@ def fit( if val_set is not None: if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") - val_set = DatasetForUSGAN( - val_set, return_X_ori=True, return_y=False, file_type=file_type - ) + val_set = DatasetForUSGAN(val_set, return_X_ori=True, return_y=False, file_type=file_type) val_loader = DataLoader( val_set, batch_size=self.batch_size, @@ -415,9 +397,7 @@ def predict( file_type: str = "hdf5", ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForUSGAN( - test_set, return_X_ori=False, return_y=False, file_type=file_type - ) + test_set = DatasetForUSGAN(test_set, return_X_ori=False, return_y=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, diff --git a/pypots/nn/modules/autoformer/layers.py b/pypots/nn/modules/autoformer/layers.py index b8daa873..1560d7a1 100644 --- a/pypots/nn/modules/autoformer/layers.py +++ b/pypots/nn/modules/autoformer/layers.py @@ -55,11 +55,7 @@ def time_delay_agg_training(self, values, corr): for i in range(top_k): pattern = torch.roll(tmp_values, -int(index[i]), -1) delays_agg = delays_agg + pattern * ( - tmp_corr[:, i] - .unsqueeze(1) - .unsqueeze(1) - .unsqueeze(1) - .repeat(1, head, channel, length) + tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) ) return delays_agg @@ -91,16 +87,10 @@ def time_delay_agg_inference(self, values, corr): tmp_values = values.repeat(1, 1, 1, 2) delays_agg = torch.zeros_like(values).float() for i in range(top_k): - tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze( - 1 - ).repeat(1, head, channel, length) + tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) delays_agg = delays_agg + pattern * ( - tmp_corr[:, i] - .unsqueeze(1) - .unsqueeze(1) - .unsqueeze(1) - .repeat(1, head, channel, length) + tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) ) return delays_agg @@ -164,13 +154,9 @@ def forward( # time delay agg if self.training: - V = self.time_delay_agg_training( - v.permute(0, 2, 3, 1).contiguous(), corr - ).permute(0, 3, 1, 2) + V = self.time_delay_agg_training(v.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) else: - V = self.time_delay_agg_inference( - v.permute(0, 2, 3, 1).contiguous(), corr - ).permute(0, 3, 1, 2) + V = self.time_delay_agg_inference(v.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) attn = corr.permute(0, 3, 1, 2) output = V.contiguous() @@ -247,12 +233,8 @@ def __init__( d_model // n_heads, d_model // n_heads, ) - self.conv1 = nn.Conv1d( - in_channels=d_model, out_channels=d_ffn, kernel_size=1, bias=False - ) - self.conv2 = nn.Conv1d( - in_channels=d_ffn, out_channels=d_model, kernel_size=1, bias=False - ) + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ffn, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(in_channels=d_ffn, out_channels=d_model, kernel_size=1, bias=False) self.series_decomp1 = SeriesDecompositionBlock(moving_avg) self.series_decomp2 = SeriesDecompositionBlock(moving_avg) self.dropout = nn.Dropout(dropout) @@ -302,12 +284,8 @@ def __init__( d_model // n_heads, d_model // n_heads, ) - self.conv1 = nn.Conv1d( - in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False - ) - self.conv2 = nn.Conv1d( - in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False - ) + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) self.series_decomp1 = SeriesDecompositionBlock(moving_avg) self.series_decomp2 = SeriesDecompositionBlock(moving_avg) self.series_decomp3 = SeriesDecompositionBlock(moving_avg) @@ -326,9 +304,7 @@ def __init__( def forward(self, x, cross, x_mask=None, cross_mask=None): x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0]) x, trend1 = self.series_decomp1(x) - x = x + self.dropout( - self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0] - ) + x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]) x, trend2 = self.series_decomp2(x) y = x y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) @@ -336,7 +312,5 @@ def forward(self, x, cross, x_mask=None, cross_mask=None): x, trend3 = self.series_decomp3(x + y) residual_trend = trend1 + trend2 + trend3 - residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose( - 1, 2 - ) + residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) return x, residual_trend diff --git a/pypots/nn/modules/brits/backbone.py b/pypots/nn/modules/brits/backbone.py index 3a8a87d1..eef07cc2 100644 --- a/pypots/nn/modules/brits/backbone.py +++ b/pypots/nn/modules/brits/backbone.py @@ -73,19 +73,13 @@ def __init__( self.rnn_hidden_size = rnn_hidden_size self.rnn_cell = nn.LSTMCell(self.n_features * 2, self.rnn_hidden_size) - self.temp_decay_h = TemporalDecay( - input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False - ) - self.temp_decay_x = TemporalDecay( - input_size=self.n_features, output_size=self.n_features, diag=True - ) + self.temp_decay_h = TemporalDecay(input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False) + self.temp_decay_x = TemporalDecay(input_size=self.n_features, output_size=self.n_features, diag=True) self.hist_reg = nn.Linear(self.rnn_hidden_size, self.n_features) self.feat_reg = FeatureRegression(self.n_features) self.combining_weight = nn.Linear(self.n_features * 2, self.n_features) - def forward( - self, inputs: dict, direction: str - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, inputs: dict, direction: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Parameters ---------- @@ -151,9 +145,7 @@ def forward( estimations.append(c_h.unsqueeze(dim=1)) inputs = torch.cat([c_c, m], dim=1) - hidden_states, cell_states = self.rnn_cell( - inputs, (hidden_states, cell_states) - ) + hidden_states, cell_states = self.rnn_cell(inputs, (hidden_states, cell_states)) # for each iteration, reconstruction_loss increases its value for 3 times reconstruction_loss /= self.n_steps * 3 @@ -204,9 +196,7 @@ def __init__( self.rits_b = BackboneRITS(n_steps, n_features, rnn_hidden_size) @staticmethod - def _get_consistency_loss( - pred_f: torch.Tensor, pred_b: torch.Tensor - ) -> torch.Tensor: + def _get_consistency_loss(pred_f: torch.Tensor, pred_b: torch.Tensor) -> torch.Tensor: """Calculate the consistency loss between the imputation from two RITS models. Parameters @@ -234,9 +224,7 @@ def reverse_tensor(tensor_): if tensor_.dim() <= 1: return tensor_ indices = range(tensor_.size()[1])[::-1] - indices = torch.tensor( - indices, dtype=torch.long, device=tensor_.device, requires_grad=False - ) + indices = torch.tensor(indices, dtype=torch.long, device=tensor_.device, requires_grad=False) return tensor_.index_select(1, indices) collector = [] diff --git a/pypots/nn/modules/crli/backbone.py b/pypots/nn/modules/crli/backbone.py index e463022c..e5c277a2 100644 --- a/pypots/nn/modules/crli/backbone.py +++ b/pypots/nn/modules/crli/backbone.py @@ -24,9 +24,7 @@ def __init__( rnn_cell_type: str = "GRU", ): super().__init__() - self.generator = CrliGenerator( - n_generator_layers, n_features, rnn_hidden_size, rnn_cell_type - ) + self.generator = CrliGenerator(n_generator_layers, n_features, rnn_hidden_size, rnn_cell_type) self.discriminator = CrliDiscriminator(rnn_cell_type, n_features) self.decoder = CrliDecoder( n_steps, rnn_hidden_size * 2, n_features, decoder_fcn_output_dims diff --git a/pypots/nn/modules/crli/layers.py b/pypots/nn/modules/crli/layers.py index d9558c32..b04aa43f 100644 --- a/pypots/nn/modules/crli/layers.py +++ b/pypots/nn/modules/crli/layers.py @@ -20,9 +20,7 @@ def reverse_tensor(tensor_: torch.Tensor) -> torch.Tensor: if tensor_.dim() <= 1: return tensor_ indices = range(tensor_.size()[1])[::-1] - indices = torch.tensor( - indices, dtype=torch.long, device=tensor_.device, requires_grad=False - ) + indices = torch.tensor(indices, dtype=torch.long, device=tensor_.device, requires_grad=False) return tensor_.index_select(1, indices) @@ -50,40 +48,27 @@ def __init__( self.output_layer = nn.Linear(d_hidden, d_input) - def forward( - self, X: torch.Tensor, missing_mask: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, X: torch.Tensor, missing_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: bz, n_steps, _ = X.shape device = X.device hidden_state = torch.zeros((bz, self.d_hidden), device=device) - hidden_state_collector = torch.empty( - (bz, n_steps, self.d_hidden), device=device - ) + hidden_state_collector = torch.empty((bz, n_steps, self.d_hidden), device=device) output_collector = torch.empty((bz, n_steps, self.d_input), device=device) if self.cell_type == "LSTM": - cell_states = [ - torch.zeros((bz, self.d_hidden), device=device) - for _ in range(self.n_layer) - ] + cell_states = [torch.zeros((bz, self.d_hidden), device=device) for _ in range(self.n_layer)] for step in range(n_steps): x = X[:, step, :] estimation = self.output_layer(hidden_state) output_collector[:, step] = estimation - imputed_x = ( - missing_mask[:, step] * x + (1 - missing_mask[:, step]) * estimation - ) + imputed_x = missing_mask[:, step] * x + (1 - missing_mask[:, step]) * estimation for i in range(self.n_layer): if i == 0: - hidden_state, cell_state = self.model[i]( - imputed_x, (hidden_state, cell_states[i]) - ) + hidden_state, cell_state = self.model[i](imputed_x, (hidden_state, cell_states[i])) else: - hidden_state, cell_state = self.model[i]( - hidden_state, (hidden_state, cell_states[i]) - ) + hidden_state, cell_state = self.model[i](hidden_state, (hidden_state, cell_states[i])) hidden_state_collector[:, step, :] = hidden_state @@ -92,9 +77,7 @@ def forward( x = X[:, step, :] estimation = self.output_layer(hidden_state) output_collector[:, step] = estimation - imputed_x = ( - missing_mask[:, step] * x + (1 - missing_mask[:, step]) * estimation - ) + imputed_x = missing_mask[:, step] * x + (1 - missing_mask[:, step]) * estimation for i in range(self.n_layer): if i == 0: hidden_state = self.model[i](imputed_x, hidden_state) @@ -121,16 +104,12 @@ def __init__( self.f_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden) self.b_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden) - def forward( - self, X: torch.Tensor, missing_mask: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, X: torch.Tensor, missing_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: f_outputs, f_final_hidden_state = self.f_rnn(X, missing_mask) b_outputs, b_final_hidden_state = self.b_rnn(X, missing_mask) b_outputs = reverse_tensor(b_outputs) # reverse the output of the backward rnn imputation_latent = (f_outputs + b_outputs) / 2 - fb_final_hidden_states = torch.concat( - [f_final_hidden_state, b_final_hidden_state], dim=-1 - ) + fb_final_hidden_states = torch.concat([f_final_hidden_state, b_final_hidden_state], dim=-1) return imputation_latent, fb_final_hidden_states @@ -184,13 +163,9 @@ def forward( x = imputed_X[:, step, :] for i, rnn_cell in enumerate(self.rnn_cell_module_list): if i == 0: - hidden_state, cell_state = rnn_cell( - x, (hidden_states[i], cell_states[i]) - ) + hidden_state, cell_state = rnn_cell(x, (hidden_states[i], cell_states[i])) else: - hidden_state, cell_state = rnn_cell( - hidden_states[i - 1], (hidden_states[i], cell_states[i]) - ) + hidden_state, cell_state = rnn_cell(hidden_states[i - 1], (hidden_states[i], cell_states[i])) cell_states[i] = cell_state hidden_states[i] = hidden_state @@ -235,9 +210,7 @@ def __init__( self.rnn_cell = nn.GRUCell(fcn_output_dims[-1], fcn_output_dims[-1]) self.output_layer = nn.Linear(fcn_output_dims[-1], d_output) - def forward( - self, generator_fb_hidden_states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, generator_fb_hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: device = generator_fb_hidden_states.device bz, _ = generator_fb_hidden_states.shape @@ -245,9 +218,7 @@ def forward( for layer in self.fcn: fcn_latent = layer(fcn_latent) hidden_state = fcn_latent - hidden_state_collector = torch.empty( - (bz, self.n_steps, self.fcn_output_dims[-1]), device=device - ) + hidden_state_collector = torch.empty((bz, self.n_steps, self.fcn_output_dims[-1]), device=device) for i in range(self.n_steps): hidden_state = self.rnn_cell(hidden_state, hidden_state) hidden_state_collector[:, i, :] = hidden_state diff --git a/pypots/nn/modules/crossformer/layers.py b/pypots/nn/modules/crossformer/layers.py index 0553a8d7..96320f54 100644 --- a/pypots/nn/modules/crossformer/layers.py +++ b/pypots/nn/modules/crossformer/layers.py @@ -62,12 +62,8 @@ def __init__( self.norm3 = nn.LayerNorm(d_model) self.norm4 = nn.LayerNorm(d_model) - self.MLP1 = nn.Sequential( - nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model) - ) - self.MLP2 = nn.Sequential( - nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model) - ) + self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)) + self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)) def forward(self, x): # Cross Time Stage: Directly apply MSA to each dimension @@ -82,29 +78,21 @@ def forward(self, x): # Cross dimension stage: use a small set of learnable vectors to # aggregate and distribute messages to build the D-to-D connection - dim_send = rearrange( - dim_in, "(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model", b=batch - ) + dim_send = rearrange(dim_in, "(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model", b=batch) # dim_send = dim_in.reshape() batch_router = repeat( self.router, "seg_num factor d_model -> (repeat seg_num) factor d_model", repeat=batch, ) - dim_buffer, attn = self.dim_sender( - batch_router, dim_send, dim_send, attn_mask=None - ) - dim_receive, attn = self.dim_receiver( - dim_send, dim_buffer, dim_buffer, attn_mask=None - ) + dim_buffer, attn = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None) + dim_receive, attn = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None) dim_enc = dim_send + self.dropout(dim_receive) dim_enc = self.norm3(dim_enc) dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc)) dim_enc = self.norm4(dim_enc) - final_out = rearrange( - dim_enc, "(b seg_num) ts_d d_model -> b ts_d seg_num d_model", b=batch - ) + final_out = rearrange(dim_enc, "(b seg_num) ts_d d_model -> b ts_d seg_num d_model", b=batch) return final_out @@ -159,9 +147,7 @@ def __init__( for i in range(depth): self.encode_layers.append( - TwoStageAttentionLayer( - seg_num, factor, d_model, n_heads, d_k, d_k, d_ff, dropout - ) + TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, d_k, d_k, d_ff, dropout) ) def forward(self, x, attn_mask=None, tau=None, delta=None): @@ -177,18 +163,14 @@ def forward(self, x, attn_mask=None, tau=None, delta=None): class CrossformerDecoderLayer(nn.Module): - def __init__( - self, self_attention, cross_attention, seg_len, d_model, d_ff=None, dropout=0.1 - ): + def __init__(self, self_attention, cross_attention, seg_len, d_model, d_ff=None, dropout=0.1): super().__init__() self.self_attention = self_attention self.cross_attention = cross_attention self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) - self.MLP1 = nn.Sequential( - nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model) - ) + self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model)) self.linear_pred = nn.Linear(d_model, seg_len) def forward(self, x, cross): @@ -196,9 +178,7 @@ def forward(self, x, cross): x = self.self_attention(x) x = rearrange(x, "b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model") - cross = rearrange( - cross, "b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model" - ) + cross = rearrange(cross, "b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model") tmp, attn = self.cross_attention( x, cross, @@ -218,8 +198,6 @@ def forward(self, x, cross): b=batch, ) layer_predict = self.linear_pred(dec_output) - layer_predict = rearrange( - layer_predict, "b out_d seg_num seg_len -> b (out_d seg_num) seg_len" - ) + layer_predict = rearrange(layer_predict, "b out_d seg_num seg_len -> b (out_d seg_num) seg_len") return dec_output, layer_predict diff --git a/pypots/nn/modules/csdi/backbone.py b/pypots/nn/modules/csdi/backbone.py index 26051060..3bddc437 100644 --- a/pypots/nn/modules/csdi/backbone.py +++ b/pypots/nn/modules/csdi/backbone.py @@ -56,22 +56,15 @@ def __init__( # parameters for diffusion models if schedule == "quad": - self.beta = ( - np.linspace(beta_start**0.5, beta_end**0.5, self.n_diffusion_steps) - ** 2 - ) + self.beta = np.linspace(beta_start**0.5, beta_end**0.5, self.n_diffusion_steps) ** 2 elif schedule == "linear": self.beta = np.linspace(beta_start, beta_end, self.n_diffusion_steps) else: - raise ValueError( - f"The argument schedule should be 'quad' or 'linear', but got {schedule}" - ) + raise ValueError(f"The argument schedule should be 'quad' or 'linear', but got {schedule}") self.alpha_hat = 1 - self.beta self.alpha = np.cumprod(self.alpha_hat) - self.register_buffer( - "alpha_torch", torch.tensor(self.alpha).float().unsqueeze(1).unsqueeze(1) - ) + self.register_buffer("alpha_torch", torch.tensor(self.alpha).float().unsqueeze(1).unsqueeze(1)) def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask): if self.is_unconditional: @@ -83,20 +76,14 @@ def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask): return total_input - def calc_loss_valid( - self, observed_data, cond_mask, indicating_mask, side_info, is_train - ): + def calc_loss_valid(self, observed_data, cond_mask, indicating_mask, side_info, is_train): loss_sum = 0 for t in range(self.n_diffusion_steps): # calculate loss for all t - loss = self.calc_loss( - observed_data, cond_mask, indicating_mask, side_info, is_train, set_t=t - ) + loss = self.calc_loss(observed_data, cond_mask, indicating_mask, side_info, is_train, set_t=t) loss_sum += loss.detach() return loss_sum / self.n_diffusion_steps - def calc_loss( - self, observed_data, cond_mask, indicating_mask, side_info, is_train, set_t=-1 - ): + def calc_loss(self, observed_data, cond_mask, indicating_mask, side_info, is_train, set_t=-1): B, K, L = observed_data.shape device = observed_data.device if is_train != 1: # for validation @@ -106,9 +93,7 @@ def calc_loss( current_alpha = self.alpha_torch[t] # (B,1,1) noise = torch.randn_like(observed_data) - noisy_data = (current_alpha**0.5) * observed_data + ( - 1.0 - current_alpha - ) ** 0.5 * noise + noisy_data = (current_alpha**0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask) @@ -132,27 +117,20 @@ def forward(self, observed_data, cond_mask, side_info, n_sampling_times): noisy_cond_history = [] for t in range(self.n_diffusion_steps): noise = torch.randn_like(noisy_obs) - noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[ - t - ] ** 0.5 * noise + noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[t] ** 0.5 * noise noisy_cond_history.append(noisy_obs * cond_mask) current_sample = torch.randn_like(observed_data) for t in range(self.n_diffusion_steps - 1, -1, -1): if self.is_unconditional: - diff_input = ( - cond_mask * noisy_cond_history[t] - + (1.0 - cond_mask) * current_sample - ) + diff_input = cond_mask * noisy_cond_history[t] + (1.0 - cond_mask) * current_sample diff_input = diff_input.unsqueeze(1) # (B,1,K,L) else: cond_obs = (cond_mask * observed_data).unsqueeze(1) noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) diff_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) - predicted = self.diff_model( - diff_input, side_info, torch.tensor([t]).to(device) - ) + predicted = self.diff_model(diff_input, side_info, torch.tensor([t]).to(device)) coeff1 = 1 / self.alpha_hat[t] ** 0.5 coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5 @@ -160,9 +138,7 @@ def forward(self, observed_data, cond_mask, side_info, n_sampling_times): if t > 0: noise = torch.randn_like(current_sample) - sigma = ( - (1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t] - ) ** 0.5 + sigma = ((1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]) ** 0.5 current_sample += sigma * noise imputed_samples[:, i] = current_sample.detach() diff --git a/pypots/nn/modules/csdi/layers.py b/pypots/nn/modules/csdi/layers.py index dfaacf20..2811075d 100644 --- a/pypots/nn/modules/csdi/layers.py +++ b/pypots/nn/modules/csdi/layers.py @@ -13,9 +13,7 @@ def get_torch_trans(heads=8, layers=1, channels=64): - encoder_layer = nn.TransformerEncoderLayer( - d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu" - ) + encoder_layer = nn.TransformerEncoderLayer(d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu") return nn.TransformerEncoder(encoder_layer, num_layers=layers) @@ -41,11 +39,7 @@ def __init__(self, n_diffusion_steps, d_embedding=128, d_projection=None): @staticmethod def _build_embedding(n_steps, d_embedding=64): steps = torch.arange(n_steps).unsqueeze(1) # (T,1) - frequencies = 10.0 ** ( - torch.arange(d_embedding) / (d_embedding - 1) * 4.0 - ).unsqueeze( - 0 - ) # (1,dim) + frequencies = 10.0 ** (torch.arange(d_embedding) / (d_embedding - 1) * 4.0).unsqueeze(0) # (1,dim) table = steps * frequencies # (T,dim) table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) # (T,dim*2) return table @@ -68,9 +62,7 @@ def __init__(self, d_side, n_channels, diffusion_embedding_dim, nheads): self.output_projection = conv1d_with_init(n_channels, 2 * n_channels, 1) self.time_layer = get_torch_trans(heads=nheads, layers=1, channels=n_channels) - self.feature_layer = get_torch_trans( - heads=nheads, layers=1, channels=n_channels - ) + self.feature_layer = get_torch_trans(heads=nheads, layers=1, channels=n_channels) def forward_time(self, y, base_shape): B, channel, K, L = base_shape # bz, 2, n_features, n_steps @@ -95,9 +87,7 @@ def forward(self, x, cond_info, diffusion_emb): base_shape = x.shape x = x.reshape(B, channel, K * L) - diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze( - -1 - ) # (B,channel,1) + diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(-1) # (B,channel,1) y = x + diffusion_emb y = self.forward_time(y, base_shape) diff --git a/pypots/nn/modules/dlinear/backbone.py b/pypots/nn/modules/dlinear/backbone.py index b4f2ff3b..581b6ea1 100644 --- a/pypots/nn/modules/dlinear/backbone.py +++ b/pypots/nn/modules/dlinear/backbone.py @@ -32,31 +32,19 @@ def __init__( for i in range(n_features): self.linear_seasonal.append(nn.Linear(n_steps, n_steps)) self.linear_trend.append(nn.Linear(n_steps, n_steps)) - self.linear_seasonal[i].weight = nn.Parameter( - (1 / n_steps) * torch.ones([n_steps, n_steps]) - ) - self.linear_trend[i].weight = nn.Parameter( - (1 / n_steps) * torch.ones([n_steps, n_steps]) - ) + self.linear_seasonal[i].weight = nn.Parameter((1 / n_steps) * torch.ones([n_steps, n_steps])) + self.linear_trend[i].weight = nn.Parameter((1 / n_steps) * torch.ones([n_steps, n_steps])) else: if d_model is None: - raise ValueError( - "The argument d_model is necessary for DLinear in the non-individual mode." - ) + raise ValueError("The argument d_model is necessary for DLinear in the non-individual mode.") self.linear_seasonal = nn.Linear(n_steps, n_steps) self.linear_trend = nn.Linear(n_steps, n_steps) - self.linear_seasonal.weight = nn.Parameter( - (1 / n_steps) * torch.ones([n_steps, n_steps]) - ) - self.linear_trend.weight = nn.Parameter( - (1 / n_steps) * torch.ones([n_steps, n_steps]) - ) + self.linear_seasonal.weight = nn.Parameter((1 / n_steps) * torch.ones([n_steps, n_steps])) + self.linear_trend.weight = nn.Parameter((1 / n_steps) * torch.ones([n_steps, n_steps])) def forward(self, seasonal_init, trend_init): if self.individual: - seasonal_init, trend_init = seasonal_init.permute( - 0, 2, 1 - ), trend_init.permute(0, 2, 1) + seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1) seasonal_output = torch.zeros( [seasonal_init.size(0), seasonal_init.size(1), self.n_steps], dtype=seasonal_init.dtype, @@ -66,17 +54,13 @@ def forward(self, seasonal_init, trend_init): dtype=trend_init.dtype, ).to(trend_init.device) for i in range(self.n_features): - seasonal_output[:, i, :] = self.linear_seasonal[i]( - seasonal_init[:, i, :] - ) + seasonal_output[:, i, :] = self.linear_seasonal[i](seasonal_init[:, i, :]) trend_output[:, i, :] = self.linear_trend[i](trend_init[:, i, :]) seasonal_output = seasonal_output.permute(0, 2, 1) trend_output = trend_output.permute(0, 2, 1) else: - seasonal_init, trend_init = seasonal_init.permute( - 0, 2, 1 - ), trend_init.permute(0, 2, 1) + seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1) seasonal_output = self.linear_seasonal(seasonal_init) trend_output = self.linear_trend(trend_init) diff --git a/pypots/nn/modules/etsformer/layers.py b/pypots/nn/modules/etsformer/layers.py index 1a36ed51..c788796a 100644 --- a/pypots/nn/modules/etsformer/layers.py +++ b/pypots/nn/modules/etsformer/layers.py @@ -84,9 +84,7 @@ def get_exponential_weight(self, T): # \alpha^t for all t = 1, 2, ..., T init_weight = self.weight ** (powers + 1) - return rearrange(init_weight, "h t -> 1 t h 1"), rearrange( - weight, "h t -> 1 t h 1" - ) + return rearrange(init_weight, "h t -> 1 t h 1"), rearrange(weight, "h t -> 1 t h 1") @property def weight(self): @@ -120,9 +118,7 @@ def __init__(self, d_model, n_heads, d_head=None, dropout=0.1): self.es = ExponentialSmoothing(self.d_head, self.n_heads, dropout=dropout) self.out_proj = nn.Linear(self.d_head * self.n_heads, self.d_model) - assert ( - self.d_head * self.n_heads == self.d_model - ), "d_model must be divisible by n_heads" + assert self.d_head * self.n_heads == self.d_model, "d_model must be divisible by n_heads" def forward(self, inputs): """ @@ -169,9 +165,7 @@ def forward(self, x): def extrapolate(self, x_freq, f, t): x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) f = torch.cat([f, -f], dim=1) - t_val = rearrange( - torch.arange(t + self.pred_len, dtype=torch.float), "t -> () () t ()" - ).to(x_freq.device) + t_val = rearrange(torch.arange(t + self.pred_len, dtype=torch.float), "t -> () () t ()").to(x_freq.device) amp = rearrange(x_freq.abs() / t, "b f d -> b f () d") phase = rearrange(x_freq.angle(), "b f d -> b f () d") @@ -181,12 +175,8 @@ def extrapolate(self, x_freq, f, t): return reduce(x_time, "b f t d -> b t d", "sum") def topk_freq(self, x_freq): - values, indices = torch.topk( - x_freq.abs(), self.k, dim=1, largest=True, sorted=True - ) - mesh_a, mesh_b = torch.meshgrid( - torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)) - ) + values, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True) + mesh_a, mesh_b = torch.meshgrid(torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2))) index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1)) x_freq = x_freq[index_tuple] diff --git a/pypots/nn/modules/fedformer/autoencoder.py b/pypots/nn/modules/fedformer/autoencoder.py index 84cae344..be081b1e 100644 --- a/pypots/nn/modules/fedformer/autoencoder.py +++ b/pypots/nn/modules/fedformer/autoencoder.py @@ -49,9 +49,7 @@ def __init__( mode_select_method=mode_select, ) else: - raise ValueError( - f"Unsupported version: {version}. Please choose from ['Wavelets', 'Fourier']." - ) + raise ValueError(f"Unsupported version: {version}. Please choose from ['Wavelets', 'Fourier'].") self.encoder = InformerEncoder( [ @@ -123,9 +121,7 @@ def __init__( num_heads=n_heads, ) else: - raise ValueError( - f"Unsupported version: {version}. Please choose from ['Wavelets', 'Fourier']." - ) + raise ValueError(f"Unsupported version: {version}. Please choose from ['Wavelets', 'Fourier'].") self.decoder = InformerDecoder( [ diff --git a/pypots/nn/modules/fedformer/layers.py b/pypots/nn/modules/fedformer/layers.py index 36522bf9..ab4512b2 100644 --- a/pypots/nn/modules/fedformer/layers.py +++ b/pypots/nn/modules/fedformer/layers.py @@ -43,13 +43,9 @@ def get_phi_psi(k, base): if base == "legendre": for ki in range(k): coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs() - phi_coeff[ki, : ki + 1] = np.flip( - np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64) - ) + phi_coeff[ki, : ki + 1] = np.flip(np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64)) coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs() - phi_2x_coeff[ki, : ki + 1] = np.flip( - np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64) - ) + phi_2x_coeff[ki, : ki + 1] = np.flip(np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64)) psi1_coeff = np.zeros((k, k)) psi2_coeff = np.zeros((k, k)) @@ -60,12 +56,7 @@ def get_phi_psi(k, base): b = phi_coeff[i, : i + 1] prod_ = np.convolve(a, b) prod_[np.abs(prod_) < 1e-8] = 0 - proj_ = ( - prod_ - * 1 - / (np.arange(len(prod_)) + 1) - * np.power(0.5, 1 + np.arange(len(prod_))) - ).sum() + proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum() psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :] psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :] for j in range(ki): @@ -73,34 +64,19 @@ def get_phi_psi(k, base): b = psi1_coeff[j, :] prod_ = np.convolve(a, b) prod_[np.abs(prod_) < 1e-8] = 0 - proj_ = ( - prod_ - * 1 - / (np.arange(len(prod_)) + 1) - * np.power(0.5, 1 + np.arange(len(prod_))) - ).sum() + proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum() psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :] psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :] a = psi1_coeff[ki, :] prod_ = np.convolve(a, a) prod_[np.abs(prod_) < 1e-8] = 0 - norm1 = ( - prod_ - * 1 - / (np.arange(len(prod_)) + 1) - * np.power(0.5, 1 + np.arange(len(prod_))) - ).sum() + norm1 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum() a = psi2_coeff[ki, :] prod_ = np.convolve(a, a) prod_[np.abs(prod_) < 1e-8] = 0 - norm2 = ( - prod_ - * 1 - / (np.arange(len(prod_)) + 1) - * (1 - np.power(0.5, 1 + np.arange(len(prod_)))) - ).sum() + norm2 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * (1 - np.power(0.5, 1 + np.arange(len(prod_))))).sum() norm_ = np.sqrt(norm1 + norm2) psi1_coeff[ki, :] /= norm_ psi2_coeff[ki, :] /= norm_ @@ -118,15 +94,10 @@ def get_phi_psi(k, base): phi_2x_coeff[ki, : ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2) else: coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs() - phi_coeff[ki, : ki + 1] = np.flip( - 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64) - ) + phi_coeff[ki, : ki + 1] = np.flip(2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64)) coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs() phi_2x_coeff[ki, : ki + 1] = np.flip( - np.sqrt(2) - * 2 - / np.sqrt(np.pi) - * np.array(coeff_).astype(np.float64) + np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64) ) phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)] @@ -198,22 +169,10 @@ def psi(psi1, psi2, i, inp): for ki in range(k): for kpi in range(k): - H0[ki, kpi] = ( - 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum() - ) - G0[ki, kpi] = ( - 1 - / np.sqrt(2) - * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum() - ) - H1[ki, kpi] = ( - 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum() - ) - G1[ki, kpi] = ( - 1 - / np.sqrt(2) - * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum() - ) + H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum() + G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum() + H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum() + G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum() PHI0 = np.eye(k) PHI1 = np.eye(k) @@ -229,27 +188,13 @@ def psi(psi1, psi2, i, inp): for ki in range(k): for kpi in range(k): - H0[ki, kpi] = ( - 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum() - ) - G0[ki, kpi] = ( - 1 - / np.sqrt(2) - * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum() - ) - H1[ki, kpi] = ( - 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum() - ) - G1[ki, kpi] = ( - 1 - / np.sqrt(2) - * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum() - ) + H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum() + G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum() + H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum() + G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum() PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2 - PHI1[ki, kpi] = ( - wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1) - ).sum() * 2 + PHI1[ki, kpi] = (wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1)).sum() * 2 PHI0[np.abs(PHI0) < 1e-8] = 0 PHI1[np.abs(PHI1) < 1e-8] = 0 @@ -268,12 +213,8 @@ def __init__(self, k, alpha, c=1, nl=1, initializer=None, **kwargs): self.modes1 = alpha self.scale = 1 / (c * k * c * k) - self.weights1 = nn.Parameter( - self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float) - ) - self.weights2 = nn.Parameter( - self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float) - ) + self.weights1 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float)) + self.weights2 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float)) self.weights1.requires_grad = True self.weights2.requires_grad = True self.k = k @@ -286,15 +227,11 @@ def compl_mul1d(self, order, x, weights): x = torch.complex(x, torch.zeros_like(x).to(x.device)) if not torch.is_complex(weights): w_flag = False - weights = torch.complex( - weights, torch.zeros_like(weights).to(weights.device) - ) + weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) if x_flag or w_flag: return torch.complex( - torch.einsum(order, x.real, weights.real) - - torch.einsum(order, x.imag, weights.imag), - torch.einsum(order, x.real, weights.imag) - + torch.einsum(order, x.imag, weights.real), + torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), + torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real), ) else: return torch.einsum(order, x.real, weights.real) @@ -319,9 +256,7 @@ def forward(self, x): class MWT_CZ1d(nn.Module): - def __init__( - self, k=3, alpha=64, L=0, c=1, base="legendre", initializer=None, **kwargs - ): + def __init__(self, k=3, alpha=64, L=0, c=1, base="legendre", initializer=None, **kwargs): super().__init__() self.k = k @@ -484,15 +419,11 @@ def compl_mul1d(self, order, x, weights): x = torch.complex(x, torch.zeros_like(x).to(x.device)) if not torch.is_complex(weights): w_flag = False - weights = torch.complex( - weights, torch.zeros_like(weights).to(weights.device) - ) + weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) if x_flag or w_flag: return torch.complex( - torch.einsum(order, x.real, weights.real) - - torch.einsum(order, x.imag, weights.imag), - torch.einsum(order, x.real, weights.imag) - + torch.einsum(order, x.imag, weights.real), + torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), + torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real), ) else: return torch.einsum(order, x.real, weights.real) @@ -507,16 +438,12 @@ def forward(self, q, k, v, mask): self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1))) # Compute Fourier coefficients - xq_ft_ = torch.zeros( - B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat - ) + xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) xq_ft = torch.fft.rfft(xq, dim=-1) for i, j in enumerate(self.index_q): xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] - xk_ft_ = torch.zeros( - B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat - ) + xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat) xk_ft = torch.fft.rfft(xk, dim=-1) for i, j in enumerate(self.index_k_v): xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] @@ -527,9 +454,7 @@ def forward(self, q, k, v, mask): xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) else: - raise Exception( - "{} actiation function is not implemented".format(self.activation) - ) + raise Exception("{} actiation function is not implemented".format(self.activation)) xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_) xqkvw = xqkv_ft @@ -537,9 +462,7 @@ def forward(self, q, k, v, mask): for i, j in enumerate(self.index_q): out_ft[:, :, :, j] = xqkvw[:, :, :, i] - out = torch.fft.irfft( - out_ft / self.in_channels / self.out_channels, n=xq.size(-1) - ).permute(0, 3, 2, 1) + out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)).permute(0, 3, 2, 1) # size = [B, L, H, E] return (out, None) @@ -701,10 +624,7 @@ def forward( dk, sk = Ud_k[i], Us_k[i] dq, sq = Ud_q[i], Us_q[i] dv, sv = Ud_v[i], Us_v[i] - Ud += [ - self.attn1(dq[0], dk[0], dv[0], attn_mask)[0] - + self.attn2(dq[1], dk[1], dv[1], attn_mask)[0] - ] + Ud += [self.attn1(dq[0], dk[0], dv[0], attn_mask)[0] + self.attn2(dq[1], dk[1], dv[1], attn_mask)[0]] Us += [self.attn3(sq, sk, sv, attn_mask)[0]] v = self.attn4(q, k, v, attn_mask)[0] @@ -759,9 +679,7 @@ def get_frequency_modes(seq_len, modes=64, mode_select_method="random"): # ########## fourier layer ############# class FourierBlock(AttentionOperator): - def __init__( - self, in_channels, out_channels, seq_len, modes=0, mode_select_method="random" - ): + def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method="random"): super().__init__() # print("fourier enhanced block used!") """ @@ -769,9 +687,7 @@ def __init__( it does FFT, linear transform, and Inverse FFT. """ # get modes on frequency domain - self.index = get_frequency_modes( - seq_len, modes=modes, mode_select_method=mode_select_method - ) + self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method) # print("modes={}, index={}".format(modes, self.index)) self.scale = 1 / (in_channels * out_channels) @@ -809,9 +725,7 @@ def forward( # Perform Fourier neural operations out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) for wi, i in enumerate(self.index): - out_ft[:, :, :, wi] = self.compl_mul1d( - x_ft[:, :, :, i], self.weights1[:, :, :, wi] - ) + out_ft[:, :, :, wi] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, wi]) # Return to time domain x = torch.fft.irfft(out_ft, n=x.size(-1)) return x, None @@ -840,12 +754,8 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels # get modes for queries and keys (& values) on frequency domain - self.index_q = get_frequency_modes( - seq_len_q, modes=modes, mode_select_method=mode_select_method - ) - self.index_kv = get_frequency_modes( - seq_len_kv, modes=modes, mode_select_method=mode_select_method - ) + self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method) + self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method) # print("modes_q={}, index_q={}".format(len(self.index_q), self.index_q)) # print("modes_kv={}, index_kv={}".format(len(self.index_kv), self.index_kv)) @@ -881,15 +791,11 @@ def compl_mul1d(self, order, x, weights): x = torch.complex(x, torch.zeros_like(x).to(x.device)) if not torch.is_complex(weights): w_flag = False - weights = torch.complex( - weights, torch.zeros_like(weights).to(weights.device) - ) + weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) if x_flag or w_flag: return torch.complex( - torch.einsum(order, x.real, weights.real) - - torch.einsum(order, x.imag, weights.imag), - torch.einsum(order, x.real, weights.imag) - + torch.einsum(order, x.imag, weights.real), + torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), + torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real), ) else: return torch.einsum(order, x.real, weights.real) @@ -911,17 +817,13 @@ def forward( # xv = v.permute(0, 2, 3, 1) # Compute Fourier coefficients - xq_ft_ = torch.zeros( - B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat - ) + xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) xq_ft = torch.fft.rfft(xq, dim=-1) for i, j in enumerate(self.index_q): if j >= xq_ft.shape[3]: continue xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] - xk_ft_ = torch.zeros( - B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat - ) + xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat) xk_ft = torch.fft.rfft(xk, dim=-1) for i, j in enumerate(self.index_kv): if j >= xk_ft.shape[3]: @@ -936,22 +838,16 @@ def forward( xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) else: - raise Exception( - "{} actiation function is not implemented".format(self.activation) - ) + raise Exception("{} actiation function is not implemented".format(self.activation)) xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_) - xqkvw = self.compl_mul1d( - "bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2) - ) + xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2)) out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) for i, j in enumerate(self.index_q): if i >= xqkvw.shape[3] or j >= out_ft.shape[3]: continue out_ft[:, :, :, j] = xqkvw[:, :, :, i] # Return to time domain - out = torch.fft.irfft( - out_ft / self.in_channels / self.out_channels, n=xq.size(-1) - ) + out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)) return out, None @@ -973,8 +869,6 @@ def forward(self, x): moving_avg = func(x) moving_mean.append(moving_avg.unsqueeze(-1)) moving_mean = torch.cat(moving_mean, dim=-1) - moving_mean = torch.sum( - moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1 - ) + moving_mean = torch.sum(moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1) res = x - moving_mean return res, moving_mean diff --git a/pypots/nn/modules/film/backbone.py b/pypots/nn/modules/film/backbone.py index 85bc6791..4547eaf4 100644 --- a/pypots/nn/modules/film/backbone.py +++ b/pypots/nn/modules/film/backbone.py @@ -34,11 +34,7 @@ def __init__( self.affine_weight = nn.Parameter(torch.ones(1, 1, in_channels)) self.affine_bias = nn.Parameter(torch.zeros(1, 1, in_channels)) self.legts = nn.ModuleList( - [ - HiPPO_LegT(N=n, dt=1.0 / n_pred_steps / i) - for n in window_size - for i in multiscale - ] + [HiPPO_LegT(N=n, dt=1.0 / n_pred_steps / i) for n in window_size for i in multiscale] ) self.spec_conv_1 = nn.ModuleList( [ @@ -65,14 +61,10 @@ def forward(self, X) -> torch.Tensor: x_in_len = self.multiscale[i % len(self.multiscale)] * self.n_pred_steps x_in = x_enc[:, -x_in_len:] legt = self.legts[i] - x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[ - :, :, :, jump_dist: - ] + x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:] out1 = self.spec_conv_1[i](x_in_c) if self.n_steps >= self.n_pred_steps: - x_dec_c = out1.transpose(2, 3)[ - :, :, self.n_pred_steps - 1 - jump_dist, : - ] + x_dec_c = out1.transpose(2, 3)[:, :, self.n_pred_steps - 1 - jump_dist, :] else: x_dec_c = out1.transpose(2, 3)[:, :, -1, :] x_dec = x_dec_c @ legt.eval_matrix[-self.n_pred_steps :, :].T diff --git a/pypots/nn/modules/film/layers.py b/pypots/nn/modules/film/layers.py index 24976227..30ff574d 100644 --- a/pypots/nn/modules/film/layers.py +++ b/pypots/nn/modules/film/layers.py @@ -93,10 +93,7 @@ def __init__( self.index0 = list(range(0, int(ratio * min(seq_len // 2, modes2)))) self.index1 = list(range(len(self.index0), self.modes2)) np.random.shuffle(self.index1) - self.index1 = self.index1[ - : min(seq_len // 2, self.modes2) - - int(ratio * min(seq_len // 2, modes2)) - ] + self.index1 = self.index1[: min(seq_len // 2, self.modes2) - int(ratio * min(seq_len // 2, modes2))] self.index = self.index0 + self.index1 self.index.sort() elif mode_type == 2: @@ -108,8 +105,7 @@ def __init__( self.scale = 1 / (in_channels * out_channels) self.weights1 = nn.Parameter( - self.scale - * torch.rand(in_channels, out_channels, len(self.index), dtype=torch.cfloat) + self.scale * torch.rand(in_channels, out_channels, len(self.index), dtype=torch.cfloat) ) def forward(self, x): @@ -126,14 +122,10 @@ def forward(self, x): if self.modes1 > 1000: for wi, i in enumerate(self.index): - out_ft[:, :, :, i] = torch.einsum( - "bji,io->bjo", (x_ft[:, :, :, i], self.weights1[:, :, wi]) - ) + out_ft[:, :, :, i] = torch.einsum("bji,io->bjo", (x_ft[:, :, :, i], self.weights1[:, :, wi])) else: a = x_ft[:, :, :, : self.modes2] - out_ft[:, :, :, : self.modes2] = torch.einsum( - "bjix,iox->bjox", a, self.weights1 - ) + out_ft[:, :, :, : self.modes2] = torch.einsum("bjix,iox->bjox", a, self.weights1) x = torch.fft.irfft(out_ft, n=x.size(-1)) return x diff --git a/pypots/nn/modules/frets/backbone.py b/pypots/nn/modules/frets/backbone.py index 2b53af10..26afc396 100644 --- a/pypots/nn/modules/frets/backbone.py +++ b/pypots/nn/modules/frets/backbone.py @@ -33,20 +33,12 @@ def __init__( self.scale = 0.02 # self.embeddings = nn.Parameter(torch.randn(1, self.embed_size)) # original embedding method, deprecate here - self.r1 = nn.Parameter( - self.scale * torch.randn(self.embed_size, self.embed_size) - ) - self.i1 = nn.Parameter( - self.scale * torch.randn(self.embed_size, self.embed_size) - ) + self.r1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) + self.i1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) self.rb1 = nn.Parameter(self.scale * torch.randn(self.embed_size)) self.ib1 = nn.Parameter(self.scale * torch.randn(self.embed_size)) - self.r2 = nn.Parameter( - self.scale * torch.randn(self.embed_size, self.embed_size) - ) - self.i2 = nn.Parameter( - self.scale * torch.randn(self.embed_size, self.embed_size) - ) + self.r2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) + self.i2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) self.rb2 = nn.Parameter(self.scale * torch.randn(self.embed_size)) self.ib2 = nn.Parameter(self.scale * torch.randn(self.embed_size)) @@ -89,24 +81,12 @@ def MLP_channel(self, x, B, N, L): # dimension: FFT along the dimension, r: the real part of weights, i: the imaginary part of weights # rb: the real part of bias, ib: the imaginary part of bias def FreMLP(self, B, nd, dimension, x, r, i, rb, ib): - o1_real = torch.zeros( - [B, nd, dimension // 2 + 1, self.embed_size], device=x.device - ) - o1_imag = torch.zeros( - [B, nd, dimension // 2 + 1, self.embed_size], device=x.device - ) + o1_real = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size], device=x.device) + o1_imag = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size], device=x.device) - o1_real = F.relu( - torch.einsum("bijd,dd->bijd", x.real, r) - - torch.einsum("bijd,dd->bijd", x.imag, i) - + rb - ) + o1_real = F.relu(torch.einsum("bijd,dd->bijd", x.real, r) - torch.einsum("bijd,dd->bijd", x.imag, i) + rb) - o1_imag = F.relu( - torch.einsum("bijd,dd->bijd", x.imag, r) - + torch.einsum("bijd,dd->bijd", x.real, i) - + ib - ) + o1_imag = F.relu(torch.einsum("bijd,dd->bijd", x.imag, r) + torch.einsum("bijd,dd->bijd", x.real, i) + ib) y = torch.stack([o1_real, o1_imag], dim=-1) y = F.softshrink(y, lambd=self.sparsity_threshold) diff --git a/pypots/nn/modules/gpvae/backbone.py b/pypots/nn/modules/gpvae/backbone.py index fe76e0f5..de6284de 100644 --- a/pypots/nn/modules/gpvae/backbone.py +++ b/pypots/nn/modules/gpvae/backbone.py @@ -114,23 +114,13 @@ def _init_prior(self, device="cpu"): kernel_matrices = [] for i in range(self.kernel_scales): if self.kernel == "rbf": - kernel_matrices.append( - rbf_kernel(self.time_length, self.length_scale / 2**i) - ) + kernel_matrices.append(rbf_kernel(self.time_length, self.length_scale / 2**i)) elif self.kernel == "diffusion": - kernel_matrices.append( - diffusion_kernel(self.time_length, self.length_scale / 2**i) - ) + kernel_matrices.append(diffusion_kernel(self.time_length, self.length_scale / 2**i)) elif self.kernel == "matern": - kernel_matrices.append( - matern_kernel(self.time_length, self.length_scale / 2**i) - ) + kernel_matrices.append(matern_kernel(self.time_length, self.length_scale / 2**i)) elif self.kernel == "cauchy": - kernel_matrices.append( - cauchy_kernel( - self.time_length, self.sigma, self.length_scale / 2**i - ) - ) + kernel_matrices.append(cauchy_kernel(self.time_length, self.sigma, self.length_scale / 2**i)) # Combine kernel matrices for each latent dimension tiled_matrices = [] @@ -141,9 +131,7 @@ def _init_prior(self, device="cpu"): else: multiplier = int(np.ceil(self.latent_dim / self.kernel_scales)) total += multiplier - tiled_matrices.append( - torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1) - ) + tiled_matrices.append(torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1)) kernel_matrix_tiled = torch.cat(tiled_matrices) assert len(kernel_matrix_tiled) == self.latent_dim prior = torch.distributions.MultivariateNormal( @@ -158,9 +146,7 @@ def impute(self, X, missing_mask, n_sampling_times=1): missing_mask = missing_mask.repeat(n_sampling_times, 1, 1).type(torch.bool) decode_x_mean = self.decode(self.encode(X).mean).mean imputed_data = decode_x_mean * ~missing_mask + X * missing_mask - imputed_data = imputed_data.reshape( - n_sampling_times, n_samples, n_steps, n_features - ).permute(1, 0, 2, 3) + imputed_data = imputed_data.reshape(n_sampling_times, n_samples, n_steps, n_features).permute(1, 0, 2, 3) return imputed_data def forward(self, X, missing_mask): diff --git a/pypots/nn/modules/gpvae/layers.py b/pypots/nn/modules/gpvae/layers.py index 02469e04..b3142d06 100644 --- a/pypots/nn/modules/gpvae/layers.py +++ b/pypots/nn/modules/gpvae/layers.py @@ -22,10 +22,7 @@ def rbf_kernel(T, length_scale): def diffusion_kernel(T, length_scale): - assert length_scale < 0.5, ( - "length_scale has to be smaller than 0.5 for the " - "kernel matrix to be diagonally dominant" - ) + assert length_scale < 0.5, "length_scale has to be smaller than 0.5 for the kernel matrix to be diagonally dominant" sigmas = torch.ones(T, T) * length_scale sigmas_tridiag = torch.diagonal(sigmas, offset=0, dim1=-2, dim2=-1) sigmas_tridiag += torch.diagonal(sigmas, offset=1, dim1=-2, dim2=-1) @@ -39,9 +36,7 @@ def matern_kernel(T, length_scale): xs_in = torch.unsqueeze(xs, 0) xs_out = torch.unsqueeze(xs, 1) distance_matrix = torch.abs(xs_in - xs_out) - distance_matrix_scaled = distance_matrix / torch.sqrt(length_scale).type( - torch.float32 - ) + distance_matrix_scaled = distance_matrix / torch.sqrt(length_scale).type(torch.float32) kernel_matrix = torch.exp(-distance_matrix_scaled) return kernel_matrix @@ -81,13 +76,9 @@ def make_nn(input_size, output_size, hidden_sizes): layers = [] for i in range(len(hidden_sizes)): if i == 0: - layers.append( - nn.Linear(in_features=input_size, out_features=hidden_sizes[i]) - ) + layers.append(nn.Linear(in_features=input_size, out_features=hidden_sizes[i])) else: - layers.append( - nn.Linear(in_features=hidden_sizes[i - 1], out_features=hidden_sizes[i]) - ) + layers.append(nn.Linear(in_features=hidden_sizes[i - 1], out_features=hidden_sizes[i])) layers.append(nn.ReLU()) layers.append(nn.Linear(in_features=hidden_sizes[-1], out_features=output_size)) return nn.Sequential(*layers) @@ -137,9 +128,7 @@ def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): """ padding = kernel_size // 2 - cnn_layer = CustomConv1d( - input_size, hidden_sizes[0], kernel_size=kernel_size, padding=padding - ) + cnn_layer = CustomConv1d(input_size, hidden_sizes[0], kernel_size=kernel_size, padding=padding) layers = [cnn_layer] for i, h in zip(hidden_sizes, hidden_sizes[1:]): @@ -193,9 +182,7 @@ def forward(self, x): dense_shape = [batch_size, self.z_size, time_length, time_length] idxs_1 = np.repeat(np.arange(batch_size), self.z_size * (2 * time_length - 1)) - idxs_2 = np.tile( - np.repeat(np.arange(self.z_size), (2 * time_length - 1)), batch_size - ) + idxs_2 = np.tile(np.repeat(np.arange(self.z_size), (2 * time_length - 1)), batch_size) idxs_3 = np.tile( np.concatenate([np.arange(time_length), np.arange(time_length - 1)]), batch_size * self.z_size, @@ -222,16 +209,12 @@ def forward(self, x): ) prec_tril = prec_tril + eye cov_tril = torch.linalg.solve_triangular(prec_tril, eye, upper=True) - cov_tril = torch.where( - torch.isfinite(cov_tril), cov_tril, torch.zeros_like(cov_tril) - ).to(mapped.device) + cov_tril = torch.where(torch.isfinite(cov_tril), cov_tril, torch.zeros_like(cov_tril)).to(mapped.device) num_dim = len(cov_tril.shape) cov_tril_lower = torch.transpose(cov_tril, num_dim - 1, num_dim - 2) - z_dist = torch.distributions.MultivariateNormal( - loc=mapped_mean, scale_tril=cov_tril_lower - ) + z_dist = torch.distributions.MultivariateNormal(loc=mapped_mean, scale_tril=cov_tril_lower) return z_dist diff --git a/pypots/nn/modules/grud/backbone.py b/pypots/nn/modules/grud/backbone.py index dde5fcfb..88eea8d8 100644 --- a/pypots/nn/modules/grud/backbone.py +++ b/pypots/nn/modules/grud/backbone.py @@ -26,19 +26,11 @@ def __init__( self.rnn_hidden_size = rnn_hidden_size # create models - self.rnn_cell = nn.GRUCell( - self.n_features * 2 + self.rnn_hidden_size, self.rnn_hidden_size - ) - self.temp_decay_h = TemporalDecay( - input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False - ) - self.temp_decay_x = TemporalDecay( - input_size=self.n_features, output_size=self.n_features, diag=True - ) - - def forward( - self, X, missing_mask, deltas, empirical_mean, X_filledLOCF - ) -> Tuple[torch.Tensor, ...]: + self.rnn_cell = nn.GRUCell(self.n_features * 2 + self.rnn_hidden_size, self.rnn_hidden_size) + self.temp_decay_h = TemporalDecay(input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False) + self.temp_decay_x = TemporalDecay(input_size=self.n_features, output_size=self.n_features, diag=True) + + def forward(self, X, missing_mask, deltas, empirical_mean, X_filledLOCF) -> Tuple[torch.Tensor, ...]: """Forward processing of GRU-D. Parameters diff --git a/pypots/nn/modules/imputeformer/attention.py b/pypots/nn/modules/imputeformer/attention.py index b9e982c8..9a9dc551 100644 --- a/pypots/nn/modules/imputeformer/attention.py +++ b/pypots/nn/modules/imputeformer/attention.py @@ -55,13 +55,9 @@ def forward(self, query, key, value): key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0) value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0) - key = key.transpose( - -1, -2 - ) # (num_heads * batch_size, ..., head_dim, src_length) + key = key.transpose(-1, -2) # (num_heads * batch_size, ..., head_dim, src_length) - attn_score = ( - query @ key - ) / self.head_dim**0.5 # (num_heads * batch_size, ..., tgt_length, src_length) + attn_score = (query @ key) / self.head_dim**0.5 # (num_heads * batch_size, ..., tgt_length, src_length) if self.mask: mask = torch.ones( @@ -105,9 +101,7 @@ def __init__( self.dropout = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) - self.MLP = nn.Sequential( - nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model) - ) + self.MLP = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)) self.seq_len = seq_len def forward(self, x): @@ -121,12 +115,8 @@ def forward(self, x): # projector = repeat(self.projector, 'dim_proj d_model -> repeat seq_len dim_proj d_model', # repeat=batch, seq_len=self.seq_len) # [b, s, c, d] - message_out = self.out_attn( - projector, x, x - ) # [b, s, c, d] <-> [b s n d] -> [b s c d] - message_in = self.in_attn( - x, projector, message_out - ) # [b s n d] <-> [b, s, c, d] -> [b s n d] + message_out = self.out_attn(projector, x, x) # [b, s, c, d] <-> [b s n d] -> [b s c d] + message_in = self.in_attn(x, projector, message_out) # [b s n d] <-> [b, s, c, d] -> [b s n d] message = x + self.dropout(message_in) message = self.norm1(message) message = message + self.dropout(self.MLP(message)) diff --git a/pypots/nn/modules/imputeformer/mlp.py b/pypots/nn/modules/imputeformer/mlp.py index eb8d6288..12ef62a3 100644 --- a/pypots/nn/modules/imputeformer/mlp.py +++ b/pypots/nn/modules/imputeformer/mlp.py @@ -28,9 +28,7 @@ class MLP(nn.Module): Simple Multi-layer Perceptron encoder with optional linear readout. """ - def __init__( - self, input_size, hidden_size, output_size=None, n_layers=1, dropout=0.0 - ): + def __init__(self, input_size, hidden_size, output_size=None, n_layers=1, dropout=0.0): super(MLP, self).__init__() layers = [ diff --git a/pypots/nn/modules/informer/autoencoder.py b/pypots/nn/modules/informer/autoencoder.py index e2fecd73..aaa8bb44 100644 --- a/pypots/nn/modules/informer/autoencoder.py +++ b/pypots/nn/modules/informer/autoencoder.py @@ -12,9 +12,7 @@ class InformerEncoder(nn.Module): def __init__(self, attn_layers, conv_layers=None, norm_layer=None): super().__init__() self.attn_layers = nn.ModuleList(attn_layers) - self.conv_layers = ( - nn.ModuleList(conv_layers) if conv_layers is not None else None - ) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None self.norm = norm_layer def forward(self, x, attn_mask=None): diff --git a/pypots/nn/modules/informer/layers.py b/pypots/nn/modules/informer/layers.py index d7f92dc3..e63aabe8 100644 --- a/pypots/nn/modules/informer/layers.py +++ b/pypots/nn/modules/informer/layers.py @@ -21,9 +21,7 @@ class ProbMask: def __init__(self, B, H, L, index, scores, device="cpu"): _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) - indicator = _mask_ex[ - torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : - ].to(device) + indicator = _mask_ex[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :].to(device) self._mask = indicator.view(scores.shape).to(device) @property @@ -76,22 +74,16 @@ def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) # calculate the sampled Q_K K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) - index_sample = torch.randint( - L_K, (L_Q, sample_k) - ) # real U = U_part(factor*ln(L_k))*L_q + index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] - Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze( - -2 - ) + Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) # find the Top_k query with sparisty measurement M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) M_top = M.topk(n_top, sorted=False)[1] # use the reduced Q to calculate Q_K - Q_reduce = Q[ - torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, : - ] # factor*ln(L_q) + Q_reduce = Q[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :] # factor*ln(L_q) Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k return Q_K, M_top @@ -116,14 +108,12 @@ def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) - context_in[ - torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : - ] = torch.matmul(attn, V).type_as(context_in) + context_in[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = torch.matmul( + attn, V + ).type_as(context_in) attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) - attns[ - torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : - ] = attn + attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn return context_in, attns def forward( @@ -159,9 +149,7 @@ def forward( # get the context context = self._get_initial_context(v, L_Q) # update the context with selected top_k queries - context, attn = self._update_context( - context, v, scores_top, index, L_Q, attn_mask - ) + context, attn = self._update_context(context, v, scores_top, index, L_Q, attn_mask) return context.transpose(2, 1).contiguous(), attn @@ -212,16 +200,10 @@ def __init__( self.activation = F.relu if activation == "relu" else F.gelu def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): - x = x + self.dropout( - self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] - ) + x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]) x = self.norm1(x) - x = x + self.dropout( - self.cross_attention( - x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta - )[0] - ) + x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta)[0]) y = x = self.norm2(x) y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) diff --git a/pypots/nn/modules/koopa/layers.py b/pypots/nn/modules/koopa/layers.py index a14a4882..9cc8ff36 100644 --- a/pypots/nn/modules/koopa/layers.py +++ b/pypots/nn/modules/koopa/layers.py @@ -99,12 +99,7 @@ def one_step_forward(self, z, return_rec=False, return_K=False): self.K = torch.linalg.lstsq(x, y).solution # B E E if torch.isnan(self.K).any(): print("Encounter K with nan, replace K by identity matrix") - self.K = ( - torch.eye(self.K.shape[1]) - .to(self.K.device) - .unsqueeze(0) - .repeat(B, 1, 1) - ) + self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1) z_pred = torch.bmm(z[:, -1:], self.K) if return_rec: @@ -148,12 +143,7 @@ def forward(self, z, pred_len=1): if torch.isnan(self.K).any(): print("Encounter K with nan, replace K by identity matrix") - self.K = ( - torch.eye(self.K.shape[1]) - .to(self.K.device) - .unsqueeze(0) - .repeat(B, 1, 1) - ) + self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1) z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) # B L E @@ -161,23 +151,13 @@ def forward(self, z, pred_len=1): self.K_step = torch.linalg.matrix_power(self.K, pred_len) if torch.isnan(self.K_step).any(): print("Encounter multistep K with nan, replace it by identity matrix") - self.K_step = ( - torch.eye(self.K_step.shape[1]) - .to(self.K_step.device) - .unsqueeze(0) - .repeat(B, 1, 1) - ) + self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1) z_pred = torch.bmm(z[:, -pred_len:, :], self.K_step) else: self.K_step = torch.linalg.matrix_power(self.K, input_len) if torch.isnan(self.K_step).any(): print("Encounter multistep K with nan, replace it by identity matrix") - self.K_step = ( - torch.eye(self.K_step.shape[1]) - .to(self.K_step.device) - .unsqueeze(0) - .repeat(B, 1, 1) - ) + self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1) temp_z_pred, all_pred = z, [] for _ in range(math.ceil(pred_len / input_len)): temp_z_pred = torch.bmm(temp_z_pred, self.K_step) @@ -247,9 +227,7 @@ class TimeInvKP(nn.Module): Utilize lookback and forecast window snapshots to predict the future of time-invariant term """ - def __init__( - self, input_len=96, pred_len=96, dynamic_dim=128, encoder=None, decoder=None - ): + def __init__(self, input_len=96, pred_len=96, dynamic_dim=128, encoder=None, decoder=None): super().__init__() self.dynamic_dim = dynamic_dim self.input_len = input_len diff --git a/pypots/nn/modules/micn/layers.py b/pypots/nn/modules/micn/layers.py index 8189d72e..a67fc424 100644 --- a/pypots/nn/modules/micn/layers.py +++ b/pypots/nn/modules/micn/layers.py @@ -69,9 +69,7 @@ def __init__( ] ) - self.decomp = nn.ModuleList( - [SeriesDecompositionBlock(k) for k in decomp_kernel] - ) + self.decomp = nn.ModuleList([SeriesDecompositionBlock(k) for k in decomp_kernel]) self.merge = torch.nn.Conv2d( in_channels=feature_size, out_channels=feature_size, @@ -79,12 +77,8 @@ def __init__( ) # feedforward network - self.conv1 = nn.Conv1d( - in_channels=feature_size, out_channels=feature_size * 4, kernel_size=1 - ) - self.conv2 = nn.Conv1d( - in_channels=feature_size * 4, out_channels=feature_size, kernel_size=1 - ) + self.conv1 = nn.Conv1d(in_channels=feature_size, out_channels=feature_size * 4, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=feature_size * 4, out_channels=feature_size, kernel_size=1) self.norm1 = nn.LayerNorm(feature_size) self.norm2 = nn.LayerNorm(feature_size) @@ -101,9 +95,7 @@ def conv_trans_conv(self, input, conv1d, conv1d_trans, isometric): x = x1 # isometric convolution - zeros = torch.zeros( - (x.shape[0], x.shape[1], x.shape[2] - 1), device=input.device - ) + zeros = torch.zeros((x.shape[0], x.shape[1], x.shape[2] - 1), device=input.device) x = torch.cat((zeros, x), dim=-1) x = self.drop(self.act(isometric(x))) x = self.norm((x + x1).permute(0, 2, 1)).permute(0, 2, 1) @@ -120,9 +112,7 @@ def forward(self, src): multi = [] for i in range(len(self.conv_kernel)): src_out, trend1 = self.decomp[i](src) - src_out = self.conv_trans_conv( - src_out, self.conv[i], self.conv_trans[i], self.isometric_conv[i] - ) + src_out = self.conv_trans_conv(src_out, self.conv[i], self.conv_trans[i], self.isometric_conv[i]) multi.append(src_out) # merge diff --git a/pypots/nn/modules/moderntcn/backbone.py b/pypots/nn/modules/moderntcn/backbone.py index a9e3b388..bf34d92a 100644 --- a/pypots/nn/modules/moderntcn/backbone.py +++ b/pypots/nn/modules/moderntcn/backbone.py @@ -134,13 +134,9 @@ def __init__( ) else: if patch_num % pow(downsampling_ratio, (self.num_stage - 1)) == 0: - self.head_nf = ( - d_model * patch_num // pow(downsampling_ratio, (self.num_stage - 1)) - ) + self.head_nf = d_model * patch_num // pow(downsampling_ratio, (self.num_stage - 1)) else: - self.head_nf = d_model * ( - patch_num // pow(downsampling_ratio, (self.num_stage - 1)) + 1 - ) + self.head_nf = d_model * (patch_num // pow(downsampling_ratio, (self.num_stage - 1)) + 1) self.head = FlattenHead( self.head_nf, diff --git a/pypots/nn/modules/moderntcn/layers.py b/pypots/nn/modules/moderntcn/layers.py index b7c21058..66676848 100644 --- a/pypots/nn/modules/moderntcn/layers.py +++ b/pypots/nn/modules/moderntcn/layers.py @@ -9,9 +9,7 @@ from torch import nn -def get_conv1d( - in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias -): +def get_conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias): return nn.Conv1d( in_channels=in_channels, out_channels=out_channels, diff --git a/pypots/nn/modules/mrnn/backbone.py b/pypots/nn/modules/mrnn/backbone.py index 2478da28..0f7f2fbc 100644 --- a/pypots/nn/modules/mrnn/backbone.py +++ b/pypots/nn/modules/mrnn/backbone.py @@ -36,21 +36,15 @@ def gene_hidden_states(self, inputs, feature_idx): device = X_f.device batch_size = X_f.size()[0] - f_hidden_state_0 = torch.zeros( - (1, batch_size, self.rnn_hidden_size), device=device - ) - b_hidden_state_0 = torch.zeros( - (1, batch_size, self.rnn_hidden_size), device=device - ) + f_hidden_state_0 = torch.zeros((1, batch_size, self.rnn_hidden_size), device=device) + b_hidden_state_0 = torch.zeros((1, batch_size, self.rnn_hidden_size), device=device) f_input = torch.cat([X_f, M_f, D_f], dim=2) b_input = torch.cat([X_b, M_b, D_b], dim=2) hidden_states_f, _ = self.f_rnn(f_input, f_hidden_state_0) hidden_states_b, _ = self.b_rnn(b_input, b_hidden_state_0) hidden_states_b = torch.flip(hidden_states_b, dims=[1]) - feature_estimation = self.concated_hidden_project( - torch.cat([hidden_states_f, hidden_states_b], dim=2) - ) + feature_estimation = self.concated_hidden_project(torch.cat([hidden_states_f, hidden_states_b], dim=2)) return feature_estimation, hidden_states_f, hidden_states_b @@ -60,9 +54,7 @@ def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso feature_collector = [] for f in range(self.n_features): - feat_estimation, hid_states_f, hid_states_b = self.gene_hidden_states( - inputs, f - ) + feat_estimation, hid_states_f, hid_states_b = self.gene_hidden_states(inputs, f) feature_collector.append(feat_estimation) RNN_estimation = torch.concat(feature_collector, dim=2) diff --git a/pypots/nn/modules/nonstationary_transformer/autoencoder.py b/pypots/nn/modules/nonstationary_transformer/autoencoder.py index fcd7863f..3006e45d 100644 --- a/pypots/nn/modules/nonstationary_transformer/autoencoder.py +++ b/pypots/nn/modules/nonstationary_transformer/autoencoder.py @@ -109,9 +109,7 @@ def forward( # triangular causal mask bz, n_steps, _ = x.shape mask_shape = [bz, n_steps, n_steps] - src_mask = torch.triu( - torch.ones(mask_shape, dtype=torch.bool), diagonal=1 - ).to(x.device) + src_mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(x.device) for layer in self.enc_layer_stack: enc_output, attn_weights = layer(enc_output, src_mask, **kwargs) diff --git a/pypots/nn/modules/nonstationary_transformer/layers.py b/pypots/nn/modules/nonstationary_transformer/layers.py index 8464bc9e..347554ac 100644 --- a/pypots/nn/modules/nonstationary_transformer/layers.py +++ b/pypots/nn/modules/nonstationary_transformer/layers.py @@ -40,9 +40,7 @@ def forward( tau, delta = kwargs["tau"], kwargs["delta"] tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1 - delta = ( - 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) - ) # B x 1 x 1 x S + delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x S # De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors scores = torch.einsum("blhe,bshe->bhls", q, k) * tau + delta diff --git a/pypots/nn/modules/patchtst/autoencoder.py b/pypots/nn/modules/patchtst/autoencoder.py index 8263817d..07c8c55a 100644 --- a/pypots/nn/modules/patchtst/autoencoder.py +++ b/pypots/nn/modules/patchtst/autoencoder.py @@ -42,9 +42,7 @@ def forward(self, x, attn_mask=None): enc_out, attns = self.encoder(x, attn_mask) - enc_out = enc_out.reshape( - -1, self.d_model, enc_out.shape[-2], enc_out.shape[-1] - ) + enc_out = enc_out.reshape(-1, self.d_model, enc_out.shape[-2], enc_out.shape[-1]) # [bz, d_model, d_model, n_patches] -> [bz, d_model, n_patches, d_model] enc_out = enc_out.permute(0, 1, 3, 2) return enc_out, attns diff --git a/pypots/nn/modules/patchtst/layers.py b/pypots/nn/modules/patchtst/layers.py index 3990954b..ed2ac651 100644 --- a/pypots/nn/modules/patchtst/layers.py +++ b/pypots/nn/modules/patchtst/layers.py @@ -60,9 +60,7 @@ def forward(self, x): x: [bs x nvars x d_model x num_patch] output: [bs x output_dim] """ - x = x[ - :, :, :, -1 - ] # only consider the last item in the sequence, x: bs x nvars x d_model + x = x[:, :, :, -1] # only consider the last item in the sequence, x: bs x nvars x d_model x = self.flatten(x) # x: bs x nvars * d_model x = self.dropout(x) y = self.linear(x) # y: bs x output_dim @@ -83,9 +81,7 @@ def forward(self, x): x: [bs x nvars x d_model x num_patch] output: [bs x n_classes] """ - x = x[ - :, :, :, -1 - ] # only consider the last item in the sequence, x: bs x nvars x d_model + x = x[:, :, :, -1] # only consider the last item in the sequence, x: bs x nvars x d_model x = self.flatten(x) # x: bs x nvars * d_model x = self.dropout(x) y = self.linear(x) # y: bs x n_classes diff --git a/pypots/nn/modules/pyraformer/layers.py b/pypots/nn/modules/pyraformer/layers.py index 0fc61e90..a6fe7598 100644 --- a/pypots/nn/modules/pyraformer/layers.py +++ b/pypots/nn/modules/pyraformer/layers.py @@ -36,15 +36,11 @@ def get_mask(input_size, window_size, inner_size): for layer_idx in range(1, len(all_size)): start = sum(all_size[:layer_idx]) for i in range(start, start + all_size[layer_idx]): - left_side = (start - all_size[layer_idx - 1]) + (i - start) * window_size[ - layer_idx - 1 - ] + left_side = (start - all_size[layer_idx - 1]) + (i - start) * window_size[layer_idx - 1] if i == (start + all_size[layer_idx] - 1): right_side = start else: - right_side = (start - all_size[layer_idx - 1]) + ( - i - start + 1 - ) * window_size[layer_idx - 1] + right_side = (start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1] mask[i, left_side:right_side] = 1 mask[left_side:right_side, i] = 1 @@ -64,9 +60,7 @@ def refer_points(all_sizes, window_size): for j in range(1, len(all_sizes)): start = sum(all_sizes[:j]) inner_layer_idx = former_index - (start - all_sizes[j - 1]) - former_index = start + min( - inner_layer_idx // window_size[j - 1], all_sizes[j] - 1 - ) + former_index = start + min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1) indexes[i][j] = former_index indexes = indexes.unsqueeze(0).unsqueeze(3) diff --git a/pypots/nn/modules/raindrop/backbone.py b/pypots/nn/modules/raindrop/backbone.py index 06f74d06..82c0c323 100644 --- a/pypots/nn/modules/raindrop/backbone.py +++ b/pypots/nn/modules/raindrop/backbone.py @@ -2,7 +2,6 @@ """ - # Created by Wenjie Du # License: BSD-3-Clause @@ -70,15 +69,11 @@ def __init__( if self.sensor_wise_mask: dim_check = n_features * (self.d_ob + d_pe) assert dim_check % n_heads == 0, "dim_check must be divisible by n_heads" - encoder_layers = TransformerEncoderLayer( - n_features * (self.d_ob + d_pe), n_heads, d_ffn, dropout - ) + encoder_layers = TransformerEncoderLayer(n_features * (self.d_ob + d_pe), n_heads, d_ffn, dropout) else: dim_check = d_model + d_pe assert dim_check % n_heads == 0, "dim_check must be divisible by n_heads" - encoder_layers = TransformerEncoderLayer( - d_model + d_pe, n_heads, d_ffn, dropout - ) + encoder_layers = TransformerEncoderLayer(d_model + d_pe, n_heads, d_ffn, dropout) self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers) self.R_u = nn.Parameter(torch.Tensor(1, self.n_features * self.d_ob)) @@ -163,9 +158,7 @@ def forward( edge_index = torch.nonzero(adj).T edge_weights = adj[edge_index[0], edge_index[1]] - output = torch.zeros( - [max_len, batch_size, self.n_features * self.d_ob], device=device - ) + output = torch.zeros([max_len, batch_size, self.n_features * self.d_ob], device=device) alpha_all = torch.zeros([edge_index.shape[1], batch_size], device=device) @@ -174,9 +167,7 @@ def forward( step_data = x[:, unit, :] p_t = pe[:, unit, :] - step_data = step_data.reshape( - [max_len, self.n_features, self.d_ob] - ).permute(1, 0, 2) + step_data = step_data.reshape([max_len, self.n_features, self.d_ob]).permute(1, 0, 2) step_data = step_data.reshape(self.n_features, max_len * self.d_ob) step_data, attention_weights = self.ob_propagation( diff --git a/pypots/nn/modules/raindrop/layers.py b/pypots/nn/modules/raindrop/layers.py index 0c56a5ba..14180de6 100644 --- a/pypots/nn/modules/raindrop/layers.py +++ b/pypots/nn/modules/raindrop/layers.py @@ -36,9 +36,7 @@ class PositionalEncoding(nn.Module): def __init__(self, d_pe: int, max_len: int = 500): super().__init__() - assert ( - d_pe % 2 == 0 - ), "d_pe should be even, otherwise the output dims will be not equal to d_pe" + assert d_pe % 2 == 0, "d_pe should be even, otherwise the output dims will be not equal to d_pe" self.max_len = max_len self._num_timescales = d_pe // 2 @@ -58,12 +56,8 @@ def forward(self, time_vectors: torch.Tensor) -> torch.Tensor: timescales = self.max_len ** np.linspace(0, 1, self._num_timescales) times = time_vectors.unsqueeze(2) - scaled_time = times / torch.from_numpy(timescales[None, None, :]).to( - time_vectors.device - ) - pe = torch.cat( - [torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1 - ) # T x B x d_model + scaled_time = times / torch.from_numpy(timescales[None, None, :]).to(time_vectors.device) + pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) # T x B x d_model pe = pe.type(torch.FloatTensor) return pe @@ -126,9 +120,7 @@ def __init__( self.bias = Parameter(torch.Tensor(heads * out_channels)) self.n_nodes = n_nodes - self.nodewise_weights = Parameter( - torch.Tensor(self.n_nodes, heads * out_channels) - ) + self.nodewise_weights = Parameter(torch.Tensor(self.n_nodes, heads * out_channels)) self.increase_dim = Linear(in_channels[1], heads * out_channels * 8) self.map_weights = Parameter(torch.Tensor(self.n_nodes, heads * 16)) @@ -183,9 +175,7 @@ def forward( if isinstance(x, Tensor): x: PairTensor = (x, x) - out = self.propagate( - edge_index, x=x, edge_weights=edge_weights, edge_attr=edge_attr, size=None - ) + out = self.propagate(edge_index, x=x, edge_weights=edge_weights, edge_attr=edge_attr, size=None) alpha = self._alpha self._alpha = None @@ -301,9 +291,7 @@ def message( target_nodes = self.edge_index[1] w1 = self.nodewise_weights[source_nodes].unsqueeze(-1) w2 = self.nodewise_weights[target_nodes].unsqueeze(1) - out = torch.bmm( - x_i.view(-1, self.heads, self.out_channels), torch.bmm(w1, w2) - ) + out = torch.bmm(x_i.view(-1, self.heads, self.out_channels), torch.bmm(w1, w2)) if use_beta: out = out * gamma.view(-1, self.heads, out.shape[-1]) else: @@ -328,11 +316,7 @@ def aggregate( :meth:`__init__` by the :obj:`aggr` argument. """ index = self.index - return scatter( - inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr - ) + return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) def __repr__(self): - return "{}({}, {}, heads={})".format( - self.__class__.__name__, self.in_channels, self.out_channels, self.heads - ) + return "{}({}, {}, heads={})".format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads) diff --git a/pypots/nn/modules/reformer/local_attention.py b/pypots/nn/modules/reformer/local_attention.py index 37f1e029..a617b9ba 100644 --- a/pypots/nn/modules/reformer/local_attention.py +++ b/pypots/nn/modules/reformer/local_attention.py @@ -75,9 +75,7 @@ def look_around(x, backward=1, forward=0, pad_value=-1, dim=2): t = x.shape[1] dims = (len(x.shape) - dim) * (0, 0) padded_x = F.pad(x, (*dims, backward, forward), value=pad_value) - tensors = [ - padded_x[:, ind : (ind + t), ...] for ind in range(forward + backward + 1) - ] + tensors = [padded_x[:, ind : (ind + t), ...] for ind in range(forward + backward + 1)] return torch.cat(tensors, dim=dim) @@ -92,9 +90,7 @@ def __init__(self, dim, scale_base=None, use_xpos=False): self.use_xpos = use_xpos self.scale_base = scale_base - assert not ( - use_xpos and not exists(scale_base) - ), "scale base must be defined if using xpos" + assert not (use_xpos and not exists(scale_base)), "scale base must be defined if using xpos" scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) self.register_buffer("scale", scale, persistent=False) @@ -171,9 +167,7 @@ def __init__( scale_base=default(xpos_scale_base, window_size // 2), ) - def forward( - self, q, k, v, mask=None, input_mask=None, attn_bias=None, window_size=None - ): + def forward(self, q, k, v, mask=None, input_mask=None, attn_bias=None, window_size=None): mask = default(mask, input_mask) @@ -181,15 +175,7 @@ def forward( exists(window_size) and not self.use_xpos ), "cannot perform window size extrapolation if xpos is not turned on" - ( - autopad, - pad_value, - window_size, - causal, - look_backward, - look_forward, - shared_qk, - ) = ( + (autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk) = ( self.autopad, -1, default(window_size, self.window_size), @@ -206,9 +192,7 @@ def forward( if autopad: orig_seq_len = q.shape[1] - (needed_pad, q), (_, k), (_, v) = map( - lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v) - ) + (needed_pad, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v)) b, n, dim_head, device = *q.shape, q.device @@ -228,15 +212,11 @@ def forward( # bucketing - bq, bk, bv = map( - lambda t: rearrange(t, "b (w n) d -> b w n d", w=windows), (q, k, v) - ) + bq, bk, bv = map(lambda t: rearrange(t, "b (w n) d -> b w n d", w=windows), (q, k, v)) bq = bq * scale - look_around_kwargs = dict( - backward=look_backward, forward=look_forward, pad_value=pad_value - ) + look_around_kwargs = dict(backward=look_backward, forward=look_forward, pad_value=pad_value) bk = look_around(bk, **look_around_kwargs) bv = look_around(bv, **look_around_kwargs) @@ -290,9 +270,7 @@ def forward( max_backward_window_size = self.window_size * self.look_backward max_forward_window_size = self.window_size * self.look_forward window_mask = ( - ((bq_k - max_forward_window_size) > bq_t) - | (bq_t > (bq_k + max_backward_window_size)) - | pad_mask + ((bq_k - max_forward_window_size) > bq_t) | (bq_t > (bq_k + max_backward_window_size)) | pad_mask ) sim = sim.masked_fill(window_mask, mask_value) else: diff --git a/pypots/nn/modules/reformer/lsh_attention.py b/pypots/nn/modules/reformer/lsh_attention.py index 40d82076..af2bb2e9 100644 --- a/pypots/nn/modules/reformer/lsh_attention.py +++ b/pypots/nn/modules/reformer/lsh_attention.py @@ -53,12 +53,8 @@ def batched_index_select(values, indices): def process_inputs_chunk(fn, chunks=1, dim=0): def inner_fn(*args, **kwargs): keys, values, len_args = kwargs.keys(), kwargs.values(), len(args) - chunked_args = list( - zip(*map(lambda x: x.chunk(chunks, dim=dim), list(args) + list(values))) - ) - all_args = map( - lambda x: (x[:len_args], dict(zip(keys, x[len_args:]))), chunked_args - ) + chunked_args = list(zip(*map(lambda x: x.chunk(chunks, dim=dim), list(args) + list(values)))) + all_args = map(lambda x: (x[:len_args], dict(zip(keys, x[len_args:]))), chunked_args) outputs = [fn(*c_args, **c_kwargs) for c_args, c_kwargs in all_args] return tuple(map(lambda x: torch.cat(x, dim=dim), zip(*outputs))) @@ -101,9 +97,7 @@ def cached_fn(*args, **kwargs): def cache_method_decorator(cache_attr, cache_namespace, reexecute=False): def inner_fn(fn): @wraps(fn) - def wrapper( - self, *args, key_namespace=None, fetch=False, set_cache=True, **kwargs - ): + def wrapper(self, *args, key_namespace=None, fetch=False, set_cache=True, **kwargs): namespace_str = str(default(key_namespace, "")) _cache = getattr(self, cache_attr) _keyname = f"{cache_namespace}:{namespace_str}" @@ -150,9 +144,7 @@ def __init__(self, causal=False, dropout=0.0): self.causal = causal self.dropout = nn.Dropout(dropout) - def forward( - self, qk, v, query_len=None, input_mask=None, input_attn_mask=None, **kwargs - ): + def forward(self, qk, v, query_len=None, input_mask=None, input_attn_mask=None, **kwargs): b, seq_len, dim = qk.shape query_len = default(query_len, seq_len) t = query_len @@ -175,9 +167,7 @@ def forward( # Mask for post qk attention logits of the input sequence if input_attn_mask is not None: - input_attn_mask = F.pad( - input_attn_mask, (0, seq_len - input_attn_mask.shape[-1]), value=True - ) + input_attn_mask = F.pad(input_attn_mask, (0, seq_len - input_attn_mask.shape[-1]), value=True) dot.masked_fill_(~input_attn_mask, masked_value) if self.causal: @@ -213,10 +203,9 @@ def __init__( self.dropout = nn.Dropout(dropout) self.dropout_for_hash = nn.Dropout(drop_for_hash_rate) - assert rehash_each_round or allow_duplicate_attention, ( - "The setting {allow_duplicate_attention=False, rehash_each_round=False}" - " is not implemented." - ) + assert ( + rehash_each_round or allow_duplicate_attention + ), "The setting {allow_duplicate_attention=False, rehash_each_round=False} is not implemented." self.causal = causal self.bucket_size = bucket_size @@ -253,9 +242,7 @@ def hash_vectors(self, n_buckets, vecs): rot_size // 2, ) - random_rotations = torch.randn( - rotations_shape, dtype=vecs.dtype, device=device - ).expand(batch_size, -1, -1, -1) + random_rotations = torch.randn(rotations_shape, dtype=vecs.dtype, device=device).expand(batch_size, -1, -1, -1) dropped_vecs = self.dropout_for_hash(vecs) rotated_vecs = torch.einsum("btf,bfhi->bhti", dropped_vecs, random_rotations) @@ -323,11 +310,7 @@ def forward( total_hashes = self.n_hashes - ticker = ( - torch.arange(total_hashes * seqlen, device=device) - .unsqueeze(0) - .expand_as(buckets) - ) + ticker = torch.arange(total_hashes * seqlen, device=device).unsqueeze(0).expand_as(buckets) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = buckets_and_t.detach() @@ -396,9 +379,7 @@ def look_one_back(x): # Input mask for padding in variable lengthed sequences if input_mask is not None: - input_mask = F.pad( - input_mask, (0, seqlen - input_mask.shape[1]), value=True - ) + input_mask = F.pad(input_mask, (0, seqlen - input_mask.shape[1]), value=True) mq = input_mask.gather(1, st).reshape((batch_size, chunk_size, -1)) mkv = look_one_back(mq) mask = mq[:, :, :, None] * mkv[:, :, None, :] @@ -420,9 +401,7 @@ def look_one_back(x): # Mask out attention to other hash buckets. if not self._attend_across_buckets: - bq_buckets = bkv_buckets = torch.reshape( - sbuckets_and_t // seqlen, (batch_size, chunk_size, -1) - ) + bq_buckets = bkv_buckets = torch.reshape(sbuckets_and_t // seqlen, (batch_size, chunk_size, -1)) bkv_buckets = look_one_back(bkv_buckets) bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :] dots.masked_fill_(bucket_mask, masked_value) @@ -448,9 +427,7 @@ def look_one_back(x): ).permute((0, 2, 1)) slocs = batched_index_select(locs, st) - b_locs = torch.reshape( - slocs, (batch_size, chunk_size, -1, 2 * total_hashes) - ) + b_locs = torch.reshape(slocs, (batch_size, chunk_size, -1, 2 * total_hashes)) b_locs1 = b_locs[:, :, :, None, :total_hashes] @@ -501,14 +478,10 @@ def look_one_back(x): if self._return_attn: attn_unsort = (bq_t * seqlen)[:, :, :, None] + bkv_t[:, :, None, :] attn_unsort = attn_unsort.view(batch_size * total_hashes, -1).long() - unsorted_dots = torch.zeros( - batch_size * total_hashes, seqlen * seqlen, device=device - ) + unsorted_dots = torch.zeros(batch_size * total_hashes, seqlen * seqlen, device=device) unsorted_dots.scatter_add_(1, attn_unsort, dots.view_as(attn_unsort)) del attn_unsort - unsorted_dots = unsorted_dots.reshape( - batch_size, total_hashes, seqlen, seqlen - ) + unsorted_dots = unsorted_dots.reshape(batch_size, total_hashes, seqlen, seqlen) attn = torch.sum(unsorted_dots[:, :, 0:query_len, :] * probs, dim=1) # return output, attention matrix, and bucket distribution @@ -539,12 +512,8 @@ def __init__( **kwargs, ): super().__init__() - assert ( - dim_head or (dim % heads) == 0 - ), "dimensions must be divisible by number of heads" - assert ( - n_local_attn_heads < heads - ), "local attention heads must be less than number of heads" + assert dim_head or (dim % heads) == 0, "dimensions must be divisible by number of heads" + assert n_local_attn_heads < heads, "local attention heads must be less than number of heads" dim_head = default(dim_head, dim // heads) dim_heads = dim_head * heads @@ -580,11 +549,7 @@ def __init__( self.full_attn_thres = default(full_attn_thres, bucket_size) self.num_mem_kv = num_mem_kv - self.mem_kv = ( - nn.Parameter(torch.randn(1, num_mem_kv, dim, requires_grad=True)) - if num_mem_kv > 0 - else None - ) + self.mem_kv = nn.Parameter(torch.randn(1, num_mem_kv, dim, requires_grad=True)) if num_mem_kv > 0 else None self.n_local_attn_heads = n_local_attn_heads self.local_attn = LocalAttention( @@ -657,16 +622,12 @@ def split_heads(v): masks["input_mask"] = mask if input_attn_mask is not None: - input_attn_mask = merge_batch_and_heads( - expand_dim(1, lsh_h, input_attn_mask) - ) + input_attn_mask = merge_batch_and_heads(expand_dim(1, lsh_h, input_attn_mask)) masks["input_attn_mask"] = input_attn_mask attn_fn = self.lsh_attn if not use_full_attn else self.full_attn partial_attn_fn = partial(attn_fn, query_len=t, pos_emb=pos_emb, **kwargs) - attn_fn_in_chunks = process_inputs_chunk( - partial_attn_fn, chunks=self.attn_chunks - ) + attn_fn_in_chunks = process_inputs_chunk(partial_attn_fn, chunks=self.attn_chunks) out, attn, buckets = attn_fn_in_chunks(qk, v, **masks) diff --git a/pypots/nn/modules/revin/layers.py b/pypots/nn/modules/revin/layers.py index 21719830..9264a485 100644 --- a/pypots/nn/modules/revin/layers.py +++ b/pypots/nn/modules/revin/layers.py @@ -60,14 +60,10 @@ def _normalize(self, x, missing_mask=None): if missing_mask is None: # original implementation mean = torch.mean(x, dim=dim2reduce, keepdim=True) - stdev = torch.sqrt( - torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps - ) + stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps) else: # pypots implementation for POTS data - missing_sum = ( - torch.sum(missing_mask == 1, dim=dim2reduce, keepdim=True) + self.eps - ) + missing_sum = torch.sum(missing_mask == 1, dim=dim2reduce, keepdim=True) + self.eps mean = torch.sum(x, dim=dim2reduce, keepdim=True) / missing_sum x_enc = x.masked_fill(missing_mask == 0, 0) variance = torch.sum(x_enc * x_enc, dim=dim2reduce, keepdim=True) + self.eps diff --git a/pypots/nn/modules/saits/backbone.py b/pypots/nn/modules/saits/backbone.py index 0b0911c6..592b2d45 100644 --- a/pypots/nn/modules/saits/backbone.py +++ b/pypots/nn/modules/saits/backbone.py @@ -90,14 +90,10 @@ def __init__( # for delta decay factor self.weight_combine = nn.Linear(n_features + n_steps, n_features) - def forward( - self, X, missing_mask, attn_mask: Optional = None - ) -> Tuple[torch.Tensor, ...]: + def forward(self, X, missing_mask, attn_mask: Optional = None) -> Tuple[torch.Tensor, ...]: # first DMSA block - enc_output = self.embedding_1( - X, missing_mask - ) # namely, term e in the math equation + enc_output = self.embedding_1(X, missing_mask) # namely, term e in the math equation first_DMSA_attn_weights = None for encoder_layer in self.layer_stack_for_first_block: enc_output, first_DMSA_attn_weights = encoder_layer(enc_output, attn_mask) @@ -105,9 +101,7 @@ def forward( X_prime = missing_mask * X + (1 - missing_mask) * X_tilde_1 # second DMSA block - enc_output = self.embedding_2( - X_prime, missing_mask - ) # namely term alpha in math algo + enc_output = self.embedding_2(X_prime, missing_mask) # namely term alpha in math algo second_DMSA_attn_weights = None for encoder_layer in self.layer_stack_for_second_block: enc_output, second_DMSA_attn_weights = encoder_layer(enc_output, attn_mask) @@ -115,9 +109,7 @@ def forward( # attention-weighted combine copy_second_DMSA_weights = second_DMSA_attn_weights.clone() - copy_second_DMSA_weights = copy_second_DMSA_weights.squeeze( - dim=1 - ) # namely term A_hat in Eq. + copy_second_DMSA_weights = copy_second_DMSA_weights.squeeze(dim=1) # namely term A_hat in Eq. if len(copy_second_DMSA_weights.shape) == 4: # if having more than 1 head, then average attention weights from all heads copy_second_DMSA_weights = torch.transpose(copy_second_DMSA_weights, 1, 3) @@ -126,9 +118,7 @@ def forward( # namely term eta combining_weights = torch.sigmoid( - self.weight_combine( - torch.cat([missing_mask, copy_second_DMSA_weights], dim=2) - ) + self.weight_combine(torch.cat([missing_mask, copy_second_DMSA_weights], dim=2)) ) # combine X_tilde_1 and X_tilde_2 X_tilde_3 = (1 - combining_weights) * X_tilde_2 + combining_weights * X_tilde_1 diff --git a/pypots/nn/modules/saits/embedding.py b/pypots/nn/modules/saits/embedding.py index 51f97b05..53385af9 100644 --- a/pypots/nn/modules/saits/embedding.py +++ b/pypots/nn/modules/saits/embedding.py @@ -47,9 +47,7 @@ def __init__( self.dropout_rate = dropout self.embedding_layer = nn.Linear(d_in, d_out) - self.position_enc = ( - PositionalEncoding(d_out, n_positions=n_max_steps) if with_pos else None - ) + self.position_enc = PositionalEncoding(d_out, n_positions=n_max_steps) if with_pos else None self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None def forward(self, X, missing_mask=None): diff --git a/pypots/nn/modules/saits/loss.py b/pypots/nn/modules/saits/loss.py index d7bcc786..0052dce2 100644 --- a/pypots/nn/modules/saits/loss.py +++ b/pypots/nn/modules/saits/loss.py @@ -27,13 +27,9 @@ def __init__( def forward(self, reconstruction, X_ori, missing_mask, indicating_mask): # calculate loss for the observed reconstruction task (ORT) - ORT_loss = self.ORT_weight * self.loss_calc_func( - reconstruction, X_ori, missing_mask - ) + ORT_loss = self.ORT_weight * self.loss_calc_func(reconstruction, X_ori, missing_mask) # calculate loss for the masked imputation task (MIT) - MIT_loss = self.MIT_weight * self.loss_calc_func( - reconstruction, X_ori, indicating_mask - ) + MIT_loss = self.MIT_weight * self.loss_calc_func(reconstruction, X_ori, indicating_mask) # calculate the loss to back propagate for model updating loss = ORT_loss + MIT_loss return loss, ORT_loss, MIT_loss diff --git a/pypots/nn/modules/scinet/backbone.py b/pypots/nn/modules/scinet/backbone.py index 8b80b931..06423b1b 100644 --- a/pypots/nn/modules/scinet/backbone.py +++ b/pypots/nn/modules/scinet/backbone.py @@ -85,9 +85,7 @@ def __init__( m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_() - self.projection1 = nn.Conv1d( - self.n_in_steps, self.n_out_steps, kernel_size=1, stride=1, bias=False - ) + self.projection1 = nn.Conv1d(self.n_in_steps, self.n_out_steps, kernel_size=1, stride=1, bias=False) self.div_projection = nn.ModuleList() self.overlap_len = self.n_in_steps // 4 self.div_len = self.n_in_steps // 6 @@ -97,23 +95,16 @@ def __init__( for layer_idx in range(self.n_decoder_layers - 1): div_projection = nn.ModuleList() for i in range(6): - lens = ( - min(i * self.div_len + self.overlap_len, self.n_in_steps) - - i * self.div_len - ) + lens = min(i * self.div_len + self.overlap_len, self.n_in_steps) - i * self.div_len div_projection.append(nn.Linear(lens, self.div_len)) self.div_projection.append(div_projection) if self.single_step_output_One: # only output the N_th timestep. if self.stacks == 2: if self.concat_len: - self.projection2 = nn.Conv1d( - self.concat_len + self.n_out_steps, 1, kernel_size=1, bias=False - ) + self.projection2 = nn.Conv1d(self.concat_len + self.n_out_steps, 1, kernel_size=1, bias=False) else: - self.projection2 = nn.Conv1d( - self.n_in_steps + self.n_out_steps, 1, kernel_size=1, bias=False - ) + self.projection2 = nn.Conv1d(self.n_in_steps + self.n_out_steps, 1, kernel_size=1, bias=False) else: # output the N timesteps. if self.stacks == 2: if self.concat_len: @@ -140,9 +131,7 @@ def __init__( max_timescale = 10000.0 min_timescale = 1.0 - log_timescale_increment = math.log( - float(max_timescale) / float(min_timescale) - ) / max(num_timescales - 1, 1) + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) # temp = torch.arange(num_timescales, dtype=torch.float32) inv_timescales = min_timescale * torch.exp( torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment @@ -157,9 +146,7 @@ def get_position_encoding(self, x): # temp1 = position.unsqueeze(1) # 5 1 # temp2 = self.inv_timescales.unsqueeze(0) # 1 256 scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0) # 5 256 - signal = torch.cat( - [torch.sin(scaled_time), torch.cos(scaled_time)], dim=1 - ) # [T, C] + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) # [T, C] signal = F.pad(signal, (0, 0, 0, self.pe_hidden_size % 2)) signal = signal.view(1, max_length, self.pe_hidden_size) @@ -187,14 +174,9 @@ def forward(self, x): div_x = x[ :, :, - i - * self.div_len : min( - i * self.div_len + self.overlap_len, self.n_in_steps - ), + i * self.div_len : min(i * self.div_len + self.overlap_len, self.n_in_steps), ] - output[:, :, i * self.div_len : (i + 1) * self.div_len] = div_layer( - div_x - ) + output[:, :, i * self.div_len : (i + 1) * self.div_len] = div_layer(div_x) x = output x = self.projection1(x) x = x.permute(0, 2, 1) diff --git a/pypots/nn/modules/scinet/layers.py b/pypots/nn/modules/scinet/layers.py index a3bea256..058b7445 100644 --- a/pypots/nn/modules/scinet/layers.py +++ b/pypots/nn/modules/scinet/layers.py @@ -43,15 +43,11 @@ def __init__( self.hidden_size = hidden_size self.groups = groups if self.kernel_size % 2 == 0: - pad_l = ( - self.dilation * (self.kernel_size - 2) // 2 + 1 - ) # by default: stride==1 + pad_l = self.dilation * (self.kernel_size - 2) // 2 + 1 # by default: stride==1 pad_r = self.dilation * (self.kernel_size) // 2 + 1 # by default: stride==1 else: - pad_l = ( - self.dilation * (self.kernel_size - 1) // 2 + 1 - ) # we fix the kernel size of the second layer as 3. + pad_l = self.dilation * (self.kernel_size - 1) // 2 + 1 # we fix the kernel size of the second layer as 3. pad_r = self.dilation * (self.kernel_size - 1) // 2 + 1 self.splitting = splitting self.split = Splitting() @@ -213,15 +209,11 @@ def __init__(self, in_planes, kernel_size, dropout, groups, hidden_size, INN): def forward(self, x): (x_even_update, x_odd_update) = self.interact(x) - return x_even_update.permute(0, 2, 1), x_odd_update.permute( - 0, 2, 1 - ) # even: B, T, D odd: B, T, D + return x_even_update.permute(0, 2, 1), x_odd_update.permute(0, 2, 1) # even: B, T, D odd: B, T, D class SCINet_Tree(nn.Module): - def __init__( - self, in_planes, current_level, kernel_size, dropout, groups, hidden_size, INN - ): + def __init__(self, in_planes, current_level, kernel_size, dropout, groups, hidden_size, INN): super().__init__() self.current_level = current_level @@ -275,15 +267,11 @@ def forward(self, x): if self.current_level == 0: return self.zip_up_the_pants(x_even_update, x_odd_update) else: - return self.zip_up_the_pants( - self.SCINet_Tree_even(x_even_update), self.SCINet_Tree_odd(x_odd_update) - ) + return self.zip_up_the_pants(self.SCINet_Tree_even(x_even_update), self.SCINet_Tree_odd(x_odd_update)) class EncoderTree(nn.Module): - def __init__( - self, in_planes, num_levels, kernel_size, dropout, groups, hidden_size, INN - ): + def __init__(self, in_planes, num_levels, kernel_size, dropout, groups, hidden_size, INN): super().__init__() self.levels = num_levels self.SCINet_Tree = SCINet_Tree( diff --git a/pypots/nn/modules/stemgnn/backbone.py b/pypots/nn/modules/stemgnn/backbone.py index 9a9a1b2a..83899580 100644 --- a/pypots/nn/modules/stemgnn/backbone.py +++ b/pypots/nn/modules/stemgnn/backbone.py @@ -39,12 +39,7 @@ def __init__( self.multi_layer = multi_layer self.stock_block = nn.ModuleList() self.stock_block.extend( - [ - StockBlockLayer( - self.time_step, self.unit, self.multi_layer, stack_cnt=i - ) - for i in range(self.stack_cnt) - ] + [StockBlockLayer(self.time_step, self.unit, self.multi_layer, stack_cnt=i) for i in range(self.stack_cnt)] ) self.fc = nn.Sequential( nn.Linear(int(self.time_step), int(self.time_step)), @@ -64,9 +59,7 @@ def get_laplacian(graph, normalize): """ if normalize: D = torch.diag(torch.sum(graph, dim=-1) ** (-1 / 2)) - L = torch.eye( - graph.size(0), device=graph.device, dtype=graph.dtype - ) - torch.mm(torch.mm(D, graph), D) + L = torch.eye(graph.size(0), device=graph.device, dtype=graph.dtype) - torch.mm(torch.mm(D, graph), D) else: D = torch.diag(torch.sum(graph, dim=-1)) L = D - graph @@ -81,19 +74,11 @@ def cheb_polynomial(laplacian): """ N = laplacian.size(0) # [N, N] laplacian = laplacian.unsqueeze(0) - first_laplacian = torch.zeros( - [1, N, N], device=laplacian.device, dtype=torch.float - ) + first_laplacian = torch.zeros([1, N, N], device=laplacian.device, dtype=torch.float) second_laplacian = laplacian - third_laplacian = ( - 2 * torch.matmul(laplacian, second_laplacian) - ) - first_laplacian - forth_laplacian = ( - 2 * torch.matmul(laplacian, third_laplacian) - second_laplacian - ) - multi_order_laplacian = torch.cat( - [first_laplacian, second_laplacian, third_laplacian, forth_laplacian], dim=0 - ) + third_laplacian = (2 * torch.matmul(laplacian, second_laplacian)) - first_laplacian + forth_laplacian = 2 * torch.matmul(laplacian, third_laplacian) - second_laplacian + multi_order_laplacian = torch.cat([first_laplacian, second_laplacian, third_laplacian, forth_laplacian], dim=0) return multi_order_laplacian def latent_correlation_layer(self, x): @@ -106,9 +91,7 @@ def latent_correlation_layer(self, x): attention = 0.5 * (attention + attention.T) degree_l = torch.diag(degree) diagonal_degree_hat = torch.diag(1 / (torch.sqrt(degree) + 1e-7)) - laplacian = torch.matmul( - diagonal_degree_hat, torch.matmul(degree_l - attention, diagonal_degree_hat) - ) + laplacian = torch.matmul(diagonal_degree_hat, torch.matmul(degree_l - attention, diagonal_degree_hat)) mul_L = self.cheb_polynomial(laplacian) return mul_L, attention diff --git a/pypots/nn/modules/stemgnn/layers.py b/pypots/nn/modules/stemgnn/layers.py index 7144a70d..8eae0da6 100644 --- a/pypots/nn/modules/stemgnn/layers.py +++ b/pypots/nn/modules/stemgnn/layers.py @@ -27,14 +27,10 @@ def __init__(self, time_step, unit, multi_layer, stack_cnt=0): self.stack_cnt = stack_cnt self.multi = multi_layer self.weight = nn.Parameter( - torch.Tensor( - 1, 3 + 1, 1, self.time_step * self.multi, self.multi * self.time_step - ) + torch.Tensor(1, 3 + 1, 1, self.time_step * self.multi, self.multi * self.time_step) ) # [K+1, 1, in_c, out_c] nn.init.xavier_normal_(self.weight) - self.forecast = nn.Linear( - self.time_step * self.multi, self.time_step * self.multi - ) + self.forecast = nn.Linear(self.time_step * self.multi, self.time_step * self.multi) self.forecast_result = nn.Linear(self.time_step * self.multi, self.time_step) if self.stack_cnt == 0: self.backcast = nn.Linear(self.time_step * self.multi, self.time_step) @@ -44,12 +40,8 @@ def __init__(self, time_step, unit, multi_layer, stack_cnt=0): self.output_channel = 4 * self.multi for i in range(3): if i == 0: - self.GLUs.append( - GLU(self.time_step * 4, self.time_step * self.output_channel) - ) - self.GLUs.append( - GLU(self.time_step * 4, self.time_step * self.output_channel) - ) + self.GLUs.append(GLU(self.time_step * 4, self.time_step * self.output_channel)) + self.GLUs.append(GLU(self.time_step * 4, self.time_step * self.output_channel)) elif i == 1: self.GLUs.append( GLU( @@ -81,27 +73,13 @@ def spe_seq_cell(self, input): batch_size, k, input_channel, node_cnt, time_step = input.size() input = input.view(batch_size, -1, node_cnt, time_step) # ffted = torch.fft.rfft(input, 1, onesided=False) # original old version, onesided doesn't work in new torch - ffted = torch.view_as_real( - torch.fft.fft(input, dim=1) - ) # WDU: replace the above line with this line - real = ( - ffted[..., 0] - .permute(0, 2, 1, 3) - .contiguous() - .reshape(batch_size, node_cnt, -1) - ) - img = ( - ffted[..., 1] - .permute(0, 2, 1, 3) - .contiguous() - .reshape(batch_size, node_cnt, -1) - ) + ffted = torch.view_as_real(torch.fft.fft(input, dim=1)) # WDU: replace the above line with this line + real = ffted[..., 0].permute(0, 2, 1, 3).contiguous().reshape(batch_size, node_cnt, -1) + img = ffted[..., 1].permute(0, 2, 1, 3).contiguous().reshape(batch_size, node_cnt, -1) for i in range(3): real = self.GLUs[i * 2](real) img = self.GLUs[2 * i + 1](img) - real = ( - real.reshape(batch_size, node_cnt, 4, -1).permute(0, 2, 1, 3).contiguous() - ) + real = real.reshape(batch_size, node_cnt, 4, -1).permute(0, 2, 1, 3).contiguous() img = img.reshape(batch_size, node_cnt, 4, -1).permute(0, 2, 1, 3).contiguous() time_step_as_inner = torch.cat([real.unsqueeze(-1), img.unsqueeze(-1)], dim=-1) # iffted = torch.fft.irfft(time_step_as_inner, 1, onesided=False) # onesided doesn't work in new torch diff --git a/pypots/nn/modules/tcn/layers.py b/pypots/nn/modules/tcn/layers.py index be640396..3233ad61 100644 --- a/pypots/nn/modules/tcn/layers.py +++ b/pypots/nn/modules/tcn/layers.py @@ -69,9 +69,7 @@ def __init__( self.relu2, self.dropout2, ) - self.downsample = ( - nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None - ) + self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None self.relu = nn.ReLU() self.init_weights() diff --git a/pypots/nn/modules/tide/autoencoder.py b/pypots/nn/modules/tide/autoencoder.py index dfc7efb7..3fe80eb9 100644 --- a/pypots/nn/modules/tide/autoencoder.py +++ b/pypots/nn/modules/tide/autoencoder.py @@ -67,9 +67,7 @@ def forward(self, X, dynamic): enc_in = torch.cat([X.reshape(bz, -1), feature.reshape(bz, -1)], dim=-1) hidden = self.encoder(enc_in) - decoded = self.decoder(hidden).reshape( - hidden.shape[0], self.n_steps, self.n_features - ) + decoded = self.decoder(hidden).reshape(hidden.shape[0], self.n_steps, self.n_features) temporal_decoder_input = torch.cat([feature, decoded], dim=-1) prediction = self.temporal_decoder(temporal_decoder_input) prediction += self.residual_proj(X) @@ -96,10 +94,7 @@ def __init__( self.encoder_layers = nn.Sequential( ResBlock(d_flatten, self.res_hidden, self.d_hidden, dropout), - *( - [ResBlock(self.d_hidden, self.res_hidden, self.d_hidden, dropout)] - * (self.n_layers - 1) - ), + *([ResBlock(self.d_hidden, self.res_hidden, self.d_hidden, dropout)] * (self.n_layers - 1)), ) def forward(self, X): @@ -146,7 +141,5 @@ def forward( self, X, ): - dec_out = self.decoder_layers(X).reshape( - X.shape[0], self.n_pred_steps, self.n_pred_features - ) + dec_out = self.decoder_layers(X).reshape(X.shape[0], self.n_pred_steps, self.n_pred_features) return dec_out diff --git a/pypots/nn/modules/timemixer/backbone.py b/pypots/nn/modules/timemixer/backbone.py index ad9238c2..1b134437 100644 --- a/pypots/nn/modules/timemixer/backbone.py +++ b/pypots/nn/modules/timemixer/backbone.py @@ -73,17 +73,11 @@ def __init__( self.preprocess = SeriesDecompositionBlock(moving_avg) if self.channel_independence == 1: - self.enc_embedding = DataEmbedding( - 1, d_model, embed, freq, dropout, with_pos=False - ) + self.enc_embedding = DataEmbedding(1, d_model, embed, freq, dropout, with_pos=False) else: - self.enc_embedding = DataEmbedding( - n_features, d_model, embed, freq, dropout, with_pos=False - ) + self.enc_embedding = DataEmbedding(n_features, d_model, embed, freq, dropout, with_pos=False) - self.normalize_layers = torch.nn.ModuleList( - [RevIN(n_features) for _ in range(downsampling_layers + 1)] - ) + self.normalize_layers = torch.nn.ModuleList([RevIN(n_features) for _ in range(downsampling_layers + 1)]) if task_name == "long_term_forecast" or task_name == "short_term_forecast": self.predict_layers = torch.nn.ModuleList( @@ -152,9 +146,7 @@ def pre_enc(self, x_list): def __multi_scale_process_inputs(self, x_enc, x_mark_enc): if self.downsampling_method == "max": - down_pool = torch.nn.MaxPool1d( - self.downsampling_window, return_indices=False - ) + down_pool = torch.nn.MaxPool1d(self.downsampling_window, return_indices=False) elif self.downsampling_method == "avg": down_pool = torch.nn.AvgPool1d(self.downsampling_window) elif self.downsampling_method == "conv": @@ -188,12 +180,8 @@ def __multi_scale_process_inputs(self, x_enc, x_mark_enc): x_enc_ori = x_enc_sampling if x_mark_enc_mark_ori is not None: - x_mark_sampling_list.append( - x_mark_enc_mark_ori[:, :: self.downsampling_window, :] - ) - x_mark_enc_mark_ori = x_mark_enc_mark_ori[ - :, :: self.downsampling_window, : - ] + x_mark_sampling_list.append(x_mark_enc_mark_ori[:, :: self.downsampling_window, :]) + x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, :: self.downsampling_window, :] x_enc = x_enc_sampling_list if x_mark_enc_mark_ori is not None: @@ -264,28 +252,18 @@ def future_multi_mixing(self, B, enc_out_list, x_list): if self.channel_independence == 1: x_list = x_list[0] for i, enc_out in zip(range(len(x_list)), enc_out_list): - dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute( - 0, 2, 1 - ) # align temporal dimension + dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension if self.use_future_temporal_feature: dec_out = dec_out + self.x_mark_dec dec_out = self.projection_layer(dec_out) else: dec_out = self.projection_layer(dec_out) - dec_out = ( - dec_out.reshape(B, self.c_out, self.n_pred_steps) - .permute(0, 2, 1) - .contiguous() - ) + dec_out = dec_out.reshape(B, self.c_out, self.n_pred_steps).permute(0, 2, 1).contiguous() dec_out_list.append(dec_out) else: - for i, enc_out, out_res in zip( - range(len(x_list[0])), enc_out_list, x_list[1] - ): - dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute( - 0, 2, 1 - ) # align temporal dimension + for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]): + dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension dec_out = self.out_projection(dec_out, i, out_res) dec_out_list.append(dec_out) @@ -385,8 +363,6 @@ def imputation(self, x_enc, x_mark_enc): enc_out_list = self.pdm_blocks[i](enc_out_list) dec_out = self.projection_layer(enc_out_list[0]) - dec_out = ( - dec_out.reshape(B, self.n_pred_features, -1).permute(0, 2, 1).contiguous() - ) + dec_out = dec_out.reshape(B, self.n_pred_features, -1).permute(0, 2, 1).contiguous() return dec_out diff --git a/pypots/nn/modules/timemixer/layers.py b/pypots/nn/modules/timemixer/layers.py index 9f6e4d5e..6306acc8 100644 --- a/pypots/nn/modules/timemixer/layers.py +++ b/pypots/nn/modules/timemixer/layers.py @@ -211,9 +211,7 @@ def forward(self, x_list): out_trend_list = self.mixing_multi_scale_trend(trend_list) out_list = [] - for ori, out_season, out_trend, length in zip( - x_list, out_season_list, out_trend_list, length_list - ): + for ori, out_season, out_trend, length in zip(x_list, out_season_list, out_trend_list, length_list): out = out_season + out_trend if self.channel_independence: out = ori + self.out_cross_layer(out) diff --git a/pypots/nn/modules/timesnet/backbone.py b/pypots/nn/modules/timesnet/backbone.py index 2f591e8e..5eb6ec04 100644 --- a/pypots/nn/modules/timesnet/backbone.py +++ b/pypots/nn/modules/timesnet/backbone.py @@ -1,6 +1,7 @@ """ """ + import torch import torch.nn as nn @@ -29,10 +30,7 @@ def __init__( self.n_pred_steps = n_pred_steps self.model = nn.ModuleList( - [ - TimesBlock(n_steps, n_pred_steps, top_k, d_model, d_ffn, n_kernels) - for _ in range(n_layers) - ] + [TimesBlock(n_steps, n_pred_steps, top_k, d_model, d_ffn, n_kernels) for _ in range(n_layers)] ) self.layer_norm = nn.LayerNorm(d_model) diff --git a/pypots/nn/modules/timesnet/layers.py b/pypots/nn/modules/timesnet/layers.py index a1130910..3fa46432 100644 --- a/pypots/nn/modules/timesnet/layers.py +++ b/pypots/nn/modules/timesnet/layers.py @@ -31,9 +31,7 @@ def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): self.num_kernels = num_kernels kernels = [] for i in range(self.num_kernels): - kernels.append( - nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i) - ) + kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)) self.kernels = nn.ModuleList(kernels) if init_weight: self._initialize_weights() @@ -77,19 +75,13 @@ def forward(self, x): # padding if (self.seq_len + self.pred_len) % period != 0: length = (((self.seq_len + self.pred_len) // period) + 1) * period - padding = torch.zeros( - [x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]] - ).to(x.device) + padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device) out = torch.cat([x, padding], dim=1) else: length = self.seq_len + self.pred_len out = x # reshape - out = ( - out.reshape(B, length // period, period, N) - .permute(0, 3, 1, 2) - .contiguous() - ) + out = out.reshape(B, length // period, period, N).permute(0, 3, 1, 2).contiguous() # 2D conv: from 1d Variation to 2d Variation out = self.conv(out) # reshape back diff --git a/pypots/nn/modules/transformer/embedding.py b/pypots/nn/modules/transformer/embedding.py index d021210e..64572064 100644 --- a/pypots/nn/modules/transformer/embedding.py +++ b/pypots/nn/modules/transformer/embedding.py @@ -33,10 +33,7 @@ def __init__(self, d_hid: int, n_positions: int = 1000): super().__init__() pe = torch.zeros(n_positions, d_hid, requires_grad=False).float() position = torch.arange(0, n_positions).float().unsqueeze(1) - div_term = ( - torch.arange(0, d_hid, 2).float() - * -(torch.log(torch.tensor(10000)) / d_hid) - ).exp() + div_term = (torch.arange(0, d_hid, 2).float() * -(torch.log(torch.tensor(10000)) / d_hid)).exp() pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) @@ -87,9 +84,7 @@ def __init__(self, c_in, d_model): ) for m in self.modules(): if isinstance(m, nn.Conv1d): - nn.init.kaiming_normal_( - m.weight, mode="fan_in", nonlinearity="leaky_relu" - ) + nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="leaky_relu") def forward(self, x): x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) @@ -104,9 +99,7 @@ def __init__(self, c_in, d_model): w.require_grad = False position = torch.arange(0, c_in).float().unsqueeze(1) - div_term = ( - torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) - ).exp() + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() w[:, 0::2] = torch.sin(position * div_term) w[:, 1::2] = torch.cos(position * div_term) @@ -138,9 +131,7 @@ def __init__(self, d_model, embed_type="fixed", freq="h"): def forward(self, x): x = x.long() - minute_x = ( - self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 - ) + minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 hour_x = self.hour_embed(x[:, :, 3]) weekday_x = self.weekday_embed(x[:, :, 2]) day_x = self.day_embed(x[:, :, 1]) @@ -178,9 +169,7 @@ def __init__( self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) if with_pos: - self.position_embedding = PositionalEncoding( - d_hid=d_model, n_positions=n_max_steps - ) + self.position_embedding = PositionalEncoding(d_hid=d_model, n_positions=n_max_steps) self.temporal_embedding = ( TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type != "timeF" diff --git a/pypots/nn/modules/usgan/backbone.py b/pypots/nn/modules/usgan/backbone.py index 42d7f430..e0b95106 100644 --- a/pypots/nn/modules/usgan/backbone.py +++ b/pypots/nn/modules/usgan/backbone.py @@ -61,12 +61,8 @@ def forward( forward_missing_mask = inputs["forward"]["missing_mask"] if training_object == "discriminator": - discrimination = self.discriminator( - imputed_data.detach(), forward_missing_mask - ) - l_D = F.binary_cross_entropy_with_logits( - discrimination, forward_missing_mask - ) + discrimination = self.discriminator(imputed_data.detach(), forward_missing_mask) + l_D = F.binary_cross_entropy_with_logits(discrimination, forward_missing_mask) discrimination_loss = l_D return imputed_data, discrimination_loss else: @@ -77,9 +73,9 @@ def forward( weight=1 - forward_missing_mask, ) reconstruction = (f_reconstruction + b_reconstruction) / 2 - reconstruction_loss = calc_mse( - forward_X, reconstruction, forward_missing_mask - ) + 0.1 * calc_mse(f_reconstruction, b_reconstruction) + reconstruction_loss = calc_mse(forward_X, reconstruction, forward_missing_mask) + 0.1 * calc_mse( + f_reconstruction, b_reconstruction + ) loss_gene = l_G + self.lambda_mse * reconstruction_loss generation_loss = loss_gene return imputed_data, generation_loss diff --git a/pypots/nn/modules/usgan/layers.py b/pypots/nn/modules/usgan/layers.py index 675a3e58..6ead8fc4 100644 --- a/pypots/nn/modules/usgan/layers.py +++ b/pypots/nn/modules/usgan/layers.py @@ -40,9 +40,7 @@ def __init__( ): super().__init__() self.hint_rate = hint_rate - self.biRNN = nn.GRU( - n_features * 2, rnn_hidden_size, bidirectional=True, batch_first=True - ) + self.biRNN = nn.GRU(n_features * 2, rnn_hidden_size, bidirectional=True, batch_first=True) self.dropout = nn.Dropout(dropout_rate) self.read_out = nn.Linear(rnn_hidden_size * 2, n_features) @@ -69,10 +67,7 @@ def forward( """ device = imputed_X.device - hint = ( - torch.rand_like(missing_mask, dtype=torch.float, device=device) - < self.hint_rate - ) + hint = torch.rand_like(missing_mask, dtype=torch.float, device=device) < self.hint_rate hint = hint.int() h = hint * missing_mask + (1 - hint) * 0.5 x_in = torch.cat([imputed_X, h], dim=-1) diff --git a/pypots/nn/modules/vader/backbone.py b/pypots/nn/modules/vader/backbone.py index 7c2d6639..0117db7c 100644 --- a/pypots/nn/modules/vader/backbone.py +++ b/pypots/nn/modules/vader/backbone.py @@ -62,16 +62,10 @@ def __init__( self.implicit_imputation_layer = ImplicitImputation(d_input) self.encoder = PeepholeLSTMCell(d_input, d_rnn_hidden) self.decoder = PeepholeLSTMCell(d_input, d_rnn_hidden) - self.ae_encode_layers = nn.Sequential( - nn.Linear(d_rnn_hidden, d_rnn_hidden), nn.Softplus() - ) - self.ae_decode_layers = nn.Sequential( - nn.Linear(d_mu_stddev, d_rnn_hidden), nn.Softplus() - ) + self.ae_encode_layers = nn.Sequential(nn.Linear(d_rnn_hidden, d_rnn_hidden), nn.Softplus()) + self.ae_decode_layers = nn.Sequential(nn.Linear(d_mu_stddev, d_rnn_hidden), nn.Softplus()) self.mu_layer = nn.Linear(d_rnn_hidden, d_mu_stddev) # layer for mean - self.stddev_layer = nn.Linear( - d_rnn_hidden, d_mu_stddev - ) # layer for standard variance + self.stddev_layer = nn.Linear(d_rnn_hidden, d_mu_stddev) # layer for standard variance self.rnn_transform_layer = nn.Linear(d_rnn_hidden, d_input) self.gmm_layer = GMMLayer(d_mu_stddev, n_clusters) @@ -93,12 +87,8 @@ def encode( X_imputed = self.implicit_imputation_layer(X, missing_mask) - hidden_state = torch.zeros( - (batch_size, self.d_rnn_hidden), dtype=X.dtype, device=X.device - ) - cell_state = torch.zeros( - (batch_size, self.d_rnn_hidden), dtype=X.dtype, device=X.device - ) + hidden_state = torch.zeros((batch_size, self.d_rnn_hidden), dtype=X.dtype, device=X.device) + cell_state = torch.zeros((batch_size, self.d_rnn_hidden), dtype=X.dtype, device=X.device) # cell_state_collector = torch.empty((batch_size, self.n_steps, self.d_rnn_hidden), # dtype=X.dtype, device=X.device) for i in range(self.n_steps): @@ -117,9 +107,7 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: hidden_state = self.ae_decode_layers(hidden_state) cell_state = torch.zeros(hidden_state.size(), dtype=z.dtype, device=z.device) - inputs = torch.zeros( - (z.size(0), self.n_steps, self.d_input), dtype=z.dtype, device=z.device - ) + inputs = torch.zeros((z.size(0), self.n_steps, self.d_input), dtype=z.dtype, device=z.device) hidden_state_collector = torch.empty( (z.size(0), self.n_steps, self.d_rnn_hidden), dtype=z.dtype, device=z.device @@ -133,16 +121,10 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: return reconstruction def forward( - self, X: torch.Tensor, missing_mask: torch.Tensor - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: + self, + X: torch.Tensor, + missing_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: z, mu_tilde, stddev_tilde = self.encode(X, missing_mask) X_reconstructed = self.decode(z) mu_c, var_c, phi_c = self.gmm_layer() diff --git a/pypots/nn/modules/vader/layers.py b/pypots/nn/modules/vader/layers.py index a3e53b87..50df59b1 100644 --- a/pypots/nn/modules/vader/layers.py +++ b/pypots/nn/modules/vader/layers.py @@ -50,9 +50,7 @@ def forward( hx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if hx is None: - zeros = torch.zeros( - X.size(0), self.hidden_size, dtype=X.dtype, device=X.device - ) + zeros = torch.zeros(X.size(0), self.hidden_size, dtype=X.dtype, device=X.device) hx = (zeros, zeros) h, c = hx diff --git a/pypots/optim/base.py b/pypots/optim/base.py index 9059ba03..b64ba012 100644 --- a/pypots/optim/base.py +++ b/pypots/optim/base.py @@ -12,6 +12,7 @@ 2). provide additional functionalities, such as learning rate scheduling, etc.; """ + # Created by Wenjie Du # License: BSD-3-Clause diff --git a/pypots/optim/lr_scheduler/constant_lrs.py b/pypots/optim/lr_scheduler/constant_lrs.py index 12123ffe..4a5cf77f 100644 --- a/pypots/optim/lr_scheduler/constant_lrs.py +++ b/pypots/optim/lr_scheduler/constant_lrs.py @@ -50,9 +50,7 @@ class ConstantLR(LRScheduler): def __init__(self, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): super().__init__(last_epoch, verbose) if factor > 1.0 or factor < 0: - raise ValueError( - "Constant multiplicative factor expected to be between 0 and 1." - ) + raise ValueError("Constant multiplicative factor expected to be between 0 and 1.") self.factor = factor self.total_iters = total_iters @@ -60,8 +58,7 @@ def __init__(self, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): def get_lr(self): if not self._get_lr_called_within_step: logger.warning( - "⚠️ To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", + "⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", ) if self.last_epoch == 0: @@ -71,14 +68,10 @@ def get_lr(self): return [group["lr"] for group in self.optimizer.param_groups] if self.last_epoch == self.total_iters: - return [ - group["lr"] * (1.0 / self.factor) - for group in self.optimizer.param_groups - ] + return [group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups] def _get_closed_form_lr(self): return [ - base_lr - * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) for base_lr in self.base_lrs ] diff --git a/pypots/optim/lr_scheduler/exponential_lrs.py b/pypots/optim/lr_scheduler/exponential_lrs.py index 722b3867..416301de 100644 --- a/pypots/optim/lr_scheduler/exponential_lrs.py +++ b/pypots/optim/lr_scheduler/exponential_lrs.py @@ -43,8 +43,7 @@ def __init__(self, gamma, last_epoch=-1, verbose=False): def get_lr(self): if not self._get_lr_called_within_step: logger.warning( - "⚠️ To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", + "⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", ) if self.last_epoch == 0: diff --git a/pypots/optim/lr_scheduler/lambda_lrs.py b/pypots/optim/lr_scheduler/lambda_lrs.py index bc0891b6..9d5e0ca3 100644 --- a/pypots/optim/lr_scheduler/lambda_lrs.py +++ b/pypots/optim/lr_scheduler/lambda_lrs.py @@ -51,16 +51,12 @@ def __init__( self.lr_lambdas = None def init_scheduler(self, optimizer): - if not isinstance(self.lr_lambda, list) and not isinstance( - self.lr_lambda, tuple - ): + if not isinstance(self.lr_lambda, list) and not isinstance(self.lr_lambda, tuple): self.lr_lambdas = [self.lr_lambda] * len(optimizer.param_groups) else: if len(self.lr_lambda) != len(optimizer.param_groups): raise ValueError( - "Expected {} lr_lambdas, but got {}".format( - len(optimizer.param_groups), len(self.lr_lambda) - ) + "Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(self.lr_lambda)) ) self.lr_lambdas = list(self.lr_lambda) @@ -68,12 +64,6 @@ def init_scheduler(self, optimizer): def get_lr(self): if not self._get_lr_called_within_step: - logger.warning( - "⚠️ To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`." - ) - - return [ - base_lr * lmbda(self.last_epoch) - for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs) - ] + logger.warning("⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.") + + return [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] diff --git a/pypots/optim/lr_scheduler/linear_lrs.py b/pypots/optim/lr_scheduler/linear_lrs.py index a91d6693..be8b79de 100644 --- a/pypots/optim/lr_scheduler/linear_lrs.py +++ b/pypots/optim/lr_scheduler/linear_lrs.py @@ -61,14 +61,10 @@ def __init__( ): super().__init__(last_epoch, verbose) if start_factor > 1.0 or start_factor < 0: - raise ValueError( - "Starting multiplicative factor expected to be between 0 and 1." - ) + raise ValueError("Starting multiplicative factor expected to be between 0 and 1.") if end_factor > 1.0 or end_factor < 0: - raise ValueError( - "Ending multiplicative factor expected to be between 0 and 1." - ) + raise ValueError("Ending multiplicative factor expected to be between 0 and 1.") self.start_factor = start_factor self.end_factor = end_factor @@ -77,14 +73,11 @@ def __init__( def get_lr(self): if not self._get_lr_called_within_step: logger.warning( - "⚠️ To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", + "⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", ) if self.last_epoch == 0: - return [ - group["lr"] * self.start_factor for group in self.optimizer.param_groups - ] + return [group["lr"] * self.start_factor for group in self.optimizer.param_groups] if self.last_epoch > self.total_iters: return [group["lr"] for group in self.optimizer.param_groups] @@ -94,10 +87,7 @@ def get_lr(self): * ( 1.0 + (self.end_factor - self.start_factor) - / ( - self.total_iters * self.start_factor - + (self.last_epoch - 1) * (self.end_factor - self.start_factor) - ) + / (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)) ) for group in self.optimizer.param_groups ] @@ -107,9 +97,7 @@ def _get_closed_form_lr(self): base_lr * ( self.start_factor - + (self.end_factor - self.start_factor) - * min(self.total_iters, self.last_epoch) - / self.total_iters + + (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters ) for base_lr in self.base_lrs ] diff --git a/pypots/optim/lr_scheduler/multiplicative_lrs.py b/pypots/optim/lr_scheduler/multiplicative_lrs.py index bd753554..58500e04 100644 --- a/pypots/optim/lr_scheduler/multiplicative_lrs.py +++ b/pypots/optim/lr_scheduler/multiplicative_lrs.py @@ -46,16 +46,12 @@ def __init__(self, lr_lambda, last_epoch=-1, verbose=False): self.lr_lambdas = None def init_scheduler(self, optimizer): - if not isinstance(self.lr_lambda, list) and not isinstance( - self.lr_lambda, tuple - ): + if not isinstance(self.lr_lambda, list) and not isinstance(self.lr_lambda, tuple): self.lr_lambdas = [self.lr_lambda] * len(optimizer.param_groups) else: if len(self.lr_lambda) != len(optimizer.param_groups): raise ValueError( - "Expected {} lr_lambdas, but got {}".format( - len(optimizer.param_groups), len(self.lr_lambda) - ) + "Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(self.lr_lambda)) ) self.lr_lambdas = list(self.lr_lambda) @@ -64,8 +60,7 @@ def init_scheduler(self, optimizer): def get_lr(self): if not self._get_lr_called_within_step: logger.warning( - "⚠️ To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", + "⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", ) if self.last_epoch > 0: diff --git a/pypots/optim/lr_scheduler/multistep_lrs.py b/pypots/optim/lr_scheduler/multistep_lrs.py index 4a841172..7c06871c 100644 --- a/pypots/optim/lr_scheduler/multistep_lrs.py +++ b/pypots/optim/lr_scheduler/multistep_lrs.py @@ -56,20 +56,13 @@ def __init__(self, milestones, gamma=0.1, last_epoch=-1, verbose=False): def get_lr(self): if not self._get_lr_called_within_step: logger.warning( - "⚠️ To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", + "⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", ) if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] - return [ - group["lr"] * self.gamma ** self.milestones[self.last_epoch] - for group in self.optimizer.param_groups - ] + return [group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups] def _get_closed_form_lr(self): milestones = list(sorted(self.milestones.elements())) - return [ - base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) - for base_lr in self.base_lrs - ] + return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) for base_lr in self.base_lrs] diff --git a/pypots/optim/lr_scheduler/step_lrs.py b/pypots/optim/lr_scheduler/step_lrs.py index b1a9a440..2f469b81 100644 --- a/pypots/optim/lr_scheduler/step_lrs.py +++ b/pypots/optim/lr_scheduler/step_lrs.py @@ -55,8 +55,7 @@ def __init__(self, step_size, gamma=0.1, last_epoch=-1, verbose=False): def get_lr(self): if not self._get_lr_called_within_step: logger.warning( - "⚠️ To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", + "⚠️ To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", ) if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): @@ -64,7 +63,4 @@ def get_lr(self): return [group["lr"] * self.gamma for group in self.optimizer.param_groups] def _get_closed_form_lr(self): - return [ - base_lr * self.gamma ** (self.last_epoch // self.step_size) - for base_lr in self.base_lrs - ] + return [base_lr * self.gamma ** (self.last_epoch // self.step_size) for base_lr in self.base_lrs] diff --git a/pypots/utils/metrics/classification.py b/pypots/utils/metrics/classification.py index 3cc2af8d..5653a378 100644 --- a/pypots/utils/metrics/classification.py +++ b/pypots/utils/metrics/classification.py @@ -68,12 +68,8 @@ def calc_binary_classification_metrics( else: raise f"targets dimensions should be 1 or 2, but got targets.shape: {targets.shape}" - if len(prob_predictions.shape) == 1 or ( - len(prob_predictions.shape) == 2 and prob_predictions.shape[1] == 1 - ): - prob_predictions = np.asarray( - prob_predictions - ).flatten() # turn the array shape into [n_samples] + if len(prob_predictions.shape) == 1 or (len(prob_predictions.shape) == 2 and prob_predictions.shape[1] == 1): + prob_predictions = np.asarray(prob_predictions).flatten() # turn the array shape into [n_samples] binary_predictions = prob_predictions prediction_categories = (prob_predictions >= 0.5).astype(int) binary_prediction_categories = prediction_categories @@ -93,12 +89,8 @@ def calc_binary_classification_metrics( binary_targets = np.copy(targets) binary_targets[~mask] = mask_val - precision, recall, f1 = calc_precision_recall_f1( - binary_prediction_categories, binary_targets, pos_label - ) - pr_auc, precisions, recalls, _ = calc_pr_auc( - binary_predictions, binary_targets, pos_label - ) + precision, recall, f1 = calc_precision_recall_f1(binary_prediction_categories, binary_targets, pos_label) + pr_auc, precisions, recalls, _ = calc_pr_auc(binary_predictions, binary_targets, pos_label) ROC_AUC, fprs, tprs, _ = calc_roc_auc(binary_predictions, binary_targets, pos_label) PR_AUC = metrics.auc(recalls, precisions) classification_metrics = { @@ -147,9 +139,7 @@ def calc_precision_recall_f1( The F1 score of model predictions. """ - precision, recall, f1, _ = metrics.precision_recall_fscore_support( - targets, prob_predictions, pos_label=pos_label - ) + precision, recall, f1, _ = metrics.precision_recall_fscore_support(targets, prob_predictions, pos_label=pos_label) precision, recall, f1 = precision[pos_label], recall[pos_label], f1[pos_label] return precision, recall, f1 @@ -188,9 +178,7 @@ def calc_pr_auc( """ - precisions, recalls, thresholds = metrics.precision_recall_curve( - targets, prob_predictions, pos_label=pos_label - ) + precisions, recalls, thresholds = metrics.precision_recall_curve(targets, prob_predictions, pos_label=pos_label) pr_auc = metrics.auc(recalls, precisions) return pr_auc, precisions, recalls, thresholds @@ -228,9 +216,7 @@ def calc_roc_auc( Increasing thresholds on the decision function used to compute FPR and TPR. """ - fprs, tprs, thresholds = metrics.roc_curve( - y_true=targets, y_score=prob_predictions, pos_label=pos_label - ) + fprs, tprs, thresholds = metrics.roc_curve(y_true=targets, y_score=prob_predictions, pos_label=pos_label) roc_auc = metrics.auc(fprs, tprs) return roc_auc, fprs, tprs, thresholds diff --git a/pypots/utils/metrics/clustering.py b/pypots/utils/metrics/clustering.py index 87a18cbf..bcff0de3 100644 --- a/pypots/utils/metrics/clustering.py +++ b/pypots/utils/metrics/clustering.py @@ -146,9 +146,7 @@ def calc_cluster_purity( """ contingency_matrix = metrics.cluster.contingency_matrix(targets, class_predictions) - cluster_purity = np.sum(np.amax(contingency_matrix, axis=0)) / np.sum( - contingency_matrix - ) + cluster_purity = np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix) return cluster_purity @@ -271,9 +269,7 @@ def calc_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: return davies_bouldin_score -def calc_internal_cluster_validation_metrics( - X: np.ndarray, predicted_labels: np.ndarray -) -> dict: +def calc_internal_cluster_validation_metrics(X: np.ndarray, predicted_labels: np.ndarray) -> dict: """Computer all internal cluster validation metrics available in PyPOTS and return as a dictionary. Parameters diff --git a/pypots/utils/metrics/error.py b/pypots/utils/metrics/error.py index d5e105d5..251eea1e 100644 --- a/pypots/utils/metrics/error.py +++ b/pypots/utils/metrics/error.py @@ -31,12 +31,8 @@ def _check_inputs( prediction_shape == target_shape ), f"shape of `predictions` and `targets` must match, but got {prediction_shape} and {target_shape}" # check NaN - assert not lib.isnan( - predictions - ).any(), "`predictions` mustn't contain NaN values, but detected NaN in it" - assert not lib.isnan( - targets - ).any(), "`targets` mustn't contain NaN values, but detected NaN in it" + assert not lib.isnan(predictions).any(), "`predictions` mustn't contain NaN values, but detected NaN in it" + assert not lib.isnan(targets).any(), "`targets` mustn't contain NaN values, but detected NaN in it" if masks is not None: # check type @@ -51,9 +47,7 @@ def _check_inputs( f"but got `mask`: {mask_shape} that is different from `targets`: {target_shape}" ) # check NaN - assert not lib.isnan( - masks - ).any(), "`masks` mustn't contain NaN values, but detected NaN in it" + assert not lib.isnan(masks).any(), "`masks` mustn't contain NaN values, but detected NaN in it" return lib @@ -104,9 +98,7 @@ def calc_mae( lib = _check_inputs(predictions, targets, masks) if masks is not None: - return lib.sum(lib.abs(predictions - targets) * masks) / ( - lib.sum(masks) + 1e-12 - ) + return lib.sum(lib.abs(predictions - targets) * masks) / (lib.sum(masks) + 1e-12) else: return lib.mean(lib.abs(predictions - targets)) @@ -157,9 +149,7 @@ def calc_mse( lib = _check_inputs(predictions, targets, masks) if masks is not None: - return lib.sum(lib.square(predictions - targets) * masks) / ( - lib.sum(masks) + 1e-12 - ) + return lib.sum(lib.square(predictions - targets) * masks) / (lib.sum(masks) + 1e-12) else: return lib.mean(lib.square(predictions - targets)) @@ -259,20 +249,14 @@ def calc_mre( lib = _check_inputs(predictions, targets, masks) if masks is not None: - return lib.sum(lib.abs(predictions - targets) * masks) / ( - lib.sum(lib.abs(targets * masks)) + 1e-12 - ) + return lib.sum(lib.abs(predictions - targets) * masks) / (lib.sum(lib.abs(targets * masks)) + 1e-12) else: - return lib.sum(lib.abs(predictions - targets)) / ( - lib.sum(lib.abs(targets)) + 1e-12 - ) + return lib.sum(lib.abs(predictions - targets)) / (lib.sum(lib.abs(targets)) + 1e-12) def calc_quantile_loss(predictions, targets, q: float, eval_points) -> float: quantile_loss = 2 * torch.sum( - torch.abs( - (predictions - targets) * eval_points * ((targets <= predictions) * 1.0 - q) - ) + torch.abs((predictions - targets) * eval_points * ((targets <= predictions) * 1.0 - q)) ) return quantile_loss diff --git a/pypots/utils/visual/clustering.py b/pypots/utils/visual/clustering.py index 97092895..be540c03 100644 --- a/pypots/utils/visual/clustering.py +++ b/pypots/utils/visual/clustering.py @@ -14,9 +14,7 @@ import scipy.stats as st -def get_cluster_members( - test_data: np.ndarray, class_predictions: np.ndarray -) -> Dict[int, np.ndarray]: +def get_cluster_members(test_data: np.ndarray, class_predictions: np.ndarray) -> Dict[int, np.ndarray]: """ Subset time series array using predicted cluster membership. @@ -79,18 +77,12 @@ def clusters_for_plotting( for i in cluster_members: # i iterates clusters dict_to_plot[i] = {} # one dict per cluster for j in cluster_members[i]: # j iterates members of each cluster - temp = pd.DataFrame(j).to_dict( - orient="list" - ) # dict of member's time series as lists (one per var) + temp = pd.DataFrame(j).to_dict(orient="list") # dict of member's time series as lists (one per var) for key in temp: # key is a time series var if key not in dict_to_plot[i]: - dict_to_plot[i][key] = [ - temp[key] - ] # create entry in cluster dict for each time series var + dict_to_plot[i][key] = [temp[key]] # create entry in cluster dict for each time series var else: - dict_to_plot[i][key].append( - temp[key] - ) # add cluster member's time series by var key + dict_to_plot[i][key].append(temp[key]) # add cluster member's time series by var key return dict_to_plot @@ -189,28 +181,19 @@ def get_cluster_means(dict_to_plot: Dict[int, dict]) -> Dict[int, dict]: if j not in cluster_means: cluster_means[j] = {} - cluster_means[j][ - i - ] = ( - {} - ) # clusters nested within vars (reverse structure to clusters_for_plotting) + cluster_means[j][i] = {} # clusters nested within vars (reverse structure to clusters_for_plotting) cluster_means[j][i]["mean"] = list( pd.DataFrame(dict_to_plot[i][j]).mean(axis=0, skipna=True) ) # cluster mean array of time series var # CI calculation, from https://stackoverflow.com/a/34474255 - ( - cluster_means[j][i]["CI_low"], - cluster_means[j][i]["CI_high"], - ) = st.t.interval( + (cluster_means[j][i]["CI_low"], cluster_means[j][i]["CI_high"]) = st.t.interval( 0.95, len(dict_to_plot[i][j]) - 1, # degrees of freedom loc=cluster_means[j][i]["mean"], scale=pd.DataFrame(dict_to_plot[i][j]).sem(axis=0, skipna=True), ) - cluster_means[j][i]["n"] = len( - dict_to_plot[i][j] - ) # save cluster size for downstream tasks/plotting + cluster_means[j][i]["n"] = len(dict_to_plot[i][j]) # save cluster size for downstream tasks/plotting return cluster_means @@ -224,9 +207,7 @@ def plot_cluster_means(cluster_means: Dict[int, dict]) -> None: cluster_means : Output from get_cluster_means function. """ - colors = plt.rcParams["axes.prop_cycle"].by_key()[ - "color" - ] # to keep cluster colors consistent + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] # to keep cluster colors consistent for i in cluster_means: # iterate time series vars y = cluster_means[i] @@ -267,9 +248,7 @@ def plot_cluster_means(cluster_means: Dict[int, dict]) -> None: plt.xticks(x) # add dashed line label to legend - line_dashed = mlines.Line2D( - [], [], color="gray", linestyle="--", linewidth=1.5, label="95% CI" - ) + line_dashed = mlines.Line2D([], [], color="gray", linestyle="--", linewidth=1.5, label="95% CI") handles, labels = plt.legend().axes.get_legend_handles_labels() handles.append(line_dashed) new_lgd = plt.legend(handles=handles) diff --git a/pypots/utils/visual/data.py b/pypots/utils/visual/data.py index ca5a5f6e..8bd90a2c 100644 --- a/pypots/utils/visual/data.py +++ b/pypots/utils/visual/data.py @@ -53,16 +53,12 @@ def plot_data( """ vals_shape = X.shape - assert ( - len(vals_shape) == 3 - ), "vals_obs should be a 3D array of shape (n_samples, n_steps, n_features)" + assert len(vals_shape) == 3, "vals_obs should be a 3D array of shape (n_samples, n_steps, n_features)" n_samples, n_steps, n_features = vals_shape if sample_idx is None: sample_idx = np.random.randint(low=0, high=n_samples) - logger.warning( - f"⚠️ No sample index is specified, a random sample {sample_idx} is selected for visualization." - ) + logger.warning(f"⚠️ No sample index is specified, a random sample {sample_idx} is selected for visualization.") if fig_size is None: fig_size = [24, 36] @@ -71,9 +67,7 @@ def plot_data( K = np.min([n_features, n_k]) L = n_steps plt.rcParams["font.size"] = 16 - fig, axes = plt.subplots( - nrows=n_rows, ncols=n_cols, figsize=(fig_size[0], fig_size[1]) - ) + fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(fig_size[0], fig_size[1])) for k in range(K): df = pd.DataFrame({"x": np.arange(0, L), "val": X_imputed[sample_idx, :, k]}) @@ -89,9 +83,7 @@ def plot_data( if row == -1: plt.setp(axes[-1, col], xlabel="time") - logger.info( - "Plotting finished. Please invoke matplotlib.pyplot.show() to display the plot." - ) + logger.info("Plotting finished. Please invoke matplotlib.pyplot.show() to display the plot.") def plot_missingness( @@ -170,6 +162,4 @@ def plot_missingness( axes[1].set_ylabel("Frequency", fontsize=7) axes[1].tick_params(axis="both", labelsize=7) - logger.info( - "Plotting finished. Please invoke matplotlib.pyplot.show() to display the plot." - ) + logger.info("Plotting finished. Please invoke matplotlib.pyplot.show() to display the plot.") diff --git a/pyproject.toml b/pyproject.toml index 75c7d043..54c3a609 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,9 @@ readme = { file = "README.md", content-type = "text/markdown" } dependencies = { file = "requirements/requirements.txt" } optional-dependencies.dev = { file = "requirements/requirements_dev.txt" } +[tool.black] +line-length = 120 + [tool.flake8] # People may argue that coding style is personal. This may be true if the project is personal and one works like a # hermit, but to PyPOTS and its community, the answer is NO. @@ -82,8 +85,10 @@ optional-dependencies.dev = { file = "requirements/requirements_dev.txt" } # who prefer the default setting can keep using 88 or 79 while coding. Please ensure your code lines not exceeding 120. max-line-length = 120 # why ignore E203? Refer to https://github.com/PyCQA/pycodestyle/issues/373 +# why ignore E231? Bad trailing comma, conflict with Black extend-ignore = """ - E203 + E203, + E231, """ # ignore some errors that are not important in template files exclude = [