Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply line-length=120 to refactor code, update dependencies and pre-commit config #512

Merged
merged 11 commits into from
Sep 13, 2024
11 changes: 7 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
repos:
# hooks for checking files
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml

# hooks for linting code
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 24.8.0
hooks:
- id: black
args: [
--line-length=120, # refer to pyproject.toml
]

- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 7.1.1
hooks:
- id: flake8
args: [
--max-line-length=120, # refer to pyproject.toml
--extend-ignore=E203, # why ignore E203? Refer to https://github.com/PyCQA/pycodestyle/issues/373
--extend-ignore=E203,E231
]
245 changes: 119 additions & 126 deletions README.md

Large diffs are not rendered by default.

194 changes: 99 additions & 95 deletions README_zh.md

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@
html_context["READTHEDOCS"] = True

html_favicon = (
"https://raw.githubusercontent.com/"
"PyPOTS/pypots.github.io/main/static/figs/pypots_logos/PyPOTS/logo_FFBG.svg"
"https://raw.githubusercontent.com/PyPOTS/pypots.github.io/main/static/figs/pypots_logos/PyPOTS/logo_FFBG.svg"
)

html_sidebars = {
Expand Down
24 changes: 6 additions & 18 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def _setup_device(self, device: Union[None, str, torch.device, list]) -> None:
self.device = device
elif isinstance(device, list):
if len(device) == 0:
raise ValueError(
"The list of devices should have at least 1 device, but got 0."
)
raise ValueError("The list of devices should have at least 1 device, but got 0.")
elif len(device) == 1:
return self._setup_device(device[0])
# parallely training on multiple CUDA devices
Expand Down Expand Up @@ -179,18 +177,14 @@ def _setup_path(self, saving_path) -> None:
logger.info(f"Model files will be saved to {self.saving_path}")
logger.info(f"Tensorboard file will be saved to {tb_saving_path}")
else:
logger.warning(
"‼️ saving_path not given. Model files and tensorboard file will not be saved."
)
logger.warning("‼️ saving_path not given. Model files and tensorboard file will not be saved.")

def _send_model_to_given_device(self) -> None:
if isinstance(self.device, list):
# parallely training on multiple devices
self.model = torch.nn.DataParallel(self.model, device_ids=self.device)
self.model = self.model.cuda()
logger.info(
f"Model has been allocated to the given multiple devices: {self.device}"
)
logger.info(f"Model has been allocated to the given multiple devices: {self.device}")
else:
self.model = self.model.to(self.device)

Expand Down Expand Up @@ -291,9 +285,7 @@ def save(

if os.path.exists(saving_path):
if overwrite:
logger.warning(
f"‼️ File {saving_path} exists. Argument `overwrite` is True. Overwriting now..."
)
logger.warning(f"‼️ File {saving_path} exists. Argument `overwrite` is True. Overwriting now...")
else:
logger.error(
f"❌ File {saving_path} exists. Saving operation aborted. "
Expand All @@ -309,9 +301,7 @@ def save(
torch.save(self.model, saving_path)
logger.info(f"Saved the model to {saving_path}")
except Exception as e:
raise RuntimeError(
f'Failed to save the model to "{saving_path}" because of the below error! \n{e}'
)
raise RuntimeError(f'Failed to save the model to "{saving_path}" because of the below error! \n{e}')

def load(self, path: str) -> None:
"""Load the saved model from a disk file.
Expand Down Expand Up @@ -519,9 +509,7 @@ def __init__(

def _print_model_size(self) -> None:
"""Print the number of trainable parameters in the initialized NN model."""
self.num_params = sum(
p.numel() for p in self.model.parameters() if p.requires_grad
)
self.num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
logger.info(
f"{self.__class__.__name__} initialized with the given hyperparameters, "
f"the number of trainable parameters: {self.num_params:,}"
Expand Down
20 changes: 5 additions & 15 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,7 @@ def _train_model(
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs)
epoch_val_loss_collector.append(
results["loss"].sum().item()
)
epoch_val_loss_collector.append(results["loss"].sum().item())

mean_val_loss = np.mean(epoch_val_loss_collector)

Expand All @@ -333,15 +331,11 @@ def _train_model(
)
mean_loss = mean_val_loss
else:
logger.info(
f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}"
)
logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss

if np.isnan(mean_loss):
logger.warning(
f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
)
logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.")

if mean_loss < self.best_loss:
self.best_epoch = epoch
Expand All @@ -363,9 +357,7 @@ def _train_model(
nni.report_final_result(self.best_loss)

if self.patience == 0:
logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
logger.info("Exceeded the training patience. Terminating the training procedure...")
break

except KeyboardInterrupt: # if keyboard interrupt, only warning
Expand All @@ -386,9 +378,7 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)
logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.")

@abstractmethod
def fit(
Expand Down
8 changes: 2 additions & 6 deletions pypots/classification/grud/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,15 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
empirical_mean = inputs["empirical_mean"]
X_filledLOCF = inputs["X_filledLOCF"]

_, hidden_state = self.model(
X, missing_mask, deltas, empirical_mean, X_filledLOCF
)
_, hidden_state = self.model(X, missing_mask, deltas, empirical_mean, X_filledLOCF)

logits = self.classifier(hidden_state)
classification_pred = torch.softmax(logits, dim=1)
results = {"classification_pred": classification_pred}

# if in training mode, return results with losses
if training:
classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)
classification_loss = F.nll_loss(torch.log(classification_pred), inputs["label"])
results["loss"] = classification_loss

return results
10 changes: 4 additions & 6 deletions pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def __init__(
self.X_filledLOCF = locf_torch(self.X)
self.X = torch.nan_to_num(self.X)
self.deltas = _parse_delta_torch(self.missing_mask)
self.empirical_mean = torch.sum(
self.missing_mask * self.X, dim=[0, 1]
) / torch.sum(self.missing_mask, dim=[0, 1])
self.empirical_mean = torch.sum(self.missing_mask * self.X, dim=[0, 1]) / torch.sum(
self.missing_mask, dim=[0, 1]
)
# fill nan with 0, in case some features have no observations
self.empirical_mean = torch.nan_to_num(self.empirical_mean, 0)

Expand Down Expand Up @@ -134,9 +134,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
X_filledLOCF = locf_torch(X.unsqueeze(dim=0)).squeeze()
X = torch.nan_to_num(X)
deltas = _parse_delta_torch(missing_mask)
empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum(
missing_mask, dim=[0]
)
empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum(missing_mask, dim=[0])

sample = [
torch.tensor(idx),
Expand Down
21 changes: 5 additions & 16 deletions pypots/classification/raindrop/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
and takes over the forward progress of the algorithm.
"""


# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

Expand Down Expand Up @@ -84,21 +83,13 @@ def forward(self, inputs, training=True):
lengths2 = lengths.unsqueeze(1).to(device)
mask2 = mask.permute(1, 0).unsqueeze(2).long()
if self.sensor_wise_mask:
output = torch.zeros(
[batch_size, self.n_features, self.d_ob + 16], device=device
)
output = torch.zeros([batch_size, self.n_features, self.d_ob + 16], device=device)
extended_missing_mask = missing_mask.view(-1, batch_size, self.n_features)
for se in range(self.n_features):
representation = representation.view(
-1, batch_size, self.n_features, (self.d_ob + 16)
)
representation = representation.view(-1, batch_size, self.n_features, (self.d_ob + 16))
out = representation[:, :, se, :]
l_ = torch.sum(extended_missing_mask[:, :, se], dim=0).unsqueeze(
1
) # length
out_sensor = torch.sum(
out * (1 - extended_missing_mask[:, :, se].unsqueeze(-1)), dim=0
) / (l_ + 1)
l_ = torch.sum(extended_missing_mask[:, :, se], dim=0).unsqueeze(1) # length
out_sensor = torch.sum(out * (1 - extended_missing_mask[:, :, se].unsqueeze(-1)), dim=0) / (l_ + 1)
output[:, se, :] = out_sensor
output = output.view([-1, self.n_features * (self.d_ob + 16)])
elif self.aggregation == "mean":
Expand All @@ -116,9 +107,7 @@ def forward(self, inputs, training=True):

# if in training mode, return results with losses
if training:
classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)
classification_loss = F.nll_loss(torch.log(classification_pred), inputs["label"])
results["loss"] = classification_loss

return results
1 change: 0 additions & 1 deletion pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

"""


# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

Expand Down
17 changes: 5 additions & 12 deletions pypots/cli/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,9 @@ def checkup(self):
)

if self._cleanup:
assert not self._run_tests and not self._lint_code, (
"Argument `--cleanup` should be used alone. "
"Try `pypots-cli dev --cleanup`"
)
assert (
not self._run_tests and not self._lint_code
), "Argument `--cleanup` should be used alone. Try `pypots-cli dev --cleanup`"

def run(self):
"""Execute the given command."""
Expand All @@ -149,14 +148,8 @@ def run(self):
elif self._build:
self.execute_command("python -m build")
elif self._run_tests:
pytest_command = (
f"pytest -k {self._k}" if self._k is not None else "pytest"
)
command_to_run_test = (
f"coverage run -m {pytest_command}"
if self._show_coverage
else pytest_command
)
pytest_command = f"pytest -k {self._k}" if self._k is not None else "pytest"
command_to_run_test = f"coverage run -m {pytest_command}" if self._show_coverage else pytest_command
self.execute_command(command_to_run_test)
if self._show_coverage and os.path.exists(".coverage"):
self.execute_command("coverage report -m")
Expand Down
31 changes: 9 additions & 22 deletions pypots/cli/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def doc_command_factory(args: Namespace):


def purge_temp_files():
logger.info(
f"Directories _build and {CLONED_LATEST_PYPOTS} will be deleted if exist"
)
logger.info(f"Directories _build and {CLONED_LATEST_PYPOTS} will be deleted if exist")
shutil.rmtree("docs/_build", ignore_errors=True)
shutil.rmtree(CLONED_LATEST_PYPOTS, ignore_errors=True)

Expand Down Expand Up @@ -148,10 +146,9 @@ def checkup(self):
self.check_if_under_root_dir(strict=True)

if self._cleanup:
assert not self._gene_rst and not self._gene_html and not self._view_doc, (
"Argument `--cleanup` should be used alone. "
"Try `pypots-cli doc --cleanup`"
)
assert (
not self._gene_rst and not self._gene_html and not self._view_doc
), "Argument `--cleanup` should be used alone. Try `pypots-cli doc --cleanup`"

def run(self):
"""Execute the given command."""
Expand All @@ -166,9 +163,7 @@ def run(self):

if self._gene_rst:
if os.path.exists(CLONED_LATEST_PYPOTS):
logger.info(
f"Directory {CLONED_LATEST_PYPOTS} exists, deleting it..."
)
logger.info(f"Directory {CLONED_LATEST_PYPOTS} exists, deleting it...")
shutil.rmtree(CLONED_LATEST_PYPOTS, ignore_errors=True)

# Download the latest code from GitHub
Expand All @@ -185,18 +180,12 @@ def run(self):
for f_ in files_to_move:
shutil.move(os.path.join(code_dir, f_), destination_dir)
# delete code in tests because we don't need its doc
shutil.rmtree(
f"{CLONED_LATEST_PYPOTS}/pypots/tests", ignore_errors=True
)
shutil.rmtree(f"{CLONED_LATEST_PYPOTS}/pypots/tests", ignore_errors=True)

# Generate the docs according to the cloned code
logger.info("Generating rst files...")
os.environ[
"SPHINX_APIDOC_OPTIONS"
] = "members,undoc-members,show-inheritance,inherited-members"
self.execute_command(
f"sphinx-apidoc {CLONED_LATEST_PYPOTS} -o {CLONED_LATEST_PYPOTS}/rst"
)
os.environ["SPHINX_APIDOC_OPTIONS"] = "members,undoc-members,show-inheritance,inherited-members"
self.execute_command(f"sphinx-apidoc {CLONED_LATEST_PYPOTS} -o {CLONED_LATEST_PYPOTS}/rst")

# Only save the files we need.
logger.info("Updating the old documentation...")
Expand All @@ -217,9 +206,7 @@ def run(self):
"docs/_build/html"
), "docs/_build/html does not exists, please run `pypots-cli doc --gene_html` first"
logger.info(f"Deploying HTML to http://127.0.0.1:{self._port}...")
self.execute_command(
f"python -m http.server {self._port} -d docs/_build/html -b 127.0.0.1"
)
self.execute_command(f"python -m http.server {self._port} -d docs/_build/html -b 127.0.0.1")

except ImportError:
raise ImportError(IMPORT_ERROR_MESSAGE)
Expand Down
8 changes: 2 additions & 6 deletions pypots/cli/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,14 @@ def run(self):
# run checks first
self.checkup()

logger.info(
f"Installing the dependencies in scope `{self._install}` for you..."
)
logger.info(f"Installing the dependencies in scope `{self._install}` for you...")

if self._tool == "conda":
assert (
self.execute_command("which conda").returncode == 0
), "Conda not installed, cannot set --tool=conda, please check your conda."

self.execute_command(
"conda install pyg pytorch-scatter pytorch-sparse -c pyg"
)
self.execute_command("conda install pyg pytorch-scatter pytorch-sparse -c pyg")

else: # self._tool == "pip"
torch_version = torch.__version__
Expand Down
4 changes: 1 addition & 3 deletions pypots/cli/pypots_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@


def main():
parser = ArgumentParser(
"PyPOTS Command-Line-Interface tool", usage="pypots-cli <command> [<args>]"
)
parser = ArgumentParser("PyPOTS Command-Line-Interface tool", usage="pypots-cli <command> [<args>]")
commands_parser = parser.add_subparsers(help="pypots-cli command helpers")

# Register commands here
Expand Down
Loading
Loading