diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 6650a37..a500d25 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -45,9 +45,3 @@ jobs: - name: Run isort run: | isort --profile black kronfluence - - actionlint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: reviewdog/action-actionlint@v1 \ No newline at end of file diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..b6690e0 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,8 @@ +name: Ruff +on: [push, pull_request] +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 \ No newline at end of file diff --git a/examples/cifar/pipeline.py b/examples/cifar/pipeline.py index c46932c..04f3ab0 100644 --- a/examples/cifar/pipeline.py +++ b/examples/cifar/pipeline.py @@ -1,14 +1,13 @@ import copy import math from typing import Dict, List, Optional, Tuple - +from torch import nn import numpy as np import torch -import torch.nn as nn import torchvision -class Mul(torch.nn.Module): +class Mul(nn.Module): def __init__(self, weight: float) -> None: super().__init__() self.weight = weight @@ -17,12 +16,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.weight -class Flatten(torch.nn.Module): +class Flatten(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return x.view(x.size(0), -1) -class Residual(torch.nn.Module): +class Residual(nn.Module): def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module @@ -71,30 +70,17 @@ def conv_bn( return model -# def get_hyperparameters(data_name: str) -> Dict[str, float]: -# wd = 0.001 -# if data_name == "cifar2": -# lr = 0.5 -# epochs = 100 -# elif data_name == "cifar10": -# lr = 0.4 -# epochs = 25 -# else: -# raise NotImplementedError() -# return {"lr": lr, "wd": wd, "epochs": epochs} - - def get_cifar10_dataset( split: str, + do_corrupt: bool, indices: List[int] = None, - data_path: str = "data/", + data_dir: str = "data/", ): assert split in ["train", "eval_train", "valid"] normalize = torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261)) - if split in ["train", "eval_train"]: - transforms = torchvision.transforms.Compose( + transform_config = torchvision.transforms.Compose( [ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip(), @@ -102,76 +88,19 @@ def get_cifar10_dataset( normalize, ] ) - else: - transforms = torchvision.transforms.Compose( - [ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize(mean=MEAN, std=STD), - ] - ) - - if split == "train": - transform_config = [ - torchvision.transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)), - torchvision.transforms.RandomHorizontalFlip(), - ] - transform_config.extend([torchvision.transforms.ToTensor(), normalize]) - transform_config = torchvision.transforms.Compose(transform_config) - else: transform_config = torchvision.transforms.Compose( [ - torchvision.transforms.Resize(size=256), - torchvision.transforms.CenterCrop(size=224), torchvision.transforms.ToTensor(), normalize, ] ) - folder = "train" if split in ["train", "eval_train"] else "val" - dataset = torchvision.datasets.ImageFolder( - root=os.path.join(data_path, folder), - transform=transform_config, - ) - - if indices is not None: - dataset = torch.utils.data.Subset(dataset, indices) - - return dataset - - -def get_cifar10_dataloader( - batch_size: int, - split: str = "train", - indices: List[int] = None, - do_corrupt: bool = False, - num_workers: int = 4, -) -> torch.utils.data.DataLoader: - MEAN = (0.4914, 0.4822, 0.4465) - STD = (0.247, 0.243, 0.261) - - if split in ["train", "eval_train"]: - transforms = torchvision.transforms.Compose( - [ - torchvision.transforms.RandomCrop(32, padding=4), - torchvision.transforms.RandomHorizontalFlip(), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize(mean=MEAN, std=STD), - ] - ) - else: - transforms = torchvision.transforms.Compose( - [ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize(mean=MEAN, std=STD), - ] - ) - dataset = torchvision.datasets.CIFAR10( - root="/tmp/cifar/", + root=data_dir, download=True, train=split in ["train", "eval_train", "eval_train_with_aug"], - transform=transforms, + transform=transform_config, ) if do_corrupt: @@ -198,11 +127,4 @@ def get_cifar10_dataloader( if indices is not None: dataset = torch.utils.data.Subset(dataset, indices) - return torch.utils.data.DataLoader( - dataset=dataset, - shuffle=split == "train", - batch_size=batch_size, - num_workers=num_workers, - drop_last=split == "train", - pin_memory=True, - ) + return dataset diff --git a/examples/glue/train.py b/examples/glue/train.py index 0a845e9..c023214 100644 --- a/examples/glue/train.py +++ b/examples/glue/train.py @@ -10,7 +10,6 @@ from transformers import default_data_collator from examples.glue.pipeline import construct_bert, get_glue_dataset -from examples.mnist.pipeline import construct_mnist_mlp, get_mnist_dataset def parse_args(): @@ -87,100 +86,19 @@ def main(): args = parse_args() logging.basicConfig(level=logging.INFO) - logger = logging.getLogger() if args.seed is not None: set_seed(args.seed) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir) - train_dataloader = DataLoader( - dataset=train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=default_data_collator, - drop_last=True, - ) - model = construct_bert(args.data_name).to(device=device) - # optimizer = torch.optim.SGD( - # model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay - # ) - # - # logger.info("Start training the model.") - # model.train() - # for epoch in range(args.num_train_epochs): - # - # total_loss = 0 - # - # with tqdm(train_dataloader, unit="batch") as tepoch: - # - # for batch in tepoch: - # tepoch.set_description(f"Epoch {epoch}") - # inputs, labels = batch - # inputs, labels = inputs.to(device), labels.to(device) - # logits = model(inputs) - # loss = F.cross_entropy(logits, labels) - # total_loss += loss.detach().float() - # loss.backward() - # optimizer.step() - # optimizer.zero_grad() - # tepoch.set_postfix(loss=total_loss.item() / len(train_dataloader)) - # - # logger.info("Start evaluating the model.") - # model.eval() - # train_eval_dataset = get_mnist_dataset( - # split="eval_train", data_path=args.dataset_dir - # ) - # train_eval_dataloader = DataLoader( - # dataset=train_eval_dataset, - # batch_size=args.eval_batch_size, - # shuffle=False, - # drop_last=False, - # ) - # eval_dataset = get_mnist_dataset(split="valid", data_path=args.dataset_dir) - # eval_dataloader = DataLoader( - # dataset=eval_dataset, - # batch_size=args.eval_batch_size, - # shuffle=False, - # drop_last=False, - # ) - # - # total_loss = 0 - # correct = 0 - # for batch in train_eval_dataloader: - # with torch.no_grad(): - # inputs, labels = batch - # inputs, labels = inputs.to(device), labels.to(device) - # logits = model(inputs) - # loss = F.cross_entropy(logits, labels) - # preds = logits.argmax(dim=1, keepdim=True) - # correct += preds.eq(labels.view_as(preds)).sum().item() - # total_loss += loss.detach().float() - # - # logger.info( - # f"Train loss: {total_loss.item() / len(train_eval_dataloader.dataset)} | " - # f"Train Accuracy: {100 * correct / len(train_eval_dataloader.dataset)}" - # ) - # - # total_loss = 0 - # correct = 0 - # for batch in eval_dataloader: - # with torch.no_grad(): - # inputs, labels = batch - # inputs, labels = inputs.to(device), labels.to(device) - # logits = model(inputs) - # loss = F.cross_entropy(logits, labels) - # preds = logits.argmax(dim=1, keepdim=True) - # correct += preds.eq(labels.view_as(preds)).sum().item() - # total_loss += loss.detach().float() - # - # logger.info( - # f"Train loss: {total_loss.item() / len(eval_dataloader.dataset)} | " - # f"Train Accuracy: {100 * correct / len(eval_dataloader.dataset)}" + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir) + # train_dataloader = DataLoader( + # dataset=train_dataset, + # batch_size=args.train_batch_size, + # shuffle=True, + # collate_fn=default_data_collator, + # drop_last=True, # ) - # - # if args.checkpoint_dir is not None: - # torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, "model.pth")) if __name__ == "__main__": diff --git a/examples/imagenet/analyze.py b/examples/imagenet/analyze.py index ede619c..99c0f00 100644 --- a/examples/imagenet/analyze.py +++ b/examples/imagenet/analyze.py @@ -98,10 +98,9 @@ def main(): args = parse_args() logging.basicConfig(level=logging.INFO) - logger = logging.getLogger() train_dataset = get_imagenet_dataset(split="eval_train", data_path=args.dataset_dir) - eval_dataset = get_imagenet_dataset(split="valid", data_path=args.dataset_dir) + # eval_dataset = get_imagenet_dataset(split="valid", data_path=args.dataset_dir) model = construct_resnet50() @@ -125,26 +124,7 @@ def main(): factor_args=factor_args, per_device_batch_size=1024, overwrite_output_dir=True, - dataloader_num_workers=2, - dataloader_pin_memory=True, ) - # analyzer.perform_eigendecomposition( - # factor_name=args.factor_strategy, - # factor_args=factor_args, - # overwrite_output_dir=True, - # ) - # analyzer.fit_lambda(train_dataset, per_device_batch_size=None) - # - # score_name = "full_pairwise" - # analyzer.compute_pairwise_scores( - # score_name=score_name, - # query_dataset=eval_dataset, - # per_device_query_batch_size=len(eval_dataset), - # train_dataset=train_dataset, - # per_device_train_batch_size=len(train_dataset), - # ) - # scores = analyzer.load_pairwise_scores(score_name=score_name) - # print(scores.shape) if __name__ == "__main__": diff --git a/examples/imagenet/ddp_analyze.py b/examples/imagenet/ddp_analyze.py index 3bedb9c..7106eef 100644 --- a/examples/imagenet/ddp_analyze.py +++ b/examples/imagenet/ddp_analyze.py @@ -1,21 +1,18 @@ import argparse import logging -import math import os -from typing import Dict, Tuple +from typing import Tuple import torch import torch.distributed as dist import torch.nn.functional as F from analyzer import Analyzer, prepare_model from arguments import FactorArguments -from module.utils import wrap_tracked_modules from task import Task from torch import nn from torch.nn.parallel.distributed import DistributedDataParallel from examples.imagenet.pipeline import construct_resnet50, get_imagenet_dataset -from examples.mnist.pipeline import construct_mnist_mlp, get_mnist_dataset BATCH_DTYPE = Tuple[torch.Tensor, torch.Tensor] LOCAL_RANK = int(os.environ["LOCAL_RANK"]) @@ -107,10 +104,9 @@ def main(): args = parse_args() logging.basicConfig(level=logging.INFO) - logger = logging.getLogger() train_dataset = get_imagenet_dataset(split="eval_train", data_path=args.dataset_dir) - eval_dataset = get_imagenet_dataset(split="valid", data_path=args.dataset_dir) + # eval_dataset = get_imagenet_dataset(split="valid", data_path=args.dataset_dir) dist.init_process_group("nccl", rank=WORLD_RANK, world_size=WORLD_SIZE) device = torch.device("cuda:{}".format(LOCAL_RANK)) @@ -144,8 +140,6 @@ def main(): factor_args=factor_args, per_device_batch_size=None, overwrite_output_dir=True, - dataloader_num_workers=2, - dataloader_pin_memory=True, ) # analyzer.perform_eigendecomposition( # factor_name=args.factor_strategy, diff --git a/examples/imagenet/pipeline.py b/examples/imagenet/pipeline.py index 119973b..c61976e 100644 --- a/examples/imagenet/pipeline.py +++ b/examples/imagenet/pipeline.py @@ -48,8 +48,3 @@ def get_imagenet_dataset( dataset = torch.utils.data.Subset(dataset, indices) return dataset - - -if __name__ == "__main__": - model = construct_resnet50() - print(model) diff --git a/examples/uci/analyze.py b/examples/uci/analyze.py index 648a4b6..15916b8 100644 --- a/examples/uci/analyze.py +++ b/examples/uci/analyze.py @@ -98,7 +98,6 @@ def main(): args = parse_args() logging.basicConfig(level=logging.INFO) - logger = logging.getLogger() train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir) eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", data_path=args.dataset_dir) diff --git a/pyproject.toml b/pyproject.toml index d0b22bb..3cdf4aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,12 +3,17 @@ profile = "black" [tool.ruff] line-length = 120 +target-version = "py39" + +[tool.ruff.lint] +ignore = ["F401"] [tool.ruff.format] quote-style = "double" skip-magic-trailing-comma = false line-ending = "auto" docstring-code-format = true +docstring-code-line-length = "dynamic" [tool.pylint.format] max-line-length = "120" diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index d4026bf..32601da 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -80,8 +80,8 @@ def test_default_factor_arguments() -> None: factor_args = FactorArguments() assert factor_args.strategy == "ekfac" - assert factor_args.use_empirical_fisher == False - assert factor_args.immediate_gradient_removal == False + assert factor_args.use_empirical_fisher is False + assert factor_args.immediate_gradient_removal is False assert factor_args.covariance_max_examples == 100_000 assert factor_args.covariance_data_partition_size == 1 @@ -94,8 +94,8 @@ def test_default_factor_arguments() -> None: assert factor_args.lambda_max_examples == 100_000 assert factor_args.lambda_data_partition_size == 1 assert factor_args.lambda_module_partition_size == 1 - assert factor_args.lambda_iterative_aggregate == False - assert factor_args.cached_activation_cpu_offload == False + assert factor_args.lambda_iterative_aggregate is False + assert factor_args.cached_activation_cpu_offload is False assert factor_args.lambda_dtype == torch.float32 @@ -103,16 +103,16 @@ def test_default_score_arguments() -> None: factor_args = ScoreArguments() assert factor_args.damping is None - assert factor_args.immediate_gradient_removal == False + assert factor_args.immediate_gradient_removal is False assert factor_args.data_partition_size == 1 assert factor_args.module_partition_size == 1 - assert factor_args.per_module_score == False + assert factor_args.per_module_score is False assert factor_args.query_gradient_rank is None assert factor_args.query_gradient_svd_dtype == torch.float64 assert factor_args.score_dtype == torch.float32 - assert factor_args.cached_activation_cpu_offload == False + assert factor_args.cached_activation_cpu_offload is False assert factor_args.per_sample_gradient_dtype == torch.float32 assert factor_args.precondition_dtype == torch.float32