diff --git a/README.md b/README.md index 2c24267..7e7e913 100644 --- a/README.md +++ b/README.md @@ -19,38 +19,37 @@ arXiv

-Torch-Pruning (TP) is a versatile library for Structural Pruning with the following features: -* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of neural networks, including *[Large Language Models (LLMs)](https://github.com/horseee/LLM-Pruner), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Vision Transformers](benchmarks/prunability), [Yolov7](benchmarks/prunability/readme.md#3-yolo-v7), [yolov8](benchmarks/prunability/readme.md#2-yolo-v8), FasterRCNN, SSD, KeypointRCNN, MaskRCNN, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, FCN, DeepLab, etc*. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys a (non-deep) graph algorithm called **DepGraph** to remove parameters and channels physically. -* **Reproducible [Performance Benchmark](benchmarks) and [Prunability Benchmark](benchmarks/prunability):** Currently, TP is able to prune approximately **81/85=95.3%** of the models from Torchvision 0.13.1. Try this [Colab Demo](https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing) for quick start. -* **[Tutorials and Documents](https://github.com/VainF/Torch-Pruning/wiki) are available at the GitHub Wiki page**. - +Torch-Pruning (TP) is a library for structural pruning that enables the following features: + +* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including *[Large Language Models (LLMs)](https://github.com/horseee/LLM-Pruner), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [ViT](examples/torchvision_models/), FasterRCNN, SSD, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, DeepLab, etc*. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys a (non-deep) graph algorithm called **DepGraph** to remove parameters physically. Currently, TP is able to prune approximately **81/85=95.3%** of the models from Torchvision 0.13.1. Try this [Colab Demo](https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing) for a quick start. +* **[Performance Benchmark](benchmarks)**: Reproduce the our results in the DepGraph paper. +* **[Tutorials and Documents](https://github.com/VainF/Torch-Pruning/wiki)** are available at the GitHub Wiki. + For more technical details, please refer to our CVPR'23 paper: > [**DepGraph: Towards Any Structural Pruning**](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html) > [Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Mingli Song](https://person.zju.edu.cn/en/msong), [Michael Bi Mi](https://dblp.org/pid/317/0937.html), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) -Please do not hesitate to open a [discussion](https://github.com/VainF/Torch-Pruning/discussions) or [issue](https://github.com/VainF/Torch-Pruning/issues) if you encounter any problems with the library or the paper. +Please do not hesitate to open a [discussion](https://github.com/VainF/Torch-Pruning/discussions) or [issue](https://github.com/VainF/Torch-Pruning/issues) if you encounter any problems with the library or the paper. ### Update: + * 2023.07.19 :rocket: Support LLaMA, LLaMA-2, Vicuna, Baichuan in [LLM-Pruner](https://github.com/horseee/LLM-Pruner) * 2023.05.20 :rocket: [**LLM-Pruner: On the Structural Pruning of Large Language Models**](https://github.com/horseee/LLM-Pruner) [*[arXiv]*](https://arxiv.org/abs/2305.11627) * 2023.05.19 [Structural Pruning for Diffusion Models](https://github.com/VainF/Diff-Pruning) [*[arXiv]*](https://arxiv.org/abs/2305.10924) * 2023.04.15 [Pruning and Post-training for YOLOv7 / YOLOv8](benchmarks/prunability) -* 2023.04.21 Join our Telegram or Wechat group for casual discussions: +* 2023.04.21 Join our Telegram or Wechat group: * Telegram: https://t.me/+NwjbBDN2ao1lZjZl * WeChat: image - - - ### **Features:** -- [x] Structural pruning for CNNs, Transformers, Detectors, Language Models and Diffusion Models. Please refer to the [Prunability Benchmark](benchmarks/prunability). +- [x] Structural pruning for CNNs, Transformers, Detectors, Language Models and Diffusion Models. Please refer to the [examples](examples). - [x] High-level pruners: [MagnitudePruner](https://arxiv.org/abs/1608.08710), [BNScalePruner](https://arxiv.org/abs/1708.06519), [GroupNormPruner](https://arxiv.org/abs/2301.12900), RandomPruner, etc. - [x] Importance Criteria: L-p Norm, Taylor, Random, BNScaling, etc. -- [x] Dependency Graph +- [x] Dependency Graph for automatic structrual pruning - [x] Supported modules: Linear, (Transposed) Conv, Normalization, PReLU, Embedding, MultiheadAttention, nn.Parameters and [customized modules](tests/test_customized_layer.py). - [x] Supported operators: split, concatenation, skip connection, flatten, reshape, view, all element-wise ops, etc. - [x] [Low-level pruning functions](torch_pruning/pruner/function.py) -- [x] [Benchmarks](benchmarks) and [tutorials](tutorials) +- [x] [Benchmarks](benchmarks) and [Tutorials](https://github.com/VainF/Torch-Pruning/wiki) ### **TODO List:** - [ ] A strong baseline with bags of tricks from existing methods. @@ -63,10 +62,10 @@ Please do not hesitate to open a [discussion](https://github.com/VainF/Torch-Pru ## Installation -Torch-Pruning is compatible with PyTorch 1.x and 2.x. **PyTorch 1.12.1 is recommended!** +Torch-Pruning is compatible with both PyTorch 1.x and 2.x versions. However, it is highly recommended to use PyTorch 1.12.1 or higher. ```bash -pip install torch-pruning # v1.1.9 +pip install torch-pruning ``` or ```bash @@ -75,11 +74,11 @@ git clone https://github.com/VainF/Torch-Pruning.git ## Quickstart -Here we provide a quick start for Torch-Pruning. More explained details can be found in [tutorals](https://github.com/VainF/Torch-Pruning/wiki) +Here we provide a quick start for Torch-Pruning. More explained details can be found in [Tutorals](https://github.com/VainF/Torch-Pruning/wiki) ### 0. How It Works -In structural pruning, a ``Group`` is defined as the minimal removable unit within deep networks. Each group consists of multiple interdependent layers that need to be pruned simultaneously in order to preserve the integrity of the resulting structures. However, deep networks often exhibit intricate dependencies among layers, posing a significant challenge for structural pruning. This work tackles this challenge by introducing an automated mechanism called ``DepGraph``, which enables effortless parameter grouping and facilitates pruning for a diverse range of deep networks. +In structural pruning, a "Group" is defined as the minimal unit that can be removed within deep networks. These groups are composed of multiple layers that are interdependent and need to be pruned together in order to maintain the integrity of the resulting structures. However, deep networks often have complex dependencies among their layers, making structural pruning a challenging task. This work addresses this challenge by introducing an automated mechanism called "DepGraph." DepGraph allows for seamless parameter grouping and facilitates pruning in various types of deep networks.
@@ -110,7 +109,7 @@ torch.save(model, 'model.pth') # without .state_dict model = torch.load('model.pth') # load the model object ``` -The above example demonstrates the fundamental pruning pipeline using DepGraph. The target layer resnet.conv1 is coupled with several layers, which requires simultaneous removal in structural pruning. Let's print the group and observe how a pruning operation "triggers" other ones. In the following outputs, ``A => B`` means the pruning operation ``A`` triggers the pruning operation ``B``. group[0] refers to the pruning root in ``DG.get_pruning_group``. +The above example demonstrates the fundamental pruning pipeline utilizing DepGraph. The target layer resnet.conv1 is coupled with several layers, necessitating their simultaneous removal during structural pruning. To observe the cascading effect of pruning operations, we can print the groups and observe how one pruning operation can "trigger" others. In the subsequent outputs, "A => B" indicates that pruning operation "A" triggers pruning operation "B." The group[0] refers to the pruning root in DG.get_pruning_group. ``` -------------------------------- @@ -134,7 +133,7 @@ The above example demonstrates the fundamental pruning pipeline using DepGraph. [15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9] -------------------------------- ``` -For more details about grouping, please refer to [tutorials/2 - Exploring Dependency Groups](https://github.com/VainF/Torch-Pruning/blob/master/tutorials/2%20-%20Exploring%20Dependency%20Groups.ipynb) +For more details about grouping, please refer to [Wiki - DepGraph & Group](https://github.com/VainF/Torch-Pruning/wiki/3.-DepGraph-&-Group) #### How to scan all groups (Advanced): We can use ``DG.get_all_groups(ignored_layers, root_module_types)`` to scan all groups sequentially. Each group will begin with a layer that matches a type in the "root_module_types" parameter. Note that DG.get_all_groups is only responsible for grouping and does not have any knowledge or understanding of which parameters should be pruned. Therefore, it is necessary to specify the pruning idxs using ``group.prune(idxs=idxs)``. @@ -149,7 +148,7 @@ for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[ ### 2. High-level Pruners -Leveraging the DependencyGraph, we developed several high-level pruners in this repository to facilitate effortless pruning. By specifying the desired channel sparsity, you can prune the entire model and fine-tune it using your own training code. For detailed information on this process, please refer to [this tutorial](https://github.com/VainF/Torch-Pruning/blob/master/tutorials/1%20-%20Customize%20Your%20Own%20Pruners.ipynb), which shows how to implement a [slimming](https://arxiv.org/abs/1708.06519) pruner from scratch. Additionally, you can find more practical examples in [benchmarks/main.py](benchmarks/main.py). +Leveraging the DependencyGraph, we developed several high-level pruners in this repository to facilitate effortless pruning. By specifying the desired channel sparsity, the pruner will scan all prunable groups, prune the entire model, and fine-tune it using your own training code. For detailed information on this process, please refer to [this tutorial](https://github.com/VainF/Torch-Pruning/blob/master/tutorials/1%20-%20Customize%20Your%20Own%20Pruners.ipynb), which shows how to implement a [slimming](https://arxiv.org/abs/1708.06519) pruner from scratch. Additionally, a more practical example is available in [benchmarks/main.py](benchmarks/main.py). ```python import torch @@ -179,10 +178,13 @@ pruner = tp.pruner.MagnitudePruner( base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) for i in range(iterative_steps): + + # Taylor expansion requires gradients for importance estimation if isinstance(imp, tp.importance.TaylorImportance): - # Taylor expansion requires gradients for importance estimation - loss = model(example_inputs).sum() # a dummy loss for TaylorImportance + # A dummy loss, please replace it with your loss function and data! + loss = model(example_inputs).sum() loss.backward() # before pruner.step() + pruner.step() macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) # finetune your model here @@ -191,7 +193,7 @@ for i in range(iterative_steps): ``` #### Sparse Training -Some pruners like [BNScalePruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py#L45) and [GroupNormPruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/group_norm_pruner.py#L53) require sparse training before pruning. This can be easily achieved by inserting just one line of code ``pruner.regularize(model)`` in your training script. The pruner will update the gradient of trainable parameters. +Some pruners like [BNScalePruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py#L45) and [GroupNormPruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/group_norm_pruner.py#L53) require sparse training before pruning. This can be easily achieved by inserting one line of code ``pruner.regularize(model)`` just after ``loss.backward()`` and before ``optimizer.step()``. The pruner will update the gradient of trainable parameters. ```python for epoch in range(epochs): model.train() @@ -200,13 +202,13 @@ for epoch in range(epochs): optimizer.zero_grad() out = model(data) loss = F.cross_entropy(out, target) - loss.backward() - pruner.regularize(model) # <== for sparse learning - optimizer.step() + loss.backward() # after loss.backward() + pruner.regularize(model) # <== for sparse training + optimizer.step() # before optimizer.step() ``` #### Interactive Pruning (Advanced) -All high-level pruners support interactive pruning. Use ``pruner.step(interactive=True)`` to get all groups and interactively prune them by calling ``group.prune()``. This feature is useful if you want to control/monitor the pruning process. +All high-level pruners offer support for interactive pruning. You can utilize the method "pruner.step(interactive=True)" to retrieve all the groups and interactively prune them by calling "group.prune()". This feature is particularly useful when you want to have control over or monitor the pruning process. ```python for i in range(iterative_steps): @@ -216,10 +218,7 @@ for i in range(iterative_steps): dep, idxs = group[0] # get the idxs target_module = dep.target.module # get the root module pruning_fn = dep.handler # get the pruning function - - # Don't forget to prune the group group.prune() - # group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) # finetune your model here @@ -239,7 +238,7 @@ With DepGraph, it is easy to design some "group-level" criteria to estimate the The following script saves the whole model object (structure+weights) as a 'model.pth'. ```python -model.zero_grad() # We don't want to store gradient information +model.zero_grad() # Remove gradients torch.save(model, 'model.pth') # without .state_dict model = torch.load('model.pth') # load the pruned model ``` @@ -256,14 +255,13 @@ new_model = resnet18().eval() # load the pruned state_dict into the unpruned model. loaded_state_dict = torch.load('pruned.pth', map_location='cpu') tp.load_state_dict(new_model, state_dict=loaded_state_dict) -print(new_model) # This will be a pruned model. ``` Refer to [tests/test_serialization.py](tests/test_serialization.py) for an ViT example. In this example, we will prune the model and modify some attributes like ``model.hidden_dims``. ### 4. Low-level Pruning Functions -While it is possible to manually prune your model using low-level functions, this approach can be quite laborious, as it requires careful management of the associated dependencies. As a result, we recommend utilizing the aforementioned high-level pruners to streamline the pruning process. +Although it is possible to manually prune your model using low-level functions, this approach can be cumbersome and time-consuming due to the need for meticulous management of dependencies. Therefore, we strongly recommend utilizing the high-level pruners mentioned earlier to streamline and simplify the pruning process. These pruners provide a more convenient and efficient way to perform pruning on your models. To manually prune the ``model.conv1`` of a ResNet-18, the pruning pipeline should look like this: ```python tp.prune_conv_out_channels( model.conv1, idxs=[2,6,9] ) @@ -274,7 +272,7 @@ tp.prune_conv_in_channels( model.layer2[0].conv1, idxs=[2,6,9] ) ... ``` -The following pruning functions are available: +The following [pruning functions](torch_pruning/pruner/function.py) are available: ```python 'prune_conv_out_channels', 'prune_conv_in_channels', diff --git a/benchmarks/main.py b/benchmarks/main.py index f7711be..8aaf3c7 100644 --- a/benchmarks/main.py +++ b/benchmarks/main.py @@ -33,6 +33,7 @@ parser.add_argument("--max-sparsity", type=float, default=1.0) parser.add_argument("--soft-keeping-ratio", type=float, default=0.0) parser.add_argument("--reg", type=float, default=5e-4) +parser.add_argument("--delta_reg", type=float, default=1e-4, help='for growing regularization') parser.add_argument("--weight-decay", type=float, default=5e-4) parser.add_argument("--seed", type=int, default=None) @@ -109,6 +110,7 @@ def train_model( best_acc = -1 for epoch in range(epochs): model.train() + for i, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() @@ -129,6 +131,10 @@ def train_model( optimizer.param_groups[0]["lr"], ) ) + + if pruner is not None and isinstance(pruner, tp.pruner.GrowingRegPruner): + pruner.update_reg() # increase the strength of regularization + #print(pruner.group_reg[pruner._groups[0]]) model.eval() acc, val_loss = eval(model, test_loader, device=device) @@ -171,6 +177,10 @@ def get_pruner(model, example_inputs): args.sparsity_learning = True imp = tp.importance.BNScaleImportance() pruner_entry = partial(tp.pruner.BNScalePruner, reg=args.reg, global_pruning=args.global_pruning) + elif args.method == "group_slim": + args.sparsity_learning = True + imp = tp.importance.BNScaleImportance() + pruner_entry = partial(tp.pruner.BNScalePruner, reg=args.reg, global_pruning=args.global_pruning, group_lasso=True) elif args.method == "group_norm": imp = tp.importance.GroupNormImportance(p=2) pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=args.global_pruning) @@ -178,6 +188,10 @@ def get_pruner(model, example_inputs): args.sparsity_learning = True imp = tp.importance.GroupNormImportance(p=2) pruner_entry = partial(tp.pruner.GroupNormPruner, reg=args.reg, global_pruning=args.global_pruning) + elif args.method == "growing_reg": + args.sparsity_learning = True + imp = tp.importance.GroupNormImportance(p=2) + pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg, global_pruning=args.global_pruning) else: raise NotImplementedError diff --git a/benchmarks/prunability/coco_image.jpg b/benchmarks/prunability/coco_image.jpg deleted file mode 100644 index 02e67fd..0000000 Binary files a/benchmarks/prunability/coco_image.jpg and /dev/null differ diff --git a/benchmarks/prunability/readme.md b/benchmarks/prunability/readme.md deleted file mode 100644 index a1963be..0000000 --- a/benchmarks/prunability/readme.md +++ /dev/null @@ -1,348 +0,0 @@ -# Prunability - -- [Prunability](#prunability) - - [Torchvision](#1-torchvision) - - [YOLO-v8](#2-yolo-v8) - - [YOLO-v7](#3-yolo-v7) - - -## 0. Requirements - -```bash -pip install -r requirements.txt -``` -Tested environment: -``` -Pytorch==1.12.1 -Torchvision==0.13.1 -``` - - -## 1. Torchvision - -```python -python torchvision_pruning.py -``` - -#### Outputs: -``` -Successful Pruning: 81 Models - ['ssdlite320_mobilenet_v3_large', 'ssd300_vgg16', 'fasterrcnn_resnet50_fpn', 'fasterrcnn_resnet50_fpn_v2', 'fasterrcnn_mobilenet_v3_large_320_fpn', 'fasterrcnn_mobilenet_v3_large_fpn', 'fcos_resnet50_fpn', 'keypointrcnn_resnet50_fpn', 'maskrcnn_resnet50_fpn_v2', 'retinanet_resnet50_fpn_v2', 'alexnet', 'vit_b_16', 'vit_b_32', 'vit_l_16', 'vit_l_32', 'vit_h_14', 'convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l', 'googlenet', 'inception_v3', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_1_6gf', 'regnet_y_3_2gf', 'regnet_y_8gf', 'regnet_y_16gf', 'regnet_y_32gf', 'regnet_y_128gf', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2', 'fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101', 'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'] - -``` - -``` -Unsuccessful Pruning: 4 Models - ['raft_large', 'swin_t', 'swin_s', 'swin_b'] -``` - -#### Vision Transfomer Example -``` -==============Before pruning================= -Model Name: vit_b_32 -VisionTransformer( - (conv_proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32)) - (encoder): Encoder( - (dropout): Dropout(p=0.0, inplace=False) - (layers): Sequential( - (encoder_layer_0): EncoderBlock( - (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) - (self_attention): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True) - ) - (dropout): Dropout(p=0.0, inplace=False) - (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) - (mlp): MLPBlock( - (0): Linear(in_features=768, out_features=3072, bias=True) - (1): GELU(approximate=none) - (2): Dropout(p=0.0, inplace=False) - (3): Linear(in_features=3072, out_features=768, bias=True) - (4): Dropout(p=0.0, inplace=False) - ) - ) -... - (encoder_layer_10): EncoderBlock( - (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) - (self_attention): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True) - ) - (dropout): Dropout(p=0.0, inplace=False) - (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) - (mlp): MLPBlock( - (0): Linear(in_features=768, out_features=3072, bias=True) - (1): GELU(approximate=none) - (2): Dropout(p=0.0, inplace=False) - (3): Linear(in_features=3072, out_features=768, bias=True) - (4): Dropout(p=0.0, inplace=False) - ) - ) - (encoder_layer_11): EncoderBlock( - (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) - (self_attention): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True) - ) - (dropout): Dropout(p=0.0, inplace=False) - (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) - (mlp): MLPBlock( - (0): Linear(in_features=768, out_features=3072, bias=True) - (1): GELU(approximate=none) - (2): Dropout(p=0.0, inplace=False) - (3): Linear(in_features=3072, out_features=768, bias=True) - (4): Dropout(p=0.0, inplace=False) - ) - ) - ) - (ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True) - ) - (heads): Sequential( - (head): Linear(in_features=768, out_features=1000, bias=True) - ) -) -torch.Size([1, 1, 384]) torch.Size([1, 50, 384]) -==============After pruning================= -VisionTransformer( - (conv_proj): Conv2d(3, 384, kernel_size=(32, 32), stride=(32, 32)) - (encoder): Encoder( - (dropout): Dropout(p=0.0, inplace=False) - (layers): Sequential( - (encoder_layer_0): EncoderBlock( - (ln_1): LayerNorm((384,), eps=1e-06, elementwise_affine=True) - (self_attention): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True) - ) - (dropout): Dropout(p=0.0, inplace=False) - (ln_2): LayerNorm((384,), eps=1e-06, elementwise_affine=True) - (mlp): MLPBlock( - (0): Linear(in_features=384, out_features=1536, bias=True) - (1): GELU(approximate=none) - (2): Dropout(p=0.0, inplace=False) - (3): Linear(in_features=1536, out_features=384, bias=True) - (4): Dropout(p=0.0, inplace=False) - ) - ) -... - (encoder_layer_10): EncoderBlock( - (ln_1): LayerNorm((384,), eps=1e-06, elementwise_affine=True) - (self_attention): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True) - ) - (dropout): Dropout(p=0.0, inplace=False) - (ln_2): LayerNorm((384,), eps=1e-06, elementwise_affine=True) - (mlp): MLPBlock( - (0): Linear(in_features=384, out_features=1536, bias=True) - (1): GELU(approximate=none) - (2): Dropout(p=0.0, inplace=False) - (3): Linear(in_features=1536, out_features=384, bias=True) - (4): Dropout(p=0.0, inplace=False) - ) - ) - (encoder_layer_11): EncoderBlock( - (ln_1): LayerNorm((384,), eps=1e-06, elementwise_affine=True) - (self_attention): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True) - ) - (dropout): Dropout(p=0.0, inplace=False) - (ln_2): LayerNorm((384,), eps=1e-06, elementwise_affine=True) - (mlp): MLPBlock( - (0): Linear(in_features=384, out_features=1536, bias=True) - (1): GELU(approximate=none) - (2): Dropout(p=0.0, inplace=False) - (3): Linear(in_features=1536, out_features=384, bias=True) - (4): Dropout(p=0.0, inplace=False) - ) - ) - ) - (ln): LayerNorm((384,), eps=1e-06, elementwise_affine=True) - ) - (heads): Sequential( - (head): Linear(in_features=384, out_features=1000, bias=True) - ) -) -Pruning vit_b_32: - Params: 88224232 => 22878952 - Output: torch.Size([1, 1000]) ------------------------------------------------------- -``` - -## 2. YOLO v8 -This example was implemented by [@Hyunseok-Kim0 (Hyunseok Kim)](https://github.com/Hyunseok-Kim0). Please refer to Issue [#147](https://github.com/VainF/Torch-Pruning/issues/147#issuecomment-1507475657) for more details. - -#### Ultralytics -```bash -git clone https://github.com/ultralytics/ultralytics.git -cp yolov8_pruning.py ultralytics/ -cd ultralytics -git checkout 44c7c3514d87a5e05cfb14dba5a3eeb6eb860e70 # for compatibility -``` - -#### Modification -Some functions will be automatically modified by the yolov8_pruning.py to prevent performance loss during model saving. - -##### 1. ```train``` in class ```YOLO``` -This function creates new trainer when called. Trainer loads model based on config file and reassign it to current model, which should be avoided for pruning. - -##### 2. ```save_model``` in class ```BaseTrainer``` -YOLO v8 saves trained model with half precision. Due to this precision loss, saved model shows different performance with validation result during fine-tuning. -This is modified to save the model with full precision because changing model to half precision can be done easily whenever after the pruning. - -##### 3. ```final_eval``` in class ```BaseTrainer``` -YOLO v8 replaces saved checkpoint file to half precision after training is done using ```strip_optimizer```. Half precision saving is changed with same reason above. - -#### Training -``` -# This example will craft yolov8-half and fine-tune it on the coco128 toy set. -python yolov8_pruning.py -``` - -#### Screenshot for coco128 post-training: -image - - -#### Outputs of yolov8_pruning.py: -``` -DetectionModel( - (model): Sequential( - (0): Conv( - (conv): Conv2d(3, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) - (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) - (act): SiLU(inplace=True) - ) - (1): Conv( - (conv): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) - (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) - (act): SiLU(inplace=True) - ) -... - (2): Sequential( - (0): Conv( - (conv): Conv2d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) - (act): SiLU(inplace=True) - ) - (1): Conv( - (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) - (act): SiLU(inplace=True) - ) - (2): Conv2d(320, 80, kernel_size=(1, 1), stride=(1, 1)) - ) - ) - (dfl): DFL( - (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False) - ) - ) - ) -) - - -DetectionModel( - (model): Sequential( - (0): Conv( - (conv): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) - (bn): BatchNorm2d(40, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) - (act): SiLU(inplace=True) - ) - (1): Conv( - (conv): Conv2d(40, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) - (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) - (act): SiLU(inplace=True) - ) -... - (2): Sequential( - (0): Conv( - (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) - (act): SiLU(inplace=True) - ) - (1): Conv( - (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) - (act): SiLU(inplace=True) - ) - (2): Conv2d(320, 80, kernel_size=(1, 1), stride=(1, 1)) - ) - ) - (dfl): DFL( - (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False) - ) - ) - ) -) -Before Pruning: MACs=129.092051 G, #Params=68.229648 M -After Pruning: MACs=41.741203 G, #Params=20.787528 M -``` - - - -## 3. YOLO v7 - -The following scripts (adapted from [yolov7/detect.py](https://github.com/WongKinYiu/yolov7/blob/main/detect.py) and [yolov7/train.py](https://github.com/WongKinYiu/yolov7/blob/main/train.py)) provide the basic examples of pruning YOLOv7. It is important to note that the training part has not been validated yet due to the time-consuming training process. - -Note: [yolov7_detect_pruned.py](https://github.com/VainF/Torch-Pruning/blob/master/benchmarks/prunability/yolov7_detect_pruned.py) does not include any code for fine-tuning. - -```bash -git clone https://github.com/WongKinYiu/yolov7.git -cp yolov7_detect_pruned.py yolov7/ -cp yolov7_train_pruned.py yolov7/ -cd yolov7 - -# Test only: We only prune and test the YOLOv7 model in this script. COCO dataset is not required. -python yolov7_detect_pruned.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inference/images/horses.jpg - -# Training with pruned yolov7 (The training part is not validated) -# Please download the pretrained yolov7_training.pt from https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt. -python yolov7_train_pruned.py --workers 8 --device 0 --batch-size 1 --data data/coco.yaml --img 640 640 --cfg cfg/training/yolov7.yaml --weights 'yolov7_training.pt' --name yolov7 --hyp data/hyp.scratch.p5.yaml -``` - -#### Screenshot for yolov7_train_pruned.py: -![image](https://user-images.githubusercontent.com/18592211/232129303-18a61be1-b505-4950-b6a1-c60b4974291b.png) - - -#### Outputs of yolov7_detect_pruned.py: -``` -Model( - (model): Sequential( - (0): Conv( - (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (act): SiLU(inplace=True) - ) -... - (104): RepConv( - (act): SiLU(inplace=True) - (rbr_reparam): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - ) - (105): Detect( - (m): ModuleList( - (0): Conv2d(256, 255, kernel_size=(1, 1), stride=(1, 1)) - (1): Conv2d(512, 255, kernel_size=(1, 1), stride=(1, 1)) - (2): Conv2d(1024, 255, kernel_size=(1, 1), stride=(1, 1)) - ) - ) - ) -) - - -Model( - (model): Sequential( - (0): Conv( - (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (act): SiLU(inplace=True) - ) -... - (104): RepConv( - (act): SiLU(inplace=True) - (rbr_reparam): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - ) - (105): Detect( - (m): ModuleList( - (0): Conv2d(128, 255, kernel_size=(1, 1), stride=(1, 1)) - (1): Conv2d(256, 255, kernel_size=(1, 1), stride=(1, 1)) - (2): Conv2d(512, 255, kernel_size=(1, 1), stride=(1, 1)) - ) - ) - ) -) -Before Pruning: MACs=6.413721 G, #Params=0.036905 G -After Pruning: MACs=1.639895 G, #Params=0.009347 G -``` - diff --git a/benchmarks/prunability/requirements.txt b/benchmarks/prunability/requirements.txt deleted file mode 100644 index 7c4c85a..0000000 --- a/benchmarks/prunability/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -torchvision>=0.13.1 -torch>=1.12.1 diff --git a/benchmarks/run/cifar10/prune/cifar10-global-group_sl-resnet56/cifar10-global-group_sl-resnet56.txt b/benchmarks/run/cifar10/prune/cifar10-global-group_sl-resnet56/cifar10-global-group_sl-resnet56.txt index 98c7268..cf06774 100644 --- a/benchmarks/run/cifar10/prune/cifar10-global-group_sl-resnet56/cifar10-global-group_sl-resnet56.txt +++ b/benchmarks/run/cifar10/prune/cifar10-global-group_sl-resnet56/cifar10-global-group_sl-resnet56.txt @@ -446,4 +446,452 @@ [01/03 20:19:34] cifar10-global-group_sl-resnet56 INFO: Epoch 97/100, Acc=0.9382, Val Loss=0.2336, lr=0.0001 [01/03 20:19:52] cifar10-global-group_sl-resnet56 INFO: Epoch 98/100, Acc=0.9383, Val Loss=0.2345, lr=0.0001 [01/03 20:20:09] cifar10-global-group_sl-resnet56 INFO: Epoch 99/100, Acc=0.9374, Val Loss=0.2349, lr=0.0001 -[01/03 20:20:09] cifar10-global-group_sl-resnet56 INFO: Best Acc=0.9391 \ No newline at end of file +[01/03 20:20:09] cifar10-global-group_sl-resnet56 INFO: Best Acc=0.9391[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: mode: prune +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: model: resnet56 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: verbose: False +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: dataset: cifar10 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: batch_size: 128 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: total_epochs: 100 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: lr_decay_milestones: 60,80 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: lr_decay_gamma: 0.1 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: lr: 0.01 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: restore: cifar10_resnet56.pth +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: output_dir: run/cifar10/prune/cifar10-global-group_sl-resnet56 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: method: group_sl +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: speed_up: 2.55 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: max_sparsity: 1.0 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: soft_keeping_ratio: 0.0 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: reg: 0.0005 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: weight_decay: 0.0005 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: seed: None +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: global_pruning: True +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: sl_total_epochs: 100 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: sl_lr: 0.01 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: sl_lr_decay_milestones: 60,80 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: sl_reg_warmup: 0 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: sl_restore: None +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: iterative_steps: 400 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: logger: +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: device: cuda +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: num_classes: 10 +[07/17 14:51:46] cifar10-global-group_sl-resnet56 INFO: Loading model from cifar10_resnet56.pth +[07/17 14:51:49] cifar10-global-group_sl-resnet56 INFO: Regularizing... +[07/17 14:52:26] cifar10-global-group_sl-resnet56 INFO: Epoch 0/100, Acc=0.8762, Val Loss=0.4190, lr=0.0100 +[07/17 14:53:02] cifar10-global-group_sl-resnet56 INFO: Epoch 1/100, Acc=0.8704, Val Loss=0.4188, lr=0.0100 +[07/17 14:53:38] cifar10-global-group_sl-resnet56 INFO: Epoch 2/100, Acc=0.8835, Val Loss=0.3664, lr=0.0100 +[07/17 14:54:14] cifar10-global-group_sl-resnet56 INFO: Epoch 3/100, Acc=0.8716, Val Loss=0.4082, lr=0.0100 +[07/17 14:54:50] cifar10-global-group_sl-resnet56 INFO: Epoch 4/100, Acc=0.8548, Val Loss=0.4428, lr=0.0100 +[07/17 14:55:26] cifar10-global-group_sl-resnet56 INFO: Epoch 5/100, Acc=0.8513, Val Loss=0.4576, lr=0.0100 +[07/17 14:56:02] cifar10-global-group_sl-resnet56 INFO: Epoch 6/100, Acc=0.8464, Val Loss=0.4702, lr=0.0100 +[07/17 14:56:38] cifar10-global-group_sl-resnet56 INFO: Epoch 7/100, Acc=0.8538, Val Loss=0.4510, lr=0.0100 +[07/17 14:57:14] cifar10-global-group_sl-resnet56 INFO: Epoch 8/100, Acc=0.8386, Val Loss=0.5299, lr=0.0100 +[07/17 14:57:49] cifar10-global-group_sl-resnet56 INFO: Epoch 9/100, Acc=0.8556, Val Loss=0.4549, lr=0.0100 +[07/17 14:58:26] cifar10-global-group_sl-resnet56 INFO: Epoch 10/100, Acc=0.8567, Val Loss=0.4271, lr=0.0100 +[07/17 14:59:01] cifar10-global-group_sl-resnet56 INFO: Epoch 11/100, Acc=0.8336, Val Loss=0.5003, lr=0.0100 +[07/17 14:59:37] cifar10-global-group_sl-resnet56 INFO: Epoch 12/100, Acc=0.8751, Val Loss=0.3694, lr=0.0100 +[07/17 15:00:12] cifar10-global-group_sl-resnet56 INFO: Epoch 13/100, Acc=0.8700, Val Loss=0.3963, lr=0.0100 +[07/17 15:00:48] cifar10-global-group_sl-resnet56 INFO: Epoch 14/100, Acc=0.8402, Val Loss=0.4835, lr=0.0100 +[07/17 15:01:24] cifar10-global-group_sl-resnet56 INFO: Epoch 15/100, Acc=0.8475, Val Loss=0.4640, lr=0.0100 +[07/17 15:01:59] cifar10-global-group_sl-resnet56 INFO: Epoch 16/100, Acc=0.8683, Val Loss=0.4127, lr=0.0100 +[07/17 15:02:36] cifar10-global-group_sl-resnet56 INFO: Epoch 17/100, Acc=0.8796, Val Loss=0.3654, lr=0.0100 +[07/17 15:03:11] cifar10-global-group_sl-resnet56 INFO: Epoch 18/100, Acc=0.8727, Val Loss=0.3757, lr=0.0100 +[07/17 15:03:47] cifar10-global-group_sl-resnet56 INFO: Epoch 19/100, Acc=0.8561, Val Loss=0.4348, lr=0.0100 +[07/17 15:04:23] cifar10-global-group_sl-resnet56 INFO: Epoch 20/100, Acc=0.8643, Val Loss=0.4157, lr=0.0100 +[07/17 15:04:58] cifar10-global-group_sl-resnet56 INFO: Epoch 21/100, Acc=0.8666, Val Loss=0.4049, lr=0.0100 +[07/17 15:05:33] cifar10-global-group_sl-resnet56 INFO: Epoch 22/100, Acc=0.8583, Val Loss=0.4221, lr=0.0100 +[07/17 15:06:09] cifar10-global-group_sl-resnet56 INFO: Epoch 23/100, Acc=0.8517, Val Loss=0.4598, lr=0.0100 +[07/17 15:06:46] cifar10-global-group_sl-resnet56 INFO: Epoch 24/100, Acc=0.8643, Val Loss=0.4025, lr=0.0100 +[07/17 15:07:22] cifar10-global-group_sl-resnet56 INFO: Epoch 25/100, Acc=0.8472, Val Loss=0.4778, lr=0.0100 +[07/17 15:07:58] cifar10-global-group_sl-resnet56 INFO: Epoch 26/100, Acc=0.8486, Val Loss=0.4611, lr=0.0100 +[07/17 15:08:33] cifar10-global-group_sl-resnet56 INFO: Epoch 27/100, Acc=0.8485, Val Loss=0.4629, lr=0.0100 +[07/17 15:09:09] cifar10-global-group_sl-resnet56 INFO: Epoch 28/100, Acc=0.8365, Val Loss=0.4989, lr=0.0100 +[07/17 15:09:44] cifar10-global-group_sl-resnet56 INFO: Epoch 29/100, Acc=0.8366, Val Loss=0.4896, lr=0.0100 +[07/17 15:10:20] cifar10-global-group_sl-resnet56 INFO: Epoch 30/100, Acc=0.8617, Val Loss=0.4181, lr=0.0100 +[07/17 15:10:56] cifar10-global-group_sl-resnet56 INFO: Epoch 31/100, Acc=0.8042, Val Loss=0.6639, lr=0.0100 +[07/17 15:11:32] cifar10-global-group_sl-resnet56 INFO: Epoch 32/100, Acc=0.8545, Val Loss=0.4262, lr=0.0100 +[07/17 15:12:07] cifar10-global-group_sl-resnet56 INFO: Epoch 33/100, Acc=0.8639, Val Loss=0.4203, lr=0.0100 +[07/17 15:12:43] cifar10-global-group_sl-resnet56 INFO: Epoch 34/100, Acc=0.8665, Val Loss=0.3984, lr=0.0100 +[07/17 15:13:19] cifar10-global-group_sl-resnet56 INFO: Epoch 35/100, Acc=0.8584, Val Loss=0.4302, lr=0.0100 +[07/17 15:13:55] cifar10-global-group_sl-resnet56 INFO: Epoch 36/100, Acc=0.8241, Val Loss=0.5497, lr=0.0100 +[07/17 15:14:30] cifar10-global-group_sl-resnet56 INFO: Epoch 37/100, Acc=0.8534, Val Loss=0.4458, lr=0.0100 +[07/17 15:15:06] cifar10-global-group_sl-resnet56 INFO: Epoch 38/100, Acc=0.8505, Val Loss=0.4428, lr=0.0100 +[07/17 15:15:42] cifar10-global-group_sl-resnet56 INFO: Epoch 39/100, Acc=0.8383, Val Loss=0.4894, lr=0.0100 +[07/17 15:16:17] cifar10-global-group_sl-resnet56 INFO: Epoch 40/100, Acc=0.8602, Val Loss=0.4187, lr=0.0100 +[07/17 15:16:53] cifar10-global-group_sl-resnet56 INFO: Epoch 41/100, Acc=0.8571, Val Loss=0.4353, lr=0.0100 +[07/17 15:17:29] cifar10-global-group_sl-resnet56 INFO: Epoch 42/100, Acc=0.8365, Val Loss=0.5006, lr=0.0100 +[07/17 15:18:05] cifar10-global-group_sl-resnet56 INFO: Epoch 43/100, Acc=0.8498, Val Loss=0.4590, lr=0.0100 +[07/17 15:18:41] cifar10-global-group_sl-resnet56 INFO: Epoch 44/100, Acc=0.8182, Val Loss=0.5726, lr=0.0100 +[07/17 15:19:17] cifar10-global-group_sl-resnet56 INFO: Epoch 45/100, Acc=0.8340, Val Loss=0.5198, lr=0.0100 +[07/17 15:19:53] cifar10-global-group_sl-resnet56 INFO: Epoch 46/100, Acc=0.8564, Val Loss=0.4327, lr=0.0100 +[07/17 15:20:29] cifar10-global-group_sl-resnet56 INFO: Epoch 47/100, Acc=0.8584, Val Loss=0.4190, lr=0.0100 +[07/17 15:21:04] cifar10-global-group_sl-resnet56 INFO: Epoch 48/100, Acc=0.8601, Val Loss=0.4170, lr=0.0100 +[07/17 15:21:40] cifar10-global-group_sl-resnet56 INFO: Epoch 49/100, Acc=0.8104, Val Loss=0.5965, lr=0.0100 +[07/17 15:22:17] cifar10-global-group_sl-resnet56 INFO: Epoch 50/100, Acc=0.8348, Val Loss=0.5012, lr=0.0100 +[07/17 15:22:53] cifar10-global-group_sl-resnet56 INFO: Epoch 51/100, Acc=0.8462, Val Loss=0.4680, lr=0.0100 +[07/17 15:23:28] cifar10-global-group_sl-resnet56 INFO: Epoch 52/100, Acc=0.8319, Val Loss=0.5172, lr=0.0100 +[07/17 15:24:04] cifar10-global-group_sl-resnet56 INFO: Epoch 53/100, Acc=0.8321, Val Loss=0.5268, lr=0.0100 +[07/17 15:24:40] cifar10-global-group_sl-resnet56 INFO: Epoch 54/100, Acc=0.8442, Val Loss=0.4576, lr=0.0100 +[07/17 15:25:16] cifar10-global-group_sl-resnet56 INFO: Epoch 55/100, Acc=0.8006, Val Loss=0.6409, lr=0.0100 +[07/17 15:25:51] cifar10-global-group_sl-resnet56 INFO: Epoch 56/100, Acc=0.8573, Val Loss=0.4338, lr=0.0100 +[07/17 15:26:27] cifar10-global-group_sl-resnet56 INFO: Epoch 57/100, Acc=0.8555, Val Loss=0.4434, lr=0.0100 +[07/17 15:27:02] cifar10-global-group_sl-resnet56 INFO: Epoch 58/100, Acc=0.8638, Val Loss=0.4128, lr=0.0100 +[07/17 15:27:38] cifar10-global-group_sl-resnet56 INFO: Epoch 59/100, Acc=0.8381, Val Loss=0.4763, lr=0.0100 +[07/17 15:28:14] cifar10-global-group_sl-resnet56 INFO: Epoch 60/100, Acc=0.9211, Val Loss=0.2360, lr=0.0010 +[07/17 15:28:50] cifar10-global-group_sl-resnet56 INFO: Epoch 61/100, Acc=0.9239, Val Loss=0.2320, lr=0.0010 +[07/17 15:29:26] cifar10-global-group_sl-resnet56 INFO: Epoch 62/100, Acc=0.9258, Val Loss=0.2243, lr=0.0010 +[07/17 15:30:01] cifar10-global-group_sl-resnet56 INFO: Epoch 63/100, Acc=0.9290, Val Loss=0.2245, lr=0.0010 +[07/17 15:30:37] cifar10-global-group_sl-resnet56 INFO: Epoch 64/100, Acc=0.9270, Val Loss=0.2269, lr=0.0010 +[07/17 15:31:13] cifar10-global-group_sl-resnet56 INFO: Epoch 65/100, Acc=0.9297, Val Loss=0.2263, lr=0.0010 +[07/17 15:31:48] cifar10-global-group_sl-resnet56 INFO: Epoch 66/100, Acc=0.9282, Val Loss=0.2279, lr=0.0010 +[07/17 15:32:24] cifar10-global-group_sl-resnet56 INFO: Epoch 67/100, Acc=0.9281, Val Loss=0.2369, lr=0.0010 +[07/17 15:33:00] cifar10-global-group_sl-resnet56 INFO: Epoch 68/100, Acc=0.9276, Val Loss=0.2361, lr=0.0010 +[07/17 15:33:35] cifar10-global-group_sl-resnet56 INFO: Epoch 69/100, Acc=0.9264, Val Loss=0.2405, lr=0.0010 +[07/17 15:34:11] cifar10-global-group_sl-resnet56 INFO: Epoch 70/100, Acc=0.9227, Val Loss=0.2533, lr=0.0010 +[07/17 15:34:47] cifar10-global-group_sl-resnet56 INFO: Epoch 71/100, Acc=0.9211, Val Loss=0.2561, lr=0.0010 +[07/17 15:35:23] cifar10-global-group_sl-resnet56 INFO: Epoch 72/100, Acc=0.9254, Val Loss=0.2441, lr=0.0010 +[07/17 15:35:59] cifar10-global-group_sl-resnet56 INFO: Epoch 73/100, Acc=0.9272, Val Loss=0.2440, lr=0.0010 +[07/17 15:36:34] cifar10-global-group_sl-resnet56 INFO: Epoch 74/100, Acc=0.9230, Val Loss=0.2592, lr=0.0010 +[07/17 15:37:10] cifar10-global-group_sl-resnet56 INFO: Epoch 75/100, Acc=0.9238, Val Loss=0.2486, lr=0.0010 +[07/17 15:37:46] cifar10-global-group_sl-resnet56 INFO: Epoch 76/100, Acc=0.9270, Val Loss=0.2514, lr=0.0010 +[07/17 15:38:22] cifar10-global-group_sl-resnet56 INFO: Epoch 77/100, Acc=0.9211, Val Loss=0.2699, lr=0.0010 +[07/17 15:38:58] cifar10-global-group_sl-resnet56 INFO: Epoch 78/100, Acc=0.9267, Val Loss=0.2453, lr=0.0010 +[07/17 15:39:33] cifar10-global-group_sl-resnet56 INFO: Epoch 79/100, Acc=0.9250, Val Loss=0.2691, lr=0.0010 +[07/17 15:40:09] cifar10-global-group_sl-resnet56 INFO: Epoch 80/100, Acc=0.9337, Val Loss=0.2275, lr=0.0001 +[07/17 15:40:45] cifar10-global-group_sl-resnet56 INFO: Epoch 81/100, Acc=0.9344, Val Loss=0.2258, lr=0.0001 +[07/17 15:41:20] cifar10-global-group_sl-resnet56 INFO: Epoch 82/100, Acc=0.9345, Val Loss=0.2274, lr=0.0001 +[07/17 15:41:57] cifar10-global-group_sl-resnet56 INFO: Epoch 83/100, Acc=0.9326, Val Loss=0.2325, lr=0.0001 +[07/17 15:42:33] cifar10-global-group_sl-resnet56 INFO: Epoch 84/100, Acc=0.9350, Val Loss=0.2271, lr=0.0001 +[07/17 15:43:09] cifar10-global-group_sl-resnet56 INFO: Epoch 85/100, Acc=0.9347, Val Loss=0.2293, lr=0.0001 +[07/17 15:43:45] cifar10-global-group_sl-resnet56 INFO: Epoch 86/100, Acc=0.9360, Val Loss=0.2296, lr=0.0001 +[07/17 15:44:21] cifar10-global-group_sl-resnet56 INFO: Epoch 87/100, Acc=0.9355, Val Loss=0.2312, lr=0.0001 +[07/17 15:44:56] cifar10-global-group_sl-resnet56 INFO: Epoch 88/100, Acc=0.9347, Val Loss=0.2336, lr=0.0001 +[07/17 15:45:32] cifar10-global-group_sl-resnet56 INFO: Epoch 89/100, Acc=0.9336, Val Loss=0.2346, lr=0.0001 +[07/17 15:46:08] cifar10-global-group_sl-resnet56 INFO: Epoch 90/100, Acc=0.9347, Val Loss=0.2335, lr=0.0001 +[07/17 15:46:43] cifar10-global-group_sl-resnet56 INFO: Epoch 91/100, Acc=0.9349, Val Loss=0.2345, lr=0.0001 +[07/17 15:47:20] cifar10-global-group_sl-resnet56 INFO: Epoch 92/100, Acc=0.9339, Val Loss=0.2367, lr=0.0001 +[07/17 15:47:55] cifar10-global-group_sl-resnet56 INFO: Epoch 93/100, Acc=0.9351, Val Loss=0.2379, lr=0.0001 +[07/17 15:48:31] cifar10-global-group_sl-resnet56 INFO: Epoch 94/100, Acc=0.9348, Val Loss=0.2363, lr=0.0001 +[07/17 15:49:06] cifar10-global-group_sl-resnet56 INFO: Epoch 95/100, Acc=0.9353, Val Loss=0.2401, lr=0.0001 +[07/17 15:49:42] cifar10-global-group_sl-resnet56 INFO: Epoch 96/100, Acc=0.9345, Val Loss=0.2397, lr=0.0001 +[07/17 15:50:18] cifar10-global-group_sl-resnet56 INFO: Epoch 97/100, Acc=0.9346, Val Loss=0.2436, lr=0.0001 +[07/17 15:50:53] cifar10-global-group_sl-resnet56 INFO: Epoch 98/100, Acc=0.9345, Val Loss=0.2444, lr=0.0001 +[07/17 15:51:29] cifar10-global-group_sl-resnet56 INFO: Epoch 99/100, Acc=0.9332, Val Loss=0.2426, lr=0.0001 +[07/17 15:51:29] cifar10-global-group_sl-resnet56 INFO: Best Acc=0.9360 +[07/17 15:51:29] cifar10-global-group_sl-resnet56 INFO: Loading the sparse model from run/cifar10/prune/cifar10-global-group_sl-resnet56/reg_cifar10_resnet56_group_sl_0.0005.pth... +[07/17 15:51:30] cifar10-global-group_sl-resnet56 INFO: Pruning... +[07/17 15:51:40] cifar10-global-group_sl-resnet56 INFO: ResNet( + (conv1): Conv2d(3, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (layer1): Sequential( + (0): BasicBlock( + (conv1): Conv2d(12, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(9, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (1): BasicBlock( + (conv1): Conv2d(12, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(5, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(12, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(7, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(12, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(8, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(12, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(7, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(12, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(10, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(12, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(4, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(12, 11, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(11, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(12, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(7, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (layer2): Sequential( + (0): BasicBlock( + (conv1): Conv2d(12, 28, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(28, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(12, 30, kernel_size=(1, 1), stride=(2, 2), bias=False) + (1): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (1): BasicBlock( + (conv1): Conv2d(30, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(9, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(30, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(24, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(30, 26, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(26, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(26, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(30, 23, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(23, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(23, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(30, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(5, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(30, 11, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(11, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(30, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(6, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(30, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(7, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (layer3): Sequential( + (0): BasicBlock( + (conv1): Conv2d(30, 61, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(61, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(30, 30, kernel_size=(1, 1), stride=(2, 2), bias=False) + (1): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (1): BasicBlock( + (conv1): Conv2d(30, 54, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(54, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(54, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(30, 53, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(53, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(53, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(30, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(44, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(30, 56, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(56, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(30, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(44, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(30, 34, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(34, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(34, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(30, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(40, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(30, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(51, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0) + (fc): Linear(in_features=30, out_features=10, bias=True) +) +[07/17 15:51:41] cifar10-global-group_sl-resnet56 INFO: Params: 0.86 M => 0.33 M (38.07%) +[07/17 15:51:41] cifar10-global-group_sl-resnet56 INFO: FLOPs: 127.12 M => 49.53 M (38.97%, 2.57X ) +[07/17 15:51:41] cifar10-global-group_sl-resnet56 INFO: Acc: 0.9360 => 0.8107 +[07/17 15:51:41] cifar10-global-group_sl-resnet56 INFO: Val Loss: 0.2296 => 0.7205 +[07/17 15:51:41] cifar10-global-group_sl-resnet56 INFO: Finetuning... +[07/17 15:51:57] cifar10-global-group_sl-resnet56 INFO: Epoch 0/100, Acc=0.8335, Val Loss=0.4876, lr=0.0100 +[07/17 15:52:13] cifar10-global-group_sl-resnet56 INFO: Epoch 1/100, Acc=0.8719, Val Loss=0.3838, lr=0.0100 +[07/17 15:52:31] cifar10-global-group_sl-resnet56 INFO: Epoch 2/100, Acc=0.8684, Val Loss=0.3967, lr=0.0100 +[07/17 15:52:47] cifar10-global-group_sl-resnet56 INFO: Epoch 3/100, Acc=0.8409, Val Loss=0.4876, lr=0.0100 +[07/17 15:53:03] cifar10-global-group_sl-resnet56 INFO: Epoch 4/100, Acc=0.8854, Val Loss=0.3419, lr=0.0100 +[07/17 15:53:20] cifar10-global-group_sl-resnet56 INFO: Epoch 5/100, Acc=0.8803, Val Loss=0.3719, lr=0.0100 +[07/17 15:53:37] cifar10-global-group_sl-resnet56 INFO: Epoch 6/100, Acc=0.8971, Val Loss=0.3194, lr=0.0100 +[07/17 15:53:53] cifar10-global-group_sl-resnet56 INFO: Epoch 7/100, Acc=0.8770, Val Loss=0.3896, lr=0.0100 +[07/17 15:54:10] cifar10-global-group_sl-resnet56 INFO: Epoch 8/100, Acc=0.8884, Val Loss=0.3338, lr=0.0100 +[07/17 15:54:27] cifar10-global-group_sl-resnet56 INFO: Epoch 9/100, Acc=0.8638, Val Loss=0.4334, lr=0.0100 +[07/17 15:54:43] cifar10-global-group_sl-resnet56 INFO: Epoch 10/100, Acc=0.8684, Val Loss=0.4269, lr=0.0100 +[07/17 15:54:59] cifar10-global-group_sl-resnet56 INFO: Epoch 11/100, Acc=0.8873, Val Loss=0.3557, lr=0.0100 +[07/17 15:55:15] cifar10-global-group_sl-resnet56 INFO: Epoch 12/100, Acc=0.8710, Val Loss=0.4085, lr=0.0100 +[07/17 15:55:32] cifar10-global-group_sl-resnet56 INFO: Epoch 13/100, Acc=0.8988, Val Loss=0.3199, lr=0.0100 +[07/17 15:55:48] cifar10-global-group_sl-resnet56 INFO: Epoch 14/100, Acc=0.9013, Val Loss=0.3043, lr=0.0100 +[07/17 15:56:04] cifar10-global-group_sl-resnet56 INFO: Epoch 15/100, Acc=0.9060, Val Loss=0.2905, lr=0.0100 +[07/17 15:56:20] cifar10-global-group_sl-resnet56 INFO: Epoch 16/100, Acc=0.8745, Val Loss=0.4022, lr=0.0100 +[07/17 15:56:38] cifar10-global-group_sl-resnet56 INFO: Epoch 17/100, Acc=0.8897, Val Loss=0.3483, lr=0.0100 +[07/17 15:56:55] cifar10-global-group_sl-resnet56 INFO: Epoch 18/100, Acc=0.8897, Val Loss=0.3474, lr=0.0100 +[07/17 15:57:11] cifar10-global-group_sl-resnet56 INFO: Epoch 19/100, Acc=0.9052, Val Loss=0.3155, lr=0.0100 +[07/17 15:57:27] cifar10-global-group_sl-resnet56 INFO: Epoch 20/100, Acc=0.8948, Val Loss=0.3271, lr=0.0100 +[07/17 15:57:44] cifar10-global-group_sl-resnet56 INFO: Epoch 21/100, Acc=0.9012, Val Loss=0.3230, lr=0.0100 +[07/17 15:58:01] cifar10-global-group_sl-resnet56 INFO: Epoch 22/100, Acc=0.8832, Val Loss=0.3981, lr=0.0100 +[07/17 15:58:19] cifar10-global-group_sl-resnet56 INFO: Epoch 23/100, Acc=0.8996, Val Loss=0.3197, lr=0.0100 +[07/17 15:58:35] cifar10-global-group_sl-resnet56 INFO: Epoch 24/100, Acc=0.8874, Val Loss=0.3838, lr=0.0100 +[07/17 15:58:52] cifar10-global-group_sl-resnet56 INFO: Epoch 25/100, Acc=0.8954, Val Loss=0.3563, lr=0.0100 +[07/17 15:59:09] cifar10-global-group_sl-resnet56 INFO: Epoch 26/100, Acc=0.8936, Val Loss=0.3487, lr=0.0100 +[07/17 15:59:27] cifar10-global-group_sl-resnet56 INFO: Epoch 27/100, Acc=0.8908, Val Loss=0.3564, lr=0.0100 +[07/17 15:59:44] cifar10-global-group_sl-resnet56 INFO: Epoch 28/100, Acc=0.9010, Val Loss=0.3239, lr=0.0100 +[07/17 16:00:00] cifar10-global-group_sl-resnet56 INFO: Epoch 29/100, Acc=0.8991, Val Loss=0.3185, lr=0.0100 +[07/17 16:00:17] cifar10-global-group_sl-resnet56 INFO: Epoch 30/100, Acc=0.9027, Val Loss=0.3278, lr=0.0100 +[07/17 16:00:35] cifar10-global-group_sl-resnet56 INFO: Epoch 31/100, Acc=0.8962, Val Loss=0.3357, lr=0.0100 +[07/17 16:00:53] cifar10-global-group_sl-resnet56 INFO: Epoch 32/100, Acc=0.9047, Val Loss=0.3065, lr=0.0100 +[07/17 16:01:09] cifar10-global-group_sl-resnet56 INFO: Epoch 33/100, Acc=0.8976, Val Loss=0.3330, lr=0.0100 +[07/17 16:01:28] cifar10-global-group_sl-resnet56 INFO: Epoch 34/100, Acc=0.8948, Val Loss=0.3482, lr=0.0100 +[07/17 16:01:46] cifar10-global-group_sl-resnet56 INFO: Epoch 35/100, Acc=0.8967, Val Loss=0.3458, lr=0.0100 +[07/17 16:02:02] cifar10-global-group_sl-resnet56 INFO: Epoch 36/100, Acc=0.8997, Val Loss=0.3245, lr=0.0100 +[07/17 16:02:20] cifar10-global-group_sl-resnet56 INFO: Epoch 37/100, Acc=0.8890, Val Loss=0.3667, lr=0.0100 +[07/17 16:02:37] cifar10-global-group_sl-resnet56 INFO: Epoch 38/100, Acc=0.9017, Val Loss=0.3153, lr=0.0100 +[07/17 16:02:56] cifar10-global-group_sl-resnet56 INFO: Epoch 39/100, Acc=0.9044, Val Loss=0.3126, lr=0.0100 +[07/17 16:03:13] cifar10-global-group_sl-resnet56 INFO: Epoch 40/100, Acc=0.9029, Val Loss=0.3211, lr=0.0100 +[07/17 16:03:30] cifar10-global-group_sl-resnet56 INFO: Epoch 41/100, Acc=0.9052, Val Loss=0.3116, lr=0.0100 +[07/17 16:03:49] cifar10-global-group_sl-resnet56 INFO: Epoch 42/100, Acc=0.9080, Val Loss=0.3052, lr=0.0100 +[07/17 16:04:08] cifar10-global-group_sl-resnet56 INFO: Epoch 43/100, Acc=0.9068, Val Loss=0.3206, lr=0.0100 +[07/17 16:04:25] cifar10-global-group_sl-resnet56 INFO: Epoch 44/100, Acc=0.8890, Val Loss=0.3719, lr=0.0100 +[07/17 16:04:43] cifar10-global-group_sl-resnet56 INFO: Epoch 45/100, Acc=0.8822, Val Loss=0.4080, lr=0.0100 +[07/17 16:05:00] cifar10-global-group_sl-resnet56 INFO: Epoch 46/100, Acc=0.9019, Val Loss=0.3346, lr=0.0100 +[07/17 16:05:18] cifar10-global-group_sl-resnet56 INFO: Epoch 47/100, Acc=0.8992, Val Loss=0.3328, lr=0.0100 +[07/17 16:05:36] cifar10-global-group_sl-resnet56 INFO: Epoch 48/100, Acc=0.9128, Val Loss=0.2845, lr=0.0100 +[07/17 16:05:55] cifar10-global-group_sl-resnet56 INFO: Epoch 49/100, Acc=0.8810, Val Loss=0.4018, lr=0.0100 +[07/17 16:06:12] cifar10-global-group_sl-resnet56 INFO: Epoch 50/100, Acc=0.9117, Val Loss=0.2994, lr=0.0100 +[07/17 16:06:28] cifar10-global-group_sl-resnet56 INFO: Epoch 51/100, Acc=0.9052, Val Loss=0.3200, lr=0.0100 +[07/17 16:06:46] cifar10-global-group_sl-resnet56 INFO: Epoch 52/100, Acc=0.9105, Val Loss=0.2748, lr=0.0100 +[07/17 16:07:04] cifar10-global-group_sl-resnet56 INFO: Epoch 53/100, Acc=0.8952, Val Loss=0.3515, lr=0.0100 +[07/17 16:07:21] cifar10-global-group_sl-resnet56 INFO: Epoch 54/100, Acc=0.9036, Val Loss=0.3138, lr=0.0100 +[07/17 16:07:37] cifar10-global-group_sl-resnet56 INFO: Epoch 55/100, Acc=0.9034, Val Loss=0.3311, lr=0.0100 +[07/17 16:07:54] cifar10-global-group_sl-resnet56 INFO: Epoch 56/100, Acc=0.8829, Val Loss=0.3994, lr=0.0100 +[07/17 16:08:10] cifar10-global-group_sl-resnet56 INFO: Epoch 57/100, Acc=0.8882, Val Loss=0.3725, lr=0.0100 +[07/17 16:08:27] cifar10-global-group_sl-resnet56 INFO: Epoch 58/100, Acc=0.8950, Val Loss=0.3515, lr=0.0100 +[07/17 16:08:43] cifar10-global-group_sl-resnet56 INFO: Epoch 59/100, Acc=0.8932, Val Loss=0.3444, lr=0.0100 +[07/17 16:09:01] cifar10-global-group_sl-resnet56 INFO: Epoch 60/100, Acc=0.9319, Val Loss=0.2224, lr=0.0010 +[07/17 16:09:19] cifar10-global-group_sl-resnet56 INFO: Epoch 61/100, Acc=0.9337, Val Loss=0.2224, lr=0.0010 +[07/17 16:09:36] cifar10-global-group_sl-resnet56 INFO: Epoch 62/100, Acc=0.9343, Val Loss=0.2194, lr=0.0010 +[07/17 16:09:52] cifar10-global-group_sl-resnet56 INFO: Epoch 63/100, Acc=0.9362, Val Loss=0.2214, lr=0.0010 +[07/17 16:10:08] cifar10-global-group_sl-resnet56 INFO: Epoch 64/100, Acc=0.9363, Val Loss=0.2228, lr=0.0010 +[07/17 16:10:24] cifar10-global-group_sl-resnet56 INFO: Epoch 65/100, Acc=0.9361, Val Loss=0.2229, lr=0.0010 +[07/17 16:10:41] cifar10-global-group_sl-resnet56 INFO: Epoch 66/100, Acc=0.9364, Val Loss=0.2286, lr=0.0010 +[07/17 16:10:57] cifar10-global-group_sl-resnet56 INFO: Epoch 67/100, Acc=0.9373, Val Loss=0.2312, lr=0.0010 +[07/17 16:11:14] cifar10-global-group_sl-resnet56 INFO: Epoch 68/100, Acc=0.9376, Val Loss=0.2299, lr=0.0010 +[07/17 16:11:30] cifar10-global-group_sl-resnet56 INFO: Epoch 69/100, Acc=0.9360, Val Loss=0.2298, lr=0.0010 +[07/17 16:11:46] cifar10-global-group_sl-resnet56 INFO: Epoch 70/100, Acc=0.9365, Val Loss=0.2310, lr=0.0010 +[07/17 16:12:03] cifar10-global-group_sl-resnet56 INFO: Epoch 71/100, Acc=0.9369, Val Loss=0.2341, lr=0.0010 +[07/17 16:12:19] cifar10-global-group_sl-resnet56 INFO: Epoch 72/100, Acc=0.9372, Val Loss=0.2372, lr=0.0010 +[07/17 16:12:35] cifar10-global-group_sl-resnet56 INFO: Epoch 73/100, Acc=0.9370, Val Loss=0.2380, lr=0.0010 +[07/17 16:12:52] cifar10-global-group_sl-resnet56 INFO: Epoch 74/100, Acc=0.9353, Val Loss=0.2424, lr=0.0010 +[07/17 16:13:08] cifar10-global-group_sl-resnet56 INFO: Epoch 75/100, Acc=0.9366, Val Loss=0.2443, lr=0.0010 +[07/17 16:13:24] cifar10-global-group_sl-resnet56 INFO: Epoch 76/100, Acc=0.9357, Val Loss=0.2475, lr=0.0010 +[07/17 16:13:41] cifar10-global-group_sl-resnet56 INFO: Epoch 77/100, Acc=0.9354, Val Loss=0.2482, lr=0.0010 +[07/17 16:13:57] cifar10-global-group_sl-resnet56 INFO: Epoch 78/100, Acc=0.9348, Val Loss=0.2509, lr=0.0010 +[07/17 16:14:13] cifar10-global-group_sl-resnet56 INFO: Epoch 79/100, Acc=0.9354, Val Loss=0.2491, lr=0.0010 +[07/17 16:14:30] cifar10-global-group_sl-resnet56 INFO: Epoch 80/100, Acc=0.9363, Val Loss=0.2453, lr=0.0001 +[07/17 16:14:47] cifar10-global-group_sl-resnet56 INFO: Epoch 81/100, Acc=0.9364, Val Loss=0.2476, lr=0.0001 +[07/17 16:15:04] cifar10-global-group_sl-resnet56 INFO: Epoch 82/100, Acc=0.9362, Val Loss=0.2464, lr=0.0001 +[07/17 16:15:22] cifar10-global-group_sl-resnet56 INFO: Epoch 83/100, Acc=0.9365, Val Loss=0.2448, lr=0.0001 +[07/17 16:15:38] cifar10-global-group_sl-resnet56 INFO: Epoch 84/100, Acc=0.9370, Val Loss=0.2450, lr=0.0001 +[07/17 16:15:55] cifar10-global-group_sl-resnet56 INFO: Epoch 85/100, Acc=0.9366, Val Loss=0.2472, lr=0.0001 +[07/17 16:16:11] cifar10-global-group_sl-resnet56 INFO: Epoch 86/100, Acc=0.9369, Val Loss=0.2455, lr=0.0001 +[07/17 16:16:27] cifar10-global-group_sl-resnet56 INFO: Epoch 87/100, Acc=0.9373, Val Loss=0.2471, lr=0.0001 +[07/17 16:16:43] cifar10-global-group_sl-resnet56 INFO: Epoch 88/100, Acc=0.9371, Val Loss=0.2459, lr=0.0001 +[07/17 16:17:00] cifar10-global-group_sl-resnet56 INFO: Epoch 89/100, Acc=0.9370, Val Loss=0.2457, lr=0.0001 +[07/17 16:17:16] cifar10-global-group_sl-resnet56 INFO: Epoch 90/100, Acc=0.9372, Val Loss=0.2446, lr=0.0001 +[07/17 16:17:32] cifar10-global-group_sl-resnet56 INFO: Epoch 91/100, Acc=0.9373, Val Loss=0.2452, lr=0.0001 +[07/17 16:17:49] cifar10-global-group_sl-resnet56 INFO: Epoch 92/100, Acc=0.9368, Val Loss=0.2473, lr=0.0001 +[07/17 16:18:07] cifar10-global-group_sl-resnet56 INFO: Epoch 93/100, Acc=0.9365, Val Loss=0.2460, lr=0.0001 +[07/17 16:18:24] cifar10-global-group_sl-resnet56 INFO: Epoch 94/100, Acc=0.9376, Val Loss=0.2458, lr=0.0001 +[07/17 16:18:40] cifar10-global-group_sl-resnet56 INFO: Epoch 95/100, Acc=0.9372, Val Loss=0.2484, lr=0.0001 +[07/17 16:18:56] cifar10-global-group_sl-resnet56 INFO: Epoch 96/100, Acc=0.9366, Val Loss=0.2473, lr=0.0001 +[07/17 16:19:13] cifar10-global-group_sl-resnet56 INFO: Epoch 97/100, Acc=0.9376, Val Loss=0.2475, lr=0.0001 +[07/17 16:19:30] cifar10-global-group_sl-resnet56 INFO: Epoch 98/100, Acc=0.9367, Val Loss=0.2481, lr=0.0001 +[07/17 16:19:46] cifar10-global-group_sl-resnet56 INFO: Epoch 99/100, Acc=0.9368, Val Loss=0.2480, lr=0.0001 +[07/17 16:19:46] cifar10-global-group_sl-resnet56 INFO: Best Acc=0.9376 diff --git a/benchmarks/run/cifar10/prune/cifar10-global-growing_reg-resnet56/cifar10-global-growing_reg-resnet56.txt b/benchmarks/run/cifar10/prune/cifar10-global-growing_reg-resnet56/cifar10-global-growing_reg-resnet56.txt new file mode 100644 index 0000000..e7b8b33 --- /dev/null +++ b/benchmarks/run/cifar10/prune/cifar10-global-growing_reg-resnet56/cifar10-global-growing_reg-resnet56.txt @@ -0,0 +1,480 @@ + +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: mode: prune +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: model: resnet56 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: verbose: False +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: dataset: cifar10 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: batch_size: 128 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: total_epochs: 100 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: lr_decay_milestones: 60,80 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: lr_decay_gamma: 0.1 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: lr: 0.01 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: restore: cifar10_resnet56.pth +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: output_dir: run/cifar10/prune/cifar10-global-growing_reg-resnet56 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: method: growing_reg +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: speed_up: 2.11 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: max_sparsity: 1.0 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: soft_keeping_ratio: 0.0 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: reg: 0.0001 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: weight_decay: 0.0005 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: seed: None +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: global_pruning: True +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: sl_total_epochs: 100 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: sl_lr: 0.01 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: sl_lr_decay_milestones: 60,80 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: sl_reg_warmup: 0 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: sl_restore: None +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: iterative_steps: 400 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: logger: +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: device: cuda +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: num_classes: 10 +[07/20 21:09:06] cifar10-global-growing_reg-resnet56 INFO: Loading model from cifar10_resnet56.pth +[07/20 21:09:09] cifar10-global-growing_reg-resnet56 INFO: Regularizing... +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: mode: prune +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: model: resnet56 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: verbose: False +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: dataset: cifar10 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: batch_size: 128 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: total_epochs: 100 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: lr_decay_milestones: 60,80 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: lr_decay_gamma: 0.1 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: lr: 0.01 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: restore: cifar10_resnet56.pth +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: output_dir: run/cifar10/prune/cifar10-global-growing_reg-resnet56 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: method: growing_reg +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: speed_up: 2.11 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: max_sparsity: 1.0 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: soft_keeping_ratio: 0.0 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: reg: 0.0001 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: weight_decay: 0.0005 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: seed: None +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: global_pruning: True +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: sl_total_epochs: 100 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: sl_lr: 0.01 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: sl_lr_decay_milestones: 60,80 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: sl_reg_warmup: 0 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: sl_restore: None +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: iterative_steps: 400 +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: logger: +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: device: cuda +[07/20 21:10:13] cifar10-global-growing_reg-resnet56 INFO: num_classes: 10 +[07/20 21:10:14] cifar10-global-growing_reg-resnet56 INFO: Loading model from cifar10_resnet56.pth +[07/20 21:10:17] cifar10-global-growing_reg-resnet56 INFO: Regularizing... +[07/20 21:10:53] cifar10-global-growing_reg-resnet56 INFO: Epoch 0/100, Acc=0.9010, Val Loss=0.3507, lr=0.0100 +[07/20 21:11:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 1/100, Acc=0.9039, Val Loss=0.3278, lr=0.0100 +[07/20 21:12:04] cifar10-global-growing_reg-resnet56 INFO: Epoch 2/100, Acc=0.8999, Val Loss=0.3665, lr=0.0100 +[07/20 21:12:40] cifar10-global-growing_reg-resnet56 INFO: Epoch 3/100, Acc=0.9076, Val Loss=0.3132, lr=0.0100 +[07/20 21:13:16] cifar10-global-growing_reg-resnet56 INFO: Epoch 4/100, Acc=0.9113, Val Loss=0.3128, lr=0.0100 +[07/20 21:13:51] cifar10-global-growing_reg-resnet56 INFO: Epoch 5/100, Acc=0.9133, Val Loss=0.3209, lr=0.0100 +[07/20 21:14:27] cifar10-global-growing_reg-resnet56 INFO: Epoch 6/100, Acc=0.9123, Val Loss=0.3094, lr=0.0100 +[07/20 21:15:03] cifar10-global-growing_reg-resnet56 INFO: Epoch 7/100, Acc=0.9141, Val Loss=0.3073, lr=0.0100 +[07/20 21:15:39] cifar10-global-growing_reg-resnet56 INFO: Epoch 8/100, Acc=0.9118, Val Loss=0.3245, lr=0.0100 +[07/20 21:16:15] cifar10-global-growing_reg-resnet56 INFO: Epoch 9/100, Acc=0.9162, Val Loss=0.3063, lr=0.0100 +[07/20 21:16:51] cifar10-global-growing_reg-resnet56 INFO: Epoch 10/100, Acc=0.9062, Val Loss=0.3562, lr=0.0100 +[07/20 21:17:26] cifar10-global-growing_reg-resnet56 INFO: Epoch 11/100, Acc=0.9039, Val Loss=0.3561, lr=0.0100 +[07/20 21:18:02] cifar10-global-growing_reg-resnet56 INFO: Epoch 12/100, Acc=0.9044, Val Loss=0.3435, lr=0.0100 +[07/20 21:18:37] cifar10-global-growing_reg-resnet56 INFO: Epoch 13/100, Acc=0.9089, Val Loss=0.3410, lr=0.0100 +[07/20 21:19:13] cifar10-global-growing_reg-resnet56 INFO: Epoch 14/100, Acc=0.9157, Val Loss=0.3082, lr=0.0100 +[07/20 21:19:49] cifar10-global-growing_reg-resnet56 INFO: Epoch 15/100, Acc=0.9140, Val Loss=0.3172, lr=0.0100 +[07/20 21:20:24] cifar10-global-growing_reg-resnet56 INFO: Epoch 16/100, Acc=0.9170, Val Loss=0.3085, lr=0.0100 +[07/20 21:21:00] cifar10-global-growing_reg-resnet56 INFO: Epoch 17/100, Acc=0.9102, Val Loss=0.3544, lr=0.0100 +[07/20 21:21:35] cifar10-global-growing_reg-resnet56 INFO: Epoch 18/100, Acc=0.9189, Val Loss=0.3087, lr=0.0100 +[07/20 21:22:11] cifar10-global-growing_reg-resnet56 INFO: Epoch 19/100, Acc=0.9144, Val Loss=0.3258, lr=0.0100 +[07/20 21:22:46] cifar10-global-growing_reg-resnet56 INFO: Epoch 20/100, Acc=0.9185, Val Loss=0.3099, lr=0.0100 +[07/20 21:23:21] cifar10-global-growing_reg-resnet56 INFO: Epoch 21/100, Acc=0.9083, Val Loss=0.3292, lr=0.0100 +[07/20 21:23:57] cifar10-global-growing_reg-resnet56 INFO: Epoch 22/100, Acc=0.9199, Val Loss=0.3005, lr=0.0100 +[07/20 21:24:33] cifar10-global-growing_reg-resnet56 INFO: Epoch 23/100, Acc=0.9140, Val Loss=0.2940, lr=0.0100 +[07/20 21:25:09] cifar10-global-growing_reg-resnet56 INFO: Epoch 24/100, Acc=0.9101, Val Loss=0.3384, lr=0.0100 +[07/20 21:25:45] cifar10-global-growing_reg-resnet56 INFO: Epoch 25/100, Acc=0.9116, Val Loss=0.3313, lr=0.0100 +[07/20 21:26:20] cifar10-global-growing_reg-resnet56 INFO: Epoch 26/100, Acc=0.9072, Val Loss=0.3519, lr=0.0100 +[07/20 21:26:55] cifar10-global-growing_reg-resnet56 INFO: Epoch 27/100, Acc=0.9132, Val Loss=0.3279, lr=0.0100 +[07/20 21:27:31] cifar10-global-growing_reg-resnet56 INFO: Epoch 28/100, Acc=0.9197, Val Loss=0.2922, lr=0.0100 +[07/20 21:28:07] cifar10-global-growing_reg-resnet56 INFO: Epoch 29/100, Acc=0.9116, Val Loss=0.3354, lr=0.0100 +[07/20 21:28:42] cifar10-global-growing_reg-resnet56 INFO: Epoch 30/100, Acc=0.9092, Val Loss=0.3516, lr=0.0100 +[07/20 21:29:18] cifar10-global-growing_reg-resnet56 INFO: Epoch 31/100, Acc=0.9103, Val Loss=0.3241, lr=0.0100 +[07/20 21:29:54] cifar10-global-growing_reg-resnet56 INFO: Epoch 32/100, Acc=0.8996, Val Loss=0.3969, lr=0.0100 +[07/20 21:30:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 33/100, Acc=0.9136, Val Loss=0.3408, lr=0.0100 +[07/20 21:31:05] cifar10-global-growing_reg-resnet56 INFO: Epoch 34/100, Acc=0.9086, Val Loss=0.3312, lr=0.0100 +[07/20 21:31:41] cifar10-global-growing_reg-resnet56 INFO: Epoch 35/100, Acc=0.9108, Val Loss=0.3159, lr=0.0100 +[07/20 21:32:17] cifar10-global-growing_reg-resnet56 INFO: Epoch 36/100, Acc=0.9078, Val Loss=0.3389, lr=0.0100 +[07/20 21:32:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 37/100, Acc=0.9117, Val Loss=0.3345, lr=0.0100 +[07/20 21:33:28] cifar10-global-growing_reg-resnet56 INFO: Epoch 38/100, Acc=0.9099, Val Loss=0.3437, lr=0.0100 +[07/20 21:34:03] cifar10-global-growing_reg-resnet56 INFO: Epoch 39/100, Acc=0.9146, Val Loss=0.3077, lr=0.0100 +[07/20 21:34:39] cifar10-global-growing_reg-resnet56 INFO: Epoch 40/100, Acc=0.9145, Val Loss=0.3083, lr=0.0100 +[07/20 21:35:15] cifar10-global-growing_reg-resnet56 INFO: Epoch 41/100, Acc=0.9034, Val Loss=0.3646, lr=0.0100 +[07/20 21:35:51] cifar10-global-growing_reg-resnet56 INFO: Epoch 42/100, Acc=0.9017, Val Loss=0.3844, lr=0.0100 +[07/20 21:36:27] cifar10-global-growing_reg-resnet56 INFO: Epoch 43/100, Acc=0.8897, Val Loss=0.4140, lr=0.0100 +[07/20 21:37:02] cifar10-global-growing_reg-resnet56 INFO: Epoch 44/100, Acc=0.9150, Val Loss=0.3001, lr=0.0100 +[07/20 21:37:38] cifar10-global-growing_reg-resnet56 INFO: Epoch 45/100, Acc=0.9061, Val Loss=0.3456, lr=0.0100 +[07/20 21:38:14] cifar10-global-growing_reg-resnet56 INFO: Epoch 46/100, Acc=0.9114, Val Loss=0.3167, lr=0.0100 +[07/20 21:38:49] cifar10-global-growing_reg-resnet56 INFO: Epoch 47/100, Acc=0.9116, Val Loss=0.3295, lr=0.0100 +[07/20 21:39:24] cifar10-global-growing_reg-resnet56 INFO: Epoch 48/100, Acc=0.8981, Val Loss=0.3903, lr=0.0100 +[07/20 21:40:00] cifar10-global-growing_reg-resnet56 INFO: Epoch 49/100, Acc=0.9068, Val Loss=0.3436, lr=0.0100 +[07/20 21:40:35] cifar10-global-growing_reg-resnet56 INFO: Epoch 50/100, Acc=0.9032, Val Loss=0.3658, lr=0.0100 +[07/20 21:41:11] cifar10-global-growing_reg-resnet56 INFO: Epoch 51/100, Acc=0.9038, Val Loss=0.3526, lr=0.0100 +[07/20 21:41:47] cifar10-global-growing_reg-resnet56 INFO: Epoch 52/100, Acc=0.9050, Val Loss=0.3494, lr=0.0100 +[07/20 21:42:22] cifar10-global-growing_reg-resnet56 INFO: Epoch 53/100, Acc=0.9040, Val Loss=0.3419, lr=0.0100 +[07/20 21:42:58] cifar10-global-growing_reg-resnet56 INFO: Epoch 54/100, Acc=0.8939, Val Loss=0.4083, lr=0.0100 +[07/20 21:43:33] cifar10-global-growing_reg-resnet56 INFO: Epoch 55/100, Acc=0.8925, Val Loss=0.3983, lr=0.0100 +[07/20 21:44:09] cifar10-global-growing_reg-resnet56 INFO: Epoch 56/100, Acc=0.9048, Val Loss=0.3488, lr=0.0100 +[07/20 21:44:44] cifar10-global-growing_reg-resnet56 INFO: Epoch 57/100, Acc=0.9016, Val Loss=0.3640, lr=0.0100 +[07/20 21:45:19] cifar10-global-growing_reg-resnet56 INFO: Epoch 58/100, Acc=0.9033, Val Loss=0.3388, lr=0.0100 +[07/20 21:45:54] cifar10-global-growing_reg-resnet56 INFO: Epoch 59/100, Acc=0.9051, Val Loss=0.3489, lr=0.0100 +[07/20 21:46:30] cifar10-global-growing_reg-resnet56 INFO: Epoch 60/100, Acc=0.9319, Val Loss=0.2429, lr=0.0010 +[07/20 21:47:06] cifar10-global-growing_reg-resnet56 INFO: Epoch 61/100, Acc=0.9327, Val Loss=0.2444, lr=0.0010 +[07/20 21:47:42] cifar10-global-growing_reg-resnet56 INFO: Epoch 62/100, Acc=0.9320, Val Loss=0.2437, lr=0.0010 +[07/20 21:48:17] cifar10-global-growing_reg-resnet56 INFO: Epoch 63/100, Acc=0.9331, Val Loss=0.2454, lr=0.0010 +[07/20 21:48:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 64/100, Acc=0.9325, Val Loss=0.2469, lr=0.0010 +[07/20 21:49:28] cifar10-global-growing_reg-resnet56 INFO: Epoch 65/100, Acc=0.9337, Val Loss=0.2493, lr=0.0010 +[07/20 21:50:04] cifar10-global-growing_reg-resnet56 INFO: Epoch 66/100, Acc=0.9334, Val Loss=0.2554, lr=0.0010 +[07/20 21:50:40] cifar10-global-growing_reg-resnet56 INFO: Epoch 67/100, Acc=0.9339, Val Loss=0.2518, lr=0.0010 +[07/20 21:51:16] cifar10-global-growing_reg-resnet56 INFO: Epoch 68/100, Acc=0.9342, Val Loss=0.2507, lr=0.0010 +[07/20 21:51:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 69/100, Acc=0.9340, Val Loss=0.2549, lr=0.0010 +[07/20 21:52:28] cifar10-global-growing_reg-resnet56 INFO: Epoch 70/100, Acc=0.9345, Val Loss=0.2563, lr=0.0010 +[07/20 21:53:03] cifar10-global-growing_reg-resnet56 INFO: Epoch 71/100, Acc=0.9349, Val Loss=0.2573, lr=0.0010 +[07/20 21:53:39] cifar10-global-growing_reg-resnet56 INFO: Epoch 72/100, Acc=0.9330, Val Loss=0.2588, lr=0.0010 +[07/20 21:54:14] cifar10-global-growing_reg-resnet56 INFO: Epoch 73/100, Acc=0.9350, Val Loss=0.2572, lr=0.0010 +[07/20 21:54:50] cifar10-global-growing_reg-resnet56 INFO: Epoch 74/100, Acc=0.9353, Val Loss=0.2575, lr=0.0010 +[07/20 21:55:26] cifar10-global-growing_reg-resnet56 INFO: Epoch 75/100, Acc=0.9353, Val Loss=0.2614, lr=0.0010 +[07/20 21:56:01] cifar10-global-growing_reg-resnet56 INFO: Epoch 76/100, Acc=0.9358, Val Loss=0.2605, lr=0.0010 +[07/20 21:56:37] cifar10-global-growing_reg-resnet56 INFO: Epoch 77/100, Acc=0.9345, Val Loss=0.2613, lr=0.0010 +[07/20 21:57:14] cifar10-global-growing_reg-resnet56 INFO: Epoch 78/100, Acc=0.9345, Val Loss=0.2645, lr=0.0010 +[07/20 21:57:49] cifar10-global-growing_reg-resnet56 INFO: Epoch 79/100, Acc=0.9343, Val Loss=0.2672, lr=0.0010 +[07/20 21:58:25] cifar10-global-growing_reg-resnet56 INFO: Epoch 80/100, Acc=0.9362, Val Loss=0.2660, lr=0.0001 +[07/20 21:59:00] cifar10-global-growing_reg-resnet56 INFO: Epoch 81/100, Acc=0.9356, Val Loss=0.2629, lr=0.0001 +[07/20 21:59:35] cifar10-global-growing_reg-resnet56 INFO: Epoch 82/100, Acc=0.9352, Val Loss=0.2645, lr=0.0001 +[07/20 22:00:10] cifar10-global-growing_reg-resnet56 INFO: Epoch 83/100, Acc=0.9344, Val Loss=0.2652, lr=0.0001 +[07/20 22:00:46] cifar10-global-growing_reg-resnet56 INFO: Epoch 84/100, Acc=0.9349, Val Loss=0.2621, lr=0.0001 +[07/20 22:01:22] cifar10-global-growing_reg-resnet56 INFO: Epoch 85/100, Acc=0.9351, Val Loss=0.2631, lr=0.0001 +[07/20 22:01:57] cifar10-global-growing_reg-resnet56 INFO: Epoch 86/100, Acc=0.9355, Val Loss=0.2613, lr=0.0001 +[07/20 22:02:33] cifar10-global-growing_reg-resnet56 INFO: Epoch 87/100, Acc=0.9355, Val Loss=0.2625, lr=0.0001 +[07/20 22:03:08] cifar10-global-growing_reg-resnet56 INFO: Epoch 88/100, Acc=0.9356, Val Loss=0.2623, lr=0.0001 +[07/20 22:03:43] cifar10-global-growing_reg-resnet56 INFO: Epoch 89/100, Acc=0.9353, Val Loss=0.2616, lr=0.0001 +[07/20 22:04:18] cifar10-global-growing_reg-resnet56 INFO: Epoch 90/100, Acc=0.9357, Val Loss=0.2628, lr=0.0001 +[07/20 22:04:54] cifar10-global-growing_reg-resnet56 INFO: Epoch 91/100, Acc=0.9353, Val Loss=0.2622, lr=0.0001 +[07/20 22:05:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 92/100, Acc=0.9352, Val Loss=0.2619, lr=0.0001 +[07/20 22:06:05] cifar10-global-growing_reg-resnet56 INFO: Epoch 93/100, Acc=0.9366, Val Loss=0.2623, lr=0.0001 +[07/20 22:06:40] cifar10-global-growing_reg-resnet56 INFO: Epoch 94/100, Acc=0.9365, Val Loss=0.2627, lr=0.0001 +[07/20 22:07:16] cifar10-global-growing_reg-resnet56 INFO: Epoch 95/100, Acc=0.9364, Val Loss=0.2622, lr=0.0001 +[07/20 22:07:51] cifar10-global-growing_reg-resnet56 INFO: Epoch 96/100, Acc=0.9367, Val Loss=0.2625, lr=0.0001 +[07/20 22:08:27] cifar10-global-growing_reg-resnet56 INFO: Epoch 97/100, Acc=0.9362, Val Loss=0.2623, lr=0.0001 +[07/20 22:09:02] cifar10-global-growing_reg-resnet56 INFO: Epoch 98/100, Acc=0.9360, Val Loss=0.2623, lr=0.0001 +[07/20 22:09:38] cifar10-global-growing_reg-resnet56 INFO: Epoch 99/100, Acc=0.9366, Val Loss=0.2626, lr=0.0001 +[07/20 22:09:38] cifar10-global-growing_reg-resnet56 INFO: Best Acc=0.9367 +[07/20 22:09:38] cifar10-global-growing_reg-resnet56 INFO: Loading the sparse model from run/cifar10/prune/cifar10-global-growing_reg-resnet56/reg_cifar10_resnet56_growing_reg_0.0001.pth... +[07/20 22:09:39] cifar10-global-growing_reg-resnet56 INFO: Pruning... +[07/20 22:09:46] cifar10-global-growing_reg-resnet56 INFO: ResNet( + (conv1): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (layer1): Sequential( + (0): BasicBlock( + (conv1): Conv2d(5, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(8, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (1): BasicBlock( + (conv1): Conv2d(5, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(8, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(5, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(9, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(5, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(12, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(5, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(8, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(5, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(7, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (layer2): Sequential( + (0): BasicBlock( + (conv1): Conv2d(5, 29, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(29, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(29, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(5, 31, kernel_size=(1, 1), stride=(2, 2), bias=False) + (1): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (1): BasicBlock( + (conv1): Conv2d(31, 13, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(13, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(31, 26, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(26, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(26, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(31, 26, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(26, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(26, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(31, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(24, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(31, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(3, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(31, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(12, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(31, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(5, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(31, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(4, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (layer3): Sequential( + (0): BasicBlock( + (conv1): Conv2d(31, 63, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(63, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(63, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(31, 58, kernel_size=(1, 1), stride=(2, 2), bias=False) + (1): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (1): BasicBlock( + (conv1): Conv2d(58, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(62, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(62, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(58, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(58, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(58, 53, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(53, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(53, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(58, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(62, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(62, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(58, 60, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(60, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(58, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(57, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(58, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(51, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(58, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(62, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(62, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0) + (fc): Linear(in_features=58, out_features=10, bias=True) +) +[07/20 22:09:47] cifar10-global-growing_reg-resnet56 INFO: Params: 0.86 M => 0.62 M (72.50%) +[07/20 22:09:47] cifar10-global-growing_reg-resnet56 INFO: FLOPs: 127.12 M => 60.08 M (47.26%, 2.12X ) +[07/20 22:09:47] cifar10-global-growing_reg-resnet56 INFO: Acc: 0.9367 => 0.1056 +[07/20 22:09:47] cifar10-global-growing_reg-resnet56 INFO: Val Loss: 0.2625 => 14.5350 +[07/20 22:09:47] cifar10-global-growing_reg-resnet56 INFO: Finetuning... +[07/20 22:10:03] cifar10-global-growing_reg-resnet56 INFO: Epoch 0/100, Acc=0.8464, Val Loss=0.4670, lr=0.0100 +[07/20 22:10:19] cifar10-global-growing_reg-resnet56 INFO: Epoch 1/100, Acc=0.8504, Val Loss=0.4879, lr=0.0100 +[07/20 22:10:36] cifar10-global-growing_reg-resnet56 INFO: Epoch 2/100, Acc=0.8790, Val Loss=0.3792, lr=0.0100 +[07/20 22:10:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 3/100, Acc=0.8671, Val Loss=0.4206, lr=0.0100 +[07/20 22:11:08] cifar10-global-growing_reg-resnet56 INFO: Epoch 4/100, Acc=0.8834, Val Loss=0.3751, lr=0.0100 +[07/20 22:11:25] cifar10-global-growing_reg-resnet56 INFO: Epoch 5/100, Acc=0.8817, Val Loss=0.3690, lr=0.0100 +[07/20 22:11:41] cifar10-global-growing_reg-resnet56 INFO: Epoch 6/100, Acc=0.8847, Val Loss=0.3741, lr=0.0100 +[07/20 22:11:59] cifar10-global-growing_reg-resnet56 INFO: Epoch 7/100, Acc=0.8921, Val Loss=0.3625, lr=0.0100 +[07/20 22:12:16] cifar10-global-growing_reg-resnet56 INFO: Epoch 8/100, Acc=0.8901, Val Loss=0.3556, lr=0.0100 +[07/20 22:12:32] cifar10-global-growing_reg-resnet56 INFO: Epoch 9/100, Acc=0.8910, Val Loss=0.3616, lr=0.0100 +[07/20 22:12:49] cifar10-global-growing_reg-resnet56 INFO: Epoch 10/100, Acc=0.8913, Val Loss=0.3662, lr=0.0100 +[07/20 22:13:07] cifar10-global-growing_reg-resnet56 INFO: Epoch 11/100, Acc=0.8976, Val Loss=0.3364, lr=0.0100 +[07/20 22:13:23] cifar10-global-growing_reg-resnet56 INFO: Epoch 12/100, Acc=0.8809, Val Loss=0.3922, lr=0.0100 +[07/20 22:13:40] cifar10-global-growing_reg-resnet56 INFO: Epoch 13/100, Acc=0.8966, Val Loss=0.3360, lr=0.0100 +[07/20 22:13:57] cifar10-global-growing_reg-resnet56 INFO: Epoch 14/100, Acc=0.8971, Val Loss=0.3418, lr=0.0100 +[07/20 22:14:14] cifar10-global-growing_reg-resnet56 INFO: Epoch 15/100, Acc=0.8927, Val Loss=0.3633, lr=0.0100 +[07/20 22:14:30] cifar10-global-growing_reg-resnet56 INFO: Epoch 16/100, Acc=0.9037, Val Loss=0.3227, lr=0.0100 +[07/20 22:14:47] cifar10-global-growing_reg-resnet56 INFO: Epoch 17/100, Acc=0.8930, Val Loss=0.3709, lr=0.0100 +[07/20 22:15:05] cifar10-global-growing_reg-resnet56 INFO: Epoch 18/100, Acc=0.8703, Val Loss=0.4620, lr=0.0100 +[07/20 22:15:22] cifar10-global-growing_reg-resnet56 INFO: Epoch 19/100, Acc=0.8781, Val Loss=0.4193, lr=0.0100 +[07/20 22:15:39] cifar10-global-growing_reg-resnet56 INFO: Epoch 20/100, Acc=0.8894, Val Loss=0.3809, lr=0.0100 +[07/20 22:15:55] cifar10-global-growing_reg-resnet56 INFO: Epoch 21/100, Acc=0.8942, Val Loss=0.3765, lr=0.0100 +[07/20 22:16:12] cifar10-global-growing_reg-resnet56 INFO: Epoch 22/100, Acc=0.9015, Val Loss=0.3425, lr=0.0100 +[07/20 22:16:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 23/100, Acc=0.8988, Val Loss=0.3396, lr=0.0100 +[07/20 22:16:45] cifar10-global-growing_reg-resnet56 INFO: Epoch 24/100, Acc=0.8924, Val Loss=0.3714, lr=0.0100 +[07/20 22:17:02] cifar10-global-growing_reg-resnet56 INFO: Epoch 25/100, Acc=0.8948, Val Loss=0.3539, lr=0.0100 +[07/20 22:17:19] cifar10-global-growing_reg-resnet56 INFO: Epoch 26/100, Acc=0.8965, Val Loss=0.3575, lr=0.0100 +[07/20 22:17:35] cifar10-global-growing_reg-resnet56 INFO: Epoch 27/100, Acc=0.8917, Val Loss=0.3648, lr=0.0100 +[07/20 22:17:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 28/100, Acc=0.8935, Val Loss=0.3799, lr=0.0100 +[07/20 22:18:09] cifar10-global-growing_reg-resnet56 INFO: Epoch 29/100, Acc=0.9025, Val Loss=0.3301, lr=0.0100 +[07/20 22:18:25] cifar10-global-growing_reg-resnet56 INFO: Epoch 30/100, Acc=0.8985, Val Loss=0.3437, lr=0.0100 +[07/20 22:18:41] cifar10-global-growing_reg-resnet56 INFO: Epoch 31/100, Acc=0.8884, Val Loss=0.3965, lr=0.0100 +[07/20 22:18:58] cifar10-global-growing_reg-resnet56 INFO: Epoch 32/100, Acc=0.8928, Val Loss=0.3817, lr=0.0100 +[07/20 22:19:14] cifar10-global-growing_reg-resnet56 INFO: Epoch 33/100, Acc=0.8964, Val Loss=0.3708, lr=0.0100 +[07/20 22:19:31] cifar10-global-growing_reg-resnet56 INFO: Epoch 34/100, Acc=0.8949, Val Loss=0.3630, lr=0.0100 +[07/20 22:19:48] cifar10-global-growing_reg-resnet56 INFO: Epoch 35/100, Acc=0.8965, Val Loss=0.3642, lr=0.0100 +[07/20 22:20:05] cifar10-global-growing_reg-resnet56 INFO: Epoch 36/100, Acc=0.8934, Val Loss=0.3533, lr=0.0100 +[07/20 22:20:21] cifar10-global-growing_reg-resnet56 INFO: Epoch 37/100, Acc=0.8935, Val Loss=0.3790, lr=0.0100 +[07/20 22:20:39] cifar10-global-growing_reg-resnet56 INFO: Epoch 38/100, Acc=0.8914, Val Loss=0.4001, lr=0.0100 +[07/20 22:20:55] cifar10-global-growing_reg-resnet56 INFO: Epoch 39/100, Acc=0.8922, Val Loss=0.3721, lr=0.0100 +[07/20 22:21:12] cifar10-global-growing_reg-resnet56 INFO: Epoch 40/100, Acc=0.8993, Val Loss=0.3483, lr=0.0100 +[07/20 22:21:28] cifar10-global-growing_reg-resnet56 INFO: Epoch 41/100, Acc=0.8911, Val Loss=0.3775, lr=0.0100 +[07/20 22:21:45] cifar10-global-growing_reg-resnet56 INFO: Epoch 42/100, Acc=0.9030, Val Loss=0.3436, lr=0.0100 +[07/20 22:22:01] cifar10-global-growing_reg-resnet56 INFO: Epoch 43/100, Acc=0.9008, Val Loss=0.3399, lr=0.0100 +[07/20 22:22:18] cifar10-global-growing_reg-resnet56 INFO: Epoch 44/100, Acc=0.8807, Val Loss=0.4193, lr=0.0100 +[07/20 22:22:35] cifar10-global-growing_reg-resnet56 INFO: Epoch 45/100, Acc=0.8935, Val Loss=0.3627, lr=0.0100 +[07/20 22:22:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 46/100, Acc=0.8954, Val Loss=0.3573, lr=0.0100 +[07/20 22:23:08] cifar10-global-growing_reg-resnet56 INFO: Epoch 47/100, Acc=0.8960, Val Loss=0.3660, lr=0.0100 +[07/20 22:23:25] cifar10-global-growing_reg-resnet56 INFO: Epoch 48/100, Acc=0.8912, Val Loss=0.3603, lr=0.0100 +[07/20 22:23:42] cifar10-global-growing_reg-resnet56 INFO: Epoch 49/100, Acc=0.8733, Val Loss=0.4392, lr=0.0100 +[07/20 22:24:00] cifar10-global-growing_reg-resnet56 INFO: Epoch 50/100, Acc=0.8937, Val Loss=0.3832, lr=0.0100 +[07/20 22:24:16] cifar10-global-growing_reg-resnet56 INFO: Epoch 51/100, Acc=0.8982, Val Loss=0.3662, lr=0.0100 +[07/20 22:24:34] cifar10-global-growing_reg-resnet56 INFO: Epoch 52/100, Acc=0.8930, Val Loss=0.3685, lr=0.0100 +[07/20 22:24:50] cifar10-global-growing_reg-resnet56 INFO: Epoch 53/100, Acc=0.8999, Val Loss=0.3362, lr=0.0100 +[07/20 22:25:07] cifar10-global-growing_reg-resnet56 INFO: Epoch 54/100, Acc=0.8966, Val Loss=0.3780, lr=0.0100 +[07/20 22:25:26] cifar10-global-growing_reg-resnet56 INFO: Epoch 55/100, Acc=0.9075, Val Loss=0.3295, lr=0.0100 +[07/20 22:25:42] cifar10-global-growing_reg-resnet56 INFO: Epoch 56/100, Acc=0.8924, Val Loss=0.4048, lr=0.0100 +[07/20 22:25:59] cifar10-global-growing_reg-resnet56 INFO: Epoch 57/100, Acc=0.8995, Val Loss=0.3683, lr=0.0100 +[07/20 22:26:16] cifar10-global-growing_reg-resnet56 INFO: Epoch 58/100, Acc=0.8992, Val Loss=0.3402, lr=0.0100 +[07/20 22:26:33] cifar10-global-growing_reg-resnet56 INFO: Epoch 59/100, Acc=0.9032, Val Loss=0.3296, lr=0.0100 +[07/20 22:26:49] cifar10-global-growing_reg-resnet56 INFO: Epoch 60/100, Acc=0.9235, Val Loss=0.2597, lr=0.0010 +[07/20 22:27:06] cifar10-global-growing_reg-resnet56 INFO: Epoch 61/100, Acc=0.9281, Val Loss=0.2566, lr=0.0010 +[07/20 22:27:22] cifar10-global-growing_reg-resnet56 INFO: Epoch 62/100, Acc=0.9302, Val Loss=0.2541, lr=0.0010 +[07/20 22:27:39] cifar10-global-growing_reg-resnet56 INFO: Epoch 63/100, Acc=0.9300, Val Loss=0.2585, lr=0.0010 +[07/20 22:27:56] cifar10-global-growing_reg-resnet56 INFO: Epoch 64/100, Acc=0.9305, Val Loss=0.2566, lr=0.0010 +[07/20 22:28:12] cifar10-global-growing_reg-resnet56 INFO: Epoch 65/100, Acc=0.9303, Val Loss=0.2565, lr=0.0010 +[07/20 22:28:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 66/100, Acc=0.9295, Val Loss=0.2584, lr=0.0010 +[07/20 22:28:45] cifar10-global-growing_reg-resnet56 INFO: Epoch 67/100, Acc=0.9329, Val Loss=0.2582, lr=0.0010 +[07/20 22:29:02] cifar10-global-growing_reg-resnet56 INFO: Epoch 68/100, Acc=0.9316, Val Loss=0.2568, lr=0.0010 +[07/20 22:29:19] cifar10-global-growing_reg-resnet56 INFO: Epoch 69/100, Acc=0.9324, Val Loss=0.2605, lr=0.0010 +[07/20 22:29:35] cifar10-global-growing_reg-resnet56 INFO: Epoch 70/100, Acc=0.9340, Val Loss=0.2591, lr=0.0010 +[07/20 22:29:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 71/100, Acc=0.9325, Val Loss=0.2632, lr=0.0010 +[07/20 22:30:08] cifar10-global-growing_reg-resnet56 INFO: Epoch 72/100, Acc=0.9319, Val Loss=0.2646, lr=0.0010 +[07/20 22:30:25] cifar10-global-growing_reg-resnet56 INFO: Epoch 73/100, Acc=0.9321, Val Loss=0.2641, lr=0.0010 +[07/20 22:30:41] cifar10-global-growing_reg-resnet56 INFO: Epoch 74/100, Acc=0.9324, Val Loss=0.2713, lr=0.0010 +[07/20 22:30:58] cifar10-global-growing_reg-resnet56 INFO: Epoch 75/100, Acc=0.9331, Val Loss=0.2678, lr=0.0010 +[07/20 22:31:15] cifar10-global-growing_reg-resnet56 INFO: Epoch 76/100, Acc=0.9319, Val Loss=0.2693, lr=0.0010 +[07/20 22:31:32] cifar10-global-growing_reg-resnet56 INFO: Epoch 77/100, Acc=0.9331, Val Loss=0.2661, lr=0.0010 +[07/20 22:31:49] cifar10-global-growing_reg-resnet56 INFO: Epoch 78/100, Acc=0.9318, Val Loss=0.2689, lr=0.0010 +[07/20 22:32:05] cifar10-global-growing_reg-resnet56 INFO: Epoch 79/100, Acc=0.9335, Val Loss=0.2651, lr=0.0010 +[07/20 22:32:22] cifar10-global-growing_reg-resnet56 INFO: Epoch 80/100, Acc=0.9339, Val Loss=0.2660, lr=0.0001 +[07/20 22:32:39] cifar10-global-growing_reg-resnet56 INFO: Epoch 81/100, Acc=0.9348, Val Loss=0.2620, lr=0.0001 +[07/20 22:32:56] cifar10-global-growing_reg-resnet56 INFO: Epoch 82/100, Acc=0.9342, Val Loss=0.2629, lr=0.0001 +[07/20 22:33:13] cifar10-global-growing_reg-resnet56 INFO: Epoch 83/100, Acc=0.9354, Val Loss=0.2622, lr=0.0001 +[07/20 22:33:30] cifar10-global-growing_reg-resnet56 INFO: Epoch 84/100, Acc=0.9342, Val Loss=0.2639, lr=0.0001 +[07/20 22:33:47] cifar10-global-growing_reg-resnet56 INFO: Epoch 85/100, Acc=0.9344, Val Loss=0.2632, lr=0.0001 +[07/20 22:34:05] cifar10-global-growing_reg-resnet56 INFO: Epoch 86/100, Acc=0.9355, Val Loss=0.2624, lr=0.0001 +[07/20 22:34:21] cifar10-global-growing_reg-resnet56 INFO: Epoch 87/100, Acc=0.9351, Val Loss=0.2615, lr=0.0001 +[07/20 22:34:37] cifar10-global-growing_reg-resnet56 INFO: Epoch 88/100, Acc=0.9354, Val Loss=0.2646, lr=0.0001 +[07/20 22:34:54] cifar10-global-growing_reg-resnet56 INFO: Epoch 89/100, Acc=0.9349, Val Loss=0.2644, lr=0.0001 +[07/20 22:35:10] cifar10-global-growing_reg-resnet56 INFO: Epoch 90/100, Acc=0.9350, Val Loss=0.2628, lr=0.0001 +[07/20 22:35:27] cifar10-global-growing_reg-resnet56 INFO: Epoch 91/100, Acc=0.9344, Val Loss=0.2633, lr=0.0001 +[07/20 22:35:43] cifar10-global-growing_reg-resnet56 INFO: Epoch 92/100, Acc=0.9349, Val Loss=0.2633, lr=0.0001 +[07/20 22:36:01] cifar10-global-growing_reg-resnet56 INFO: Epoch 93/100, Acc=0.9343, Val Loss=0.2652, lr=0.0001 +[07/20 22:36:17] cifar10-global-growing_reg-resnet56 INFO: Epoch 94/100, Acc=0.9342, Val Loss=0.2658, lr=0.0001 +[07/20 22:36:34] cifar10-global-growing_reg-resnet56 INFO: Epoch 95/100, Acc=0.9347, Val Loss=0.2644, lr=0.0001 +[07/20 22:36:50] cifar10-global-growing_reg-resnet56 INFO: Epoch 96/100, Acc=0.9348, Val Loss=0.2649, lr=0.0001 +[07/20 22:37:07] cifar10-global-growing_reg-resnet56 INFO: Epoch 97/100, Acc=0.9355, Val Loss=0.2640, lr=0.0001 +[07/20 22:37:24] cifar10-global-growing_reg-resnet56 INFO: Epoch 98/100, Acc=0.9354, Val Loss=0.2627, lr=0.0001 +[07/20 22:37:41] cifar10-global-growing_reg-resnet56 INFO: Epoch 99/100, Acc=0.9348, Val Loss=0.2656, lr=0.0001 +[07/20 22:37:41] cifar10-global-growing_reg-resnet56 INFO: Best Acc=0.9355 diff --git a/benchmarks/run/cifar10/prune/cifar10-global-slim-resnet56/cifar10-global-slim-resnet56.txt b/benchmarks/run/cifar10/prune/cifar10-global-slim-resnet56/cifar10-global-slim-resnet56.txt index 6932237..69d4609 100644 --- a/benchmarks/run/cifar10/prune/cifar10-global-slim-resnet56/cifar10-global-slim-resnet56.txt +++ b/benchmarks/run/cifar10/prune/cifar10-global-slim-resnet56/cifar10-global-slim-resnet56.txt @@ -477,3 +477,452 @@ [01/03 18:35:14] cifar10-global-slim-resnet56 INFO: Epoch 98/100, Acc=0.9304, Val Loss=0.2677, lr=0.0001 [01/03 18:35:31] cifar10-global-slim-resnet56 INFO: Epoch 99/100, Acc=0.9317, Val Loss=0.2714, lr=0.0001 [01/03 18:35:31] cifar10-global-slim-resnet56 INFO: Best Acc=0.9329 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: mode: prune +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: model: resnet56 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: verbose: False +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: dataset: cifar10 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: batch_size: 128 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: total_epochs: 100 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: lr_decay_milestones: 60,80 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: lr_decay_gamma: 0.1 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: lr: 0.01 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: restore: cifar10_resnet56.pth +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: output_dir: run/cifar10/prune/cifar10-global-slim-resnet56 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: method: slim +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: speed_up: 2.11 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: max_sparsity: 1.0 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: soft_keeping_ratio: 0.0 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: reg: 0.0001 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: weight_decay: 0.0005 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: seed: None +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: global_pruning: True +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: sl_total_epochs: 100 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: sl_lr: 0.01 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: sl_lr_decay_milestones: 60,80 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: sl_reg_warmup: 0 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: sl_restore: None +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: iterative_steps: 400 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: logger: +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: device: cuda +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: num_classes: 10 +[07/20 15:24:45] cifar10-global-slim-resnet56 INFO: Loading model from cifar10_resnet56.pth +[07/20 15:24:49] cifar10-global-slim-resnet56 INFO: Regularizing... +[07/20 15:25:14] cifar10-global-slim-resnet56 INFO: Epoch 0/100, Acc=0.9088, Val Loss=0.3215, lr=0.0100 +[07/20 15:25:39] cifar10-global-slim-resnet56 INFO: Epoch 1/100, Acc=0.9114, Val Loss=0.3050, lr=0.0100 +[07/20 15:26:03] cifar10-global-slim-resnet56 INFO: Epoch 2/100, Acc=0.9098, Val Loss=0.3070, lr=0.0100 +[07/20 15:26:29] cifar10-global-slim-resnet56 INFO: Epoch 3/100, Acc=0.9183, Val Loss=0.3023, lr=0.0100 +[07/20 15:26:54] cifar10-global-slim-resnet56 INFO: Epoch 4/100, Acc=0.9066, Val Loss=0.3324, lr=0.0100 +[07/20 15:27:19] cifar10-global-slim-resnet56 INFO: Epoch 5/100, Acc=0.9101, Val Loss=0.3055, lr=0.0100 +[07/20 15:27:43] cifar10-global-slim-resnet56 INFO: Epoch 6/100, Acc=0.9141, Val Loss=0.3261, lr=0.0100 +[07/20 15:28:08] cifar10-global-slim-resnet56 INFO: Epoch 7/100, Acc=0.9102, Val Loss=0.3205, lr=0.0100 +[07/20 15:28:33] cifar10-global-slim-resnet56 INFO: Epoch 8/100, Acc=0.9131, Val Loss=0.3259, lr=0.0100 +[07/20 15:28:58] cifar10-global-slim-resnet56 INFO: Epoch 9/100, Acc=0.9129, Val Loss=0.3311, lr=0.0100 +[07/20 15:29:23] cifar10-global-slim-resnet56 INFO: Epoch 10/100, Acc=0.9164, Val Loss=0.3079, lr=0.0100 +[07/20 15:29:47] cifar10-global-slim-resnet56 INFO: Epoch 11/100, Acc=0.9076, Val Loss=0.3582, lr=0.0100 +[07/20 15:30:12] cifar10-global-slim-resnet56 INFO: Epoch 12/100, Acc=0.9134, Val Loss=0.3269, lr=0.0100 +[07/20 15:30:36] cifar10-global-slim-resnet56 INFO: Epoch 13/100, Acc=0.9182, Val Loss=0.3122, lr=0.0100 +[07/20 15:31:01] cifar10-global-slim-resnet56 INFO: Epoch 14/100, Acc=0.9163, Val Loss=0.3232, lr=0.0100 +[07/20 15:31:26] cifar10-global-slim-resnet56 INFO: Epoch 15/100, Acc=0.9136, Val Loss=0.3256, lr=0.0100 +[07/20 15:31:51] cifar10-global-slim-resnet56 INFO: Epoch 16/100, Acc=0.9156, Val Loss=0.3194, lr=0.0100 +[07/20 15:32:15] cifar10-global-slim-resnet56 INFO: Epoch 17/100, Acc=0.9226, Val Loss=0.3104, lr=0.0100 +[07/20 15:32:40] cifar10-global-slim-resnet56 INFO: Epoch 18/100, Acc=0.9159, Val Loss=0.3228, lr=0.0100 +[07/20 15:33:04] cifar10-global-slim-resnet56 INFO: Epoch 19/100, Acc=0.9186, Val Loss=0.2996, lr=0.0100 +[07/20 15:33:30] cifar10-global-slim-resnet56 INFO: Epoch 20/100, Acc=0.9215, Val Loss=0.3191, lr=0.0100 +[07/20 15:33:54] cifar10-global-slim-resnet56 INFO: Epoch 21/100, Acc=0.9215, Val Loss=0.2943, lr=0.0100 +[07/20 15:34:19] cifar10-global-slim-resnet56 INFO: Epoch 22/100, Acc=0.9253, Val Loss=0.2928, lr=0.0100 +[07/20 15:34:44] cifar10-global-slim-resnet56 INFO: Epoch 23/100, Acc=0.9214, Val Loss=0.3091, lr=0.0100 +[07/20 15:35:08] cifar10-global-slim-resnet56 INFO: Epoch 24/100, Acc=0.9148, Val Loss=0.3374, lr=0.0100 +[07/20 15:35:33] cifar10-global-slim-resnet56 INFO: Epoch 25/100, Acc=0.9197, Val Loss=0.3213, lr=0.0100 +[07/20 15:35:58] cifar10-global-slim-resnet56 INFO: Epoch 26/100, Acc=0.9246, Val Loss=0.3028, lr=0.0100 +[07/20 15:36:22] cifar10-global-slim-resnet56 INFO: Epoch 27/100, Acc=0.9154, Val Loss=0.3415, lr=0.0100 +[07/20 15:36:48] cifar10-global-slim-resnet56 INFO: Epoch 28/100, Acc=0.9295, Val Loss=0.2897, lr=0.0100 +[07/20 15:37:12] cifar10-global-slim-resnet56 INFO: Epoch 29/100, Acc=0.9253, Val Loss=0.3089, lr=0.0100 +[07/20 15:37:37] cifar10-global-slim-resnet56 INFO: Epoch 30/100, Acc=0.9173, Val Loss=0.3387, lr=0.0100 +[07/20 15:38:02] cifar10-global-slim-resnet56 INFO: Epoch 31/100, Acc=0.9225, Val Loss=0.3249, lr=0.0100 +[07/20 15:38:26] cifar10-global-slim-resnet56 INFO: Epoch 32/100, Acc=0.9253, Val Loss=0.3102, lr=0.0100 +[07/20 15:38:51] cifar10-global-slim-resnet56 INFO: Epoch 33/100, Acc=0.9253, Val Loss=0.3109, lr=0.0100 +[07/20 15:39:16] cifar10-global-slim-resnet56 INFO: Epoch 34/100, Acc=0.9276, Val Loss=0.3078, lr=0.0100 +[07/20 15:39:40] cifar10-global-slim-resnet56 INFO: Epoch 35/100, Acc=0.9220, Val Loss=0.3240, lr=0.0100 +[07/20 15:40:05] cifar10-global-slim-resnet56 INFO: Epoch 36/100, Acc=0.9252, Val Loss=0.3213, lr=0.0100 +[07/20 15:40:29] cifar10-global-slim-resnet56 INFO: Epoch 37/100, Acc=0.9177, Val Loss=0.3261, lr=0.0100 +[07/20 15:40:54] cifar10-global-slim-resnet56 INFO: Epoch 38/100, Acc=0.9270, Val Loss=0.3065, lr=0.0100 +[07/20 15:41:19] cifar10-global-slim-resnet56 INFO: Epoch 39/100, Acc=0.9251, Val Loss=0.3181, lr=0.0100 +[07/20 15:41:44] cifar10-global-slim-resnet56 INFO: Epoch 40/100, Acc=0.9229, Val Loss=0.3284, lr=0.0100 +[07/20 15:42:09] cifar10-global-slim-resnet56 INFO: Epoch 41/100, Acc=0.9206, Val Loss=0.3486, lr=0.0100 +[07/20 15:42:34] cifar10-global-slim-resnet56 INFO: Epoch 42/100, Acc=0.9238, Val Loss=0.3341, lr=0.0100 +[07/20 15:42:59] cifar10-global-slim-resnet56 INFO: Epoch 43/100, Acc=0.9222, Val Loss=0.3298, lr=0.0100 +[07/20 15:43:24] cifar10-global-slim-resnet56 INFO: Epoch 44/100, Acc=0.9193, Val Loss=0.3465, lr=0.0100 +[07/20 15:43:48] cifar10-global-slim-resnet56 INFO: Epoch 45/100, Acc=0.9251, Val Loss=0.3247, lr=0.0100 +[07/20 15:44:13] cifar10-global-slim-resnet56 INFO: Epoch 46/100, Acc=0.9229, Val Loss=0.3312, lr=0.0100 +[07/20 15:44:38] cifar10-global-slim-resnet56 INFO: Epoch 47/100, Acc=0.9217, Val Loss=0.3416, lr=0.0100 +[07/20 15:45:03] cifar10-global-slim-resnet56 INFO: Epoch 48/100, Acc=0.9260, Val Loss=0.3246, lr=0.0100 +[07/20 15:45:27] cifar10-global-slim-resnet56 INFO: Epoch 49/100, Acc=0.9249, Val Loss=0.3358, lr=0.0100 +[07/20 15:45:52] cifar10-global-slim-resnet56 INFO: Epoch 50/100, Acc=0.9210, Val Loss=0.3448, lr=0.0100 +[07/20 15:46:17] cifar10-global-slim-resnet56 INFO: Epoch 51/100, Acc=0.9266, Val Loss=0.3221, lr=0.0100 +[07/20 15:46:42] cifar10-global-slim-resnet56 INFO: Epoch 52/100, Acc=0.9129, Val Loss=0.3905, lr=0.0100 +[07/20 15:47:07] cifar10-global-slim-resnet56 INFO: Epoch 53/100, Acc=0.9196, Val Loss=0.3450, lr=0.0100 +[07/20 15:47:31] cifar10-global-slim-resnet56 INFO: Epoch 54/100, Acc=0.9285, Val Loss=0.3155, lr=0.0100 +[07/20 15:47:56] cifar10-global-slim-resnet56 INFO: Epoch 55/100, Acc=0.9274, Val Loss=0.3298, lr=0.0100 +[07/20 15:48:21] cifar10-global-slim-resnet56 INFO: Epoch 56/100, Acc=0.9220, Val Loss=0.3644, lr=0.0100 +[07/20 15:48:46] cifar10-global-slim-resnet56 INFO: Epoch 57/100, Acc=0.9254, Val Loss=0.3275, lr=0.0100 +[07/20 15:49:11] cifar10-global-slim-resnet56 INFO: Epoch 58/100, Acc=0.9173, Val Loss=0.3706, lr=0.0100 +[07/20 15:49:36] cifar10-global-slim-resnet56 INFO: Epoch 59/100, Acc=0.9218, Val Loss=0.3564, lr=0.0100 +[07/20 15:50:00] cifar10-global-slim-resnet56 INFO: Epoch 60/100, Acc=0.9316, Val Loss=0.3092, lr=0.0010 +[07/20 15:50:26] cifar10-global-slim-resnet56 INFO: Epoch 61/100, Acc=0.9322, Val Loss=0.3062, lr=0.0010 +[07/20 15:50:51] cifar10-global-slim-resnet56 INFO: Epoch 62/100, Acc=0.9339, Val Loss=0.3026, lr=0.0010 +[07/20 15:51:15] cifar10-global-slim-resnet56 INFO: Epoch 63/100, Acc=0.9339, Val Loss=0.3032, lr=0.0010 +[07/20 15:51:40] cifar10-global-slim-resnet56 INFO: Epoch 64/100, Acc=0.9334, Val Loss=0.3050, lr=0.0010 +[07/20 15:52:04] cifar10-global-slim-resnet56 INFO: Epoch 65/100, Acc=0.9350, Val Loss=0.3049, lr=0.0010 +[07/20 15:52:29] cifar10-global-slim-resnet56 INFO: Epoch 66/100, Acc=0.9357, Val Loss=0.3058, lr=0.0010 +[07/20 15:52:54] cifar10-global-slim-resnet56 INFO: Epoch 67/100, Acc=0.9350, Val Loss=0.3043, lr=0.0010 +[07/20 15:53:19] cifar10-global-slim-resnet56 INFO: Epoch 68/100, Acc=0.9363, Val Loss=0.3048, lr=0.0010 +[07/20 15:53:43] cifar10-global-slim-resnet56 INFO: Epoch 69/100, Acc=0.9358, Val Loss=0.3050, lr=0.0010 +[07/20 15:54:08] cifar10-global-slim-resnet56 INFO: Epoch 70/100, Acc=0.9364, Val Loss=0.3049, lr=0.0010 +[07/20 15:54:33] cifar10-global-slim-resnet56 INFO: Epoch 71/100, Acc=0.9352, Val Loss=0.3066, lr=0.0010 +[07/20 15:54:57] cifar10-global-slim-resnet56 INFO: Epoch 72/100, Acc=0.9361, Val Loss=0.3059, lr=0.0010 +[07/20 15:55:23] cifar10-global-slim-resnet56 INFO: Epoch 73/100, Acc=0.9358, Val Loss=0.3061, lr=0.0010 +[07/20 15:55:48] cifar10-global-slim-resnet56 INFO: Epoch 74/100, Acc=0.9349, Val Loss=0.3074, lr=0.0010 +[07/20 15:56:15] cifar10-global-slim-resnet56 INFO: Epoch 75/100, Acc=0.9360, Val Loss=0.3032, lr=0.0010 +[07/20 15:56:40] cifar10-global-slim-resnet56 INFO: Epoch 76/100, Acc=0.9368, Val Loss=0.3062, lr=0.0010 +[07/20 15:57:08] cifar10-global-slim-resnet56 INFO: Epoch 77/100, Acc=0.9366, Val Loss=0.3018, lr=0.0010 +[07/20 15:57:33] cifar10-global-slim-resnet56 INFO: Epoch 78/100, Acc=0.9361, Val Loss=0.3050, lr=0.0010 +[07/20 15:57:59] cifar10-global-slim-resnet56 INFO: Epoch 79/100, Acc=0.9365, Val Loss=0.3057, lr=0.0010 +[07/20 15:58:25] cifar10-global-slim-resnet56 INFO: Epoch 80/100, Acc=0.9373, Val Loss=0.3081, lr=0.0001 +[07/20 15:58:50] cifar10-global-slim-resnet56 INFO: Epoch 81/100, Acc=0.9362, Val Loss=0.3065, lr=0.0001 +[07/20 15:59:16] cifar10-global-slim-resnet56 INFO: Epoch 82/100, Acc=0.9358, Val Loss=0.3069, lr=0.0001 +[07/20 15:59:43] cifar10-global-slim-resnet56 INFO: Epoch 83/100, Acc=0.9373, Val Loss=0.3061, lr=0.0001 +[07/20 16:00:08] cifar10-global-slim-resnet56 INFO: Epoch 84/100, Acc=0.9372, Val Loss=0.3050, lr=0.0001 +[07/20 16:00:35] cifar10-global-slim-resnet56 INFO: Epoch 85/100, Acc=0.9353, Val Loss=0.3068, lr=0.0001 +[07/20 16:01:02] cifar10-global-slim-resnet56 INFO: Epoch 86/100, Acc=0.9362, Val Loss=0.3048, lr=0.0001 +[07/20 16:01:30] cifar10-global-slim-resnet56 INFO: Epoch 87/100, Acc=0.9369, Val Loss=0.3051, lr=0.0001 +[07/20 16:01:57] cifar10-global-slim-resnet56 INFO: Epoch 88/100, Acc=0.9367, Val Loss=0.3068, lr=0.0001 +[07/20 16:02:22] cifar10-global-slim-resnet56 INFO: Epoch 89/100, Acc=0.9362, Val Loss=0.3056, lr=0.0001 +[07/20 16:02:48] cifar10-global-slim-resnet56 INFO: Epoch 90/100, Acc=0.9364, Val Loss=0.3068, lr=0.0001 +[07/20 16:03:13] cifar10-global-slim-resnet56 INFO: Epoch 91/100, Acc=0.9372, Val Loss=0.3058, lr=0.0001 +[07/20 16:03:40] cifar10-global-slim-resnet56 INFO: Epoch 92/100, Acc=0.9367, Val Loss=0.3044, lr=0.0001 +[07/20 16:04:07] cifar10-global-slim-resnet56 INFO: Epoch 93/100, Acc=0.9356, Val Loss=0.3067, lr=0.0001 +[07/20 16:04:33] cifar10-global-slim-resnet56 INFO: Epoch 94/100, Acc=0.9366, Val Loss=0.3080, lr=0.0001 +[07/20 16:04:58] cifar10-global-slim-resnet56 INFO: Epoch 95/100, Acc=0.9367, Val Loss=0.3025, lr=0.0001 +[07/20 16:05:24] cifar10-global-slim-resnet56 INFO: Epoch 96/100, Acc=0.9357, Val Loss=0.3045, lr=0.0001 +[07/20 16:05:50] cifar10-global-slim-resnet56 INFO: Epoch 97/100, Acc=0.9368, Val Loss=0.3046, lr=0.0001 +[07/20 16:06:16] cifar10-global-slim-resnet56 INFO: Epoch 98/100, Acc=0.9360, Val Loss=0.3069, lr=0.0001 +[07/20 16:06:42] cifar10-global-slim-resnet56 INFO: Epoch 99/100, Acc=0.9357, Val Loss=0.3064, lr=0.0001 +[07/20 16:06:42] cifar10-global-slim-resnet56 INFO: Best Acc=0.9373 +[07/20 16:06:42] cifar10-global-slim-resnet56 INFO: Loading the sparse model from run/cifar10/prune/cifar10-global-slim-resnet56/reg_cifar10_resnet56_slim_0.0001.pth... +[07/20 16:06:44] cifar10-global-slim-resnet56 INFO: Pruning... +[07/20 16:06:50] cifar10-global-slim-resnet56 INFO: ResNet( + (conv1): Conv2d(3, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (layer1): Sequential( + (0): BasicBlock( + (conv1): Conv2d(9, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(12, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (1): BasicBlock( + (conv1): Conv2d(9, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(3, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(9, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(4, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(9, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(5, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(9, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(12, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(9, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(7, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(9, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(9, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(9, 15, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(15, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(9, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(8, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (layer2): Sequential( + (0): BasicBlock( + (conv1): Conv2d(9, 30, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(30, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(9, 31, kernel_size=(1, 1), stride=(2, 2), bias=False) + (1): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (1): BasicBlock( + (conv1): Conv2d(31, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(16, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(31, 28, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(28, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(31, 27, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(27, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(27, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(31, 27, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(27, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(27, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(31, 11, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(11, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(31, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(10, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(31, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(9, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(31, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(4, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(31, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (layer3): Sequential( + (0): BasicBlock( + (conv1): Conv2d(31, 61, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(61, 41, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(31, 41, kernel_size=(1, 1), stride=(2, 2), bias=False) + (1): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (1): BasicBlock( + (conv1): Conv2d(41, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(62, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(62, 41, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(41, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(57, 41, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(41, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(57, 41, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(41, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(58, 41, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(41, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(57, 41, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(41, 53, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(53, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(53, 41, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(41, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(50, 41, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(41, 53, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(53, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(53, 41, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(41, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0) + (fc): Linear(in_features=41, out_features=10, bias=True) +) +[07/20 16:06:51] cifar10-global-slim-resnet56 INFO: Params: 0.86 M => 0.47 M (55.08%) +[07/20 16:06:51] cifar10-global-slim-resnet56 INFO: FLOPs: 127.12 M => 58.97 M (46.39%, 2.16X ) +[07/20 16:06:51] cifar10-global-slim-resnet56 INFO: Acc: 0.9373 => 0.0968 +[07/20 16:06:51] cifar10-global-slim-resnet56 INFO: Val Loss: 0.3081 => 3.2124 +[07/20 16:06:51] cifar10-global-slim-resnet56 INFO: Finetuning... +[07/20 16:07:09] cifar10-global-slim-resnet56 INFO: Epoch 0/100, Acc=0.8456, Val Loss=0.4697, lr=0.0100 +[07/20 16:07:27] cifar10-global-slim-resnet56 INFO: Epoch 1/100, Acc=0.8797, Val Loss=0.3772, lr=0.0100 +[07/20 16:07:44] cifar10-global-slim-resnet56 INFO: Epoch 2/100, Acc=0.8804, Val Loss=0.3950, lr=0.0100 +[07/20 16:08:02] cifar10-global-slim-resnet56 INFO: Epoch 3/100, Acc=0.8888, Val Loss=0.3768, lr=0.0100 +[07/20 16:08:19] cifar10-global-slim-resnet56 INFO: Epoch 4/100, Acc=0.8961, Val Loss=0.3320, lr=0.0100 +[07/20 16:08:37] cifar10-global-slim-resnet56 INFO: Epoch 5/100, Acc=0.8966, Val Loss=0.3522, lr=0.0100 +[07/20 16:08:54] cifar10-global-slim-resnet56 INFO: Epoch 6/100, Acc=0.8839, Val Loss=0.4093, lr=0.0100 +[07/20 16:09:11] cifar10-global-slim-resnet56 INFO: Epoch 7/100, Acc=0.8917, Val Loss=0.3489, lr=0.0100 +[07/20 16:09:29] cifar10-global-slim-resnet56 INFO: Epoch 8/100, Acc=0.8997, Val Loss=0.3340, lr=0.0100 +[07/20 16:09:46] cifar10-global-slim-resnet56 INFO: Epoch 9/100, Acc=0.8990, Val Loss=0.3386, lr=0.0100 +[07/20 16:10:03] cifar10-global-slim-resnet56 INFO: Epoch 10/100, Acc=0.8918, Val Loss=0.3756, lr=0.0100 +[07/20 16:10:20] cifar10-global-slim-resnet56 INFO: Epoch 11/100, Acc=0.8993, Val Loss=0.3459, lr=0.0100 +[07/20 16:10:38] cifar10-global-slim-resnet56 INFO: Epoch 12/100, Acc=0.8931, Val Loss=0.3621, lr=0.0100 +[07/20 16:10:55] cifar10-global-slim-resnet56 INFO: Epoch 13/100, Acc=0.8978, Val Loss=0.3433, lr=0.0100 +[07/20 16:11:13] cifar10-global-slim-resnet56 INFO: Epoch 14/100, Acc=0.9045, Val Loss=0.3137, lr=0.0100 +[07/20 16:11:32] cifar10-global-slim-resnet56 INFO: Epoch 15/100, Acc=0.8996, Val Loss=0.3362, lr=0.0100 +[07/20 16:11:49] cifar10-global-slim-resnet56 INFO: Epoch 16/100, Acc=0.8937, Val Loss=0.3583, lr=0.0100 +[07/20 16:12:07] cifar10-global-slim-resnet56 INFO: Epoch 17/100, Acc=0.8960, Val Loss=0.3604, lr=0.0100 +[07/20 16:12:25] cifar10-global-slim-resnet56 INFO: Epoch 18/100, Acc=0.9003, Val Loss=0.3391, lr=0.0100 +[07/20 16:12:43] cifar10-global-slim-resnet56 INFO: Epoch 19/100, Acc=0.8912, Val Loss=0.3608, lr=0.0100 +[07/20 16:13:01] cifar10-global-slim-resnet56 INFO: Epoch 20/100, Acc=0.8875, Val Loss=0.3981, lr=0.0100 +[07/20 16:13:18] cifar10-global-slim-resnet56 INFO: Epoch 21/100, Acc=0.8993, Val Loss=0.3319, lr=0.0100 +[07/20 16:13:36] cifar10-global-slim-resnet56 INFO: Epoch 22/100, Acc=0.8982, Val Loss=0.3408, lr=0.0100 +[07/20 16:13:53] cifar10-global-slim-resnet56 INFO: Epoch 23/100, Acc=0.8962, Val Loss=0.3459, lr=0.0100 +[07/20 16:14:11] cifar10-global-slim-resnet56 INFO: Epoch 24/100, Acc=0.9010, Val Loss=0.3386, lr=0.0100 +[07/20 16:14:28] cifar10-global-slim-resnet56 INFO: Epoch 25/100, Acc=0.8812, Val Loss=0.4142, lr=0.0100 +[07/20 16:14:45] cifar10-global-slim-resnet56 INFO: Epoch 26/100, Acc=0.8973, Val Loss=0.3714, lr=0.0100 +[07/20 16:15:03] cifar10-global-slim-resnet56 INFO: Epoch 27/100, Acc=0.8996, Val Loss=0.3361, lr=0.0100 +[07/20 16:15:20] cifar10-global-slim-resnet56 INFO: Epoch 28/100, Acc=0.8920, Val Loss=0.3883, lr=0.0100 +[07/20 16:15:37] cifar10-global-slim-resnet56 INFO: Epoch 29/100, Acc=0.8970, Val Loss=0.3645, lr=0.0100 +[07/20 16:15:55] cifar10-global-slim-resnet56 INFO: Epoch 30/100, Acc=0.8948, Val Loss=0.3601, lr=0.0100 +[07/20 16:16:13] cifar10-global-slim-resnet56 INFO: Epoch 31/100, Acc=0.8975, Val Loss=0.3512, lr=0.0100 +[07/20 16:16:30] cifar10-global-slim-resnet56 INFO: Epoch 32/100, Acc=0.8945, Val Loss=0.3835, lr=0.0100 +[07/20 16:16:47] cifar10-global-slim-resnet56 INFO: Epoch 33/100, Acc=0.8973, Val Loss=0.3387, lr=0.0100 +[07/20 16:17:05] cifar10-global-slim-resnet56 INFO: Epoch 34/100, Acc=0.9011, Val Loss=0.3381, lr=0.0100 +[07/20 16:17:23] cifar10-global-slim-resnet56 INFO: Epoch 35/100, Acc=0.9029, Val Loss=0.3298, lr=0.0100 +[07/20 16:17:40] cifar10-global-slim-resnet56 INFO: Epoch 36/100, Acc=0.8751, Val Loss=0.4391, lr=0.0100 +[07/20 16:17:57] cifar10-global-slim-resnet56 INFO: Epoch 37/100, Acc=0.8892, Val Loss=0.4121, lr=0.0100 +[07/20 16:18:15] cifar10-global-slim-resnet56 INFO: Epoch 38/100, Acc=0.8917, Val Loss=0.3832, lr=0.0100 +[07/20 16:18:32] cifar10-global-slim-resnet56 INFO: Epoch 39/100, Acc=0.9002, Val Loss=0.3376, lr=0.0100 +[07/20 16:18:50] cifar10-global-slim-resnet56 INFO: Epoch 40/100, Acc=0.8981, Val Loss=0.3492, lr=0.0100 +[07/20 16:19:07] cifar10-global-slim-resnet56 INFO: Epoch 41/100, Acc=0.8897, Val Loss=0.3946, lr=0.0100 +[07/20 16:19:24] cifar10-global-slim-resnet56 INFO: Epoch 42/100, Acc=0.9044, Val Loss=0.3255, lr=0.0100 +[07/20 16:19:41] cifar10-global-slim-resnet56 INFO: Epoch 43/100, Acc=0.9136, Val Loss=0.2942, lr=0.0100 +[07/20 16:19:59] cifar10-global-slim-resnet56 INFO: Epoch 44/100, Acc=0.8877, Val Loss=0.3956, lr=0.0100 +[07/20 16:20:16] cifar10-global-slim-resnet56 INFO: Epoch 45/100, Acc=0.9028, Val Loss=0.3327, lr=0.0100 +[07/20 16:20:33] cifar10-global-slim-resnet56 INFO: Epoch 46/100, Acc=0.8652, Val Loss=0.4811, lr=0.0100 +[07/20 16:20:51] cifar10-global-slim-resnet56 INFO: Epoch 47/100, Acc=0.8985, Val Loss=0.3806, lr=0.0100 +[07/20 16:21:09] cifar10-global-slim-resnet56 INFO: Epoch 48/100, Acc=0.9056, Val Loss=0.3378, lr=0.0100 +[07/20 16:21:26] cifar10-global-slim-resnet56 INFO: Epoch 49/100, Acc=0.9033, Val Loss=0.3469, lr=0.0100 +[07/20 16:21:44] cifar10-global-slim-resnet56 INFO: Epoch 50/100, Acc=0.8981, Val Loss=0.3595, lr=0.0100 +[07/20 16:22:01] cifar10-global-slim-resnet56 INFO: Epoch 51/100, Acc=0.8894, Val Loss=0.3978, lr=0.0100 +[07/20 16:22:19] cifar10-global-slim-resnet56 INFO: Epoch 52/100, Acc=0.8940, Val Loss=0.3741, lr=0.0100 +[07/20 16:22:36] cifar10-global-slim-resnet56 INFO: Epoch 53/100, Acc=0.8926, Val Loss=0.3807, lr=0.0100 +[07/20 16:22:54] cifar10-global-slim-resnet56 INFO: Epoch 54/100, Acc=0.8945, Val Loss=0.3592, lr=0.0100 +[07/20 16:23:12] cifar10-global-slim-resnet56 INFO: Epoch 55/100, Acc=0.8756, Val Loss=0.4583, lr=0.0100 +[07/20 16:23:29] cifar10-global-slim-resnet56 INFO: Epoch 56/100, Acc=0.8883, Val Loss=0.4007, lr=0.0100 +[07/20 16:23:48] cifar10-global-slim-resnet56 INFO: Epoch 57/100, Acc=0.8993, Val Loss=0.3497, lr=0.0100 +[07/20 16:24:06] cifar10-global-slim-resnet56 INFO: Epoch 58/100, Acc=0.9023, Val Loss=0.3308, lr=0.0100 +[07/20 16:24:23] cifar10-global-slim-resnet56 INFO: Epoch 59/100, Acc=0.8827, Val Loss=0.4252, lr=0.0100 +[07/20 16:24:41] cifar10-global-slim-resnet56 INFO: Epoch 60/100, Acc=0.9240, Val Loss=0.2592, lr=0.0010 +[07/20 16:24:59] cifar10-global-slim-resnet56 INFO: Epoch 61/100, Acc=0.9267, Val Loss=0.2516, lr=0.0010 +[07/20 16:25:17] cifar10-global-slim-resnet56 INFO: Epoch 62/100, Acc=0.9283, Val Loss=0.2534, lr=0.0010 +[07/20 16:25:35] cifar10-global-slim-resnet56 INFO: Epoch 63/100, Acc=0.9297, Val Loss=0.2532, lr=0.0010 +[07/20 16:25:53] cifar10-global-slim-resnet56 INFO: Epoch 64/100, Acc=0.9304, Val Loss=0.2553, lr=0.0010 +[07/20 16:26:10] cifar10-global-slim-resnet56 INFO: Epoch 65/100, Acc=0.9298, Val Loss=0.2562, lr=0.0010 +[07/20 16:26:28] cifar10-global-slim-resnet56 INFO: Epoch 66/100, Acc=0.9296, Val Loss=0.2584, lr=0.0010 +[07/20 16:26:44] cifar10-global-slim-resnet56 INFO: Epoch 67/100, Acc=0.9322, Val Loss=0.2561, lr=0.0010 +[07/20 16:27:03] cifar10-global-slim-resnet56 INFO: Epoch 68/100, Acc=0.9305, Val Loss=0.2606, lr=0.0010 +[07/20 16:27:20] cifar10-global-slim-resnet56 INFO: Epoch 69/100, Acc=0.9293, Val Loss=0.2606, lr=0.0010 +[07/20 16:27:39] cifar10-global-slim-resnet56 INFO: Epoch 70/100, Acc=0.9303, Val Loss=0.2627, lr=0.0010 +[07/20 16:27:56] cifar10-global-slim-resnet56 INFO: Epoch 71/100, Acc=0.9304, Val Loss=0.2641, lr=0.0010 +[07/20 16:28:14] cifar10-global-slim-resnet56 INFO: Epoch 72/100, Acc=0.9310, Val Loss=0.2587, lr=0.0010 +[07/20 16:28:31] cifar10-global-slim-resnet56 INFO: Epoch 73/100, Acc=0.9312, Val Loss=0.2602, lr=0.0010 +[07/20 16:28:49] cifar10-global-slim-resnet56 INFO: Epoch 74/100, Acc=0.9316, Val Loss=0.2684, lr=0.0010 +[07/20 16:29:06] cifar10-global-slim-resnet56 INFO: Epoch 75/100, Acc=0.9294, Val Loss=0.2710, lr=0.0010 +[07/20 16:29:22] cifar10-global-slim-resnet56 INFO: Epoch 76/100, Acc=0.9279, Val Loss=0.2701, lr=0.0010 +[07/20 16:29:40] cifar10-global-slim-resnet56 INFO: Epoch 77/100, Acc=0.9312, Val Loss=0.2656, lr=0.0010 +[07/20 16:29:58] cifar10-global-slim-resnet56 INFO: Epoch 78/100, Acc=0.9297, Val Loss=0.2684, lr=0.0010 +[07/20 16:30:16] cifar10-global-slim-resnet56 INFO: Epoch 79/100, Acc=0.9299, Val Loss=0.2707, lr=0.0010 +[07/20 16:30:34] cifar10-global-slim-resnet56 INFO: Epoch 80/100, Acc=0.9305, Val Loss=0.2680, lr=0.0001 +[07/20 16:30:51] cifar10-global-slim-resnet56 INFO: Epoch 81/100, Acc=0.9311, Val Loss=0.2654, lr=0.0001 +[07/20 16:31:09] cifar10-global-slim-resnet56 INFO: Epoch 82/100, Acc=0.9305, Val Loss=0.2662, lr=0.0001 +[07/20 16:31:27] cifar10-global-slim-resnet56 INFO: Epoch 83/100, Acc=0.9307, Val Loss=0.2642, lr=0.0001 +[07/20 16:31:44] cifar10-global-slim-resnet56 INFO: Epoch 84/100, Acc=0.9301, Val Loss=0.2659, lr=0.0001 +[07/20 16:32:01] cifar10-global-slim-resnet56 INFO: Epoch 85/100, Acc=0.9299, Val Loss=0.2686, lr=0.0001 +[07/20 16:32:19] cifar10-global-slim-resnet56 INFO: Epoch 86/100, Acc=0.9302, Val Loss=0.2651, lr=0.0001 +[07/20 16:32:37] cifar10-global-slim-resnet56 INFO: Epoch 87/100, Acc=0.9298, Val Loss=0.2665, lr=0.0001 +[07/20 16:32:53] cifar10-global-slim-resnet56 INFO: Epoch 88/100, Acc=0.9293, Val Loss=0.2680, lr=0.0001 +[07/20 16:33:11] cifar10-global-slim-resnet56 INFO: Epoch 89/100, Acc=0.9317, Val Loss=0.2664, lr=0.0001 +[07/20 16:33:27] cifar10-global-slim-resnet56 INFO: Epoch 90/100, Acc=0.9305, Val Loss=0.2694, lr=0.0001 +[07/20 16:33:44] cifar10-global-slim-resnet56 INFO: Epoch 91/100, Acc=0.9313, Val Loss=0.2676, lr=0.0001 +[07/20 16:34:02] cifar10-global-slim-resnet56 INFO: Epoch 92/100, Acc=0.9304, Val Loss=0.2695, lr=0.0001 +[07/20 16:34:19] cifar10-global-slim-resnet56 INFO: Epoch 93/100, Acc=0.9302, Val Loss=0.2655, lr=0.0001 +[07/20 16:34:35] cifar10-global-slim-resnet56 INFO: Epoch 94/100, Acc=0.9298, Val Loss=0.2667, lr=0.0001 +[07/20 16:34:52] cifar10-global-slim-resnet56 INFO: Epoch 95/100, Acc=0.9317, Val Loss=0.2675, lr=0.0001 +[07/20 16:35:09] cifar10-global-slim-resnet56 INFO: Epoch 96/100, Acc=0.9300, Val Loss=0.2664, lr=0.0001 +[07/20 16:35:26] cifar10-global-slim-resnet56 INFO: Epoch 97/100, Acc=0.9313, Val Loss=0.2654, lr=0.0001 +[07/20 16:35:44] cifar10-global-slim-resnet56 INFO: Epoch 98/100, Acc=0.9317, Val Loss=0.2660, lr=0.0001 +[07/20 16:36:01] cifar10-global-slim-resnet56 INFO: Epoch 99/100, Acc=0.9300, Val Loss=0.2677, lr=0.0001 +[07/20 16:36:01] cifar10-global-slim-resnet56 INFO: Best Acc=0.9322 diff --git a/tutorials/0 - QuickStart.ipynb b/examples/notebook/0 - QuickStart.ipynb similarity index 100% rename from tutorials/0 - QuickStart.ipynb rename to examples/notebook/0 - QuickStart.ipynb diff --git a/tutorials/1 - Customize Your Own Pruners.ipynb b/examples/notebook/1 - Customize Your Own Pruners.ipynb similarity index 100% rename from tutorials/1 - Customize Your Own Pruners.ipynb rename to examples/notebook/1 - Customize Your Own Pruners.ipynb diff --git a/tutorials/2 - Exploring Dependency Groups.ipynb b/examples/notebook/2 - Exploring Dependency Groups.ipynb similarity index 100% rename from tutorials/2 - Exploring Dependency Groups.ipynb rename to examples/notebook/2 - Exploring Dependency Groups.ipynb diff --git a/examples/timm_models/readme.md b/examples/timm_models/readme.md new file mode 100644 index 0000000..70eff00 --- /dev/null +++ b/examples/timm_models/readme.md @@ -0,0 +1,29 @@ +# Pruning Models from Torchvision + +## 0. Requirements + +```bash +pip install -r requirements.txt +``` +Tested environment: +``` +pytorch==1.12.1 +timm=0.9.2 +``` + +## 1. Pruning + +```python +python torchvision_pruning.py +``` + +#### Outputs: +Prunable: 119 models, +``` + ['beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_large_patch16_224', 'botnet26t_256', 'botnet50ts_256', 'convmixer_768_32', 'convmixer_1024_20_ks9_p14', 'convmixer_1536_20', 'convnext_atto', 'convnext_atto_ols', 'convnext_base', 'convnext_femto', 'convnext_femto_ols', 'convnext_large', 'convnext_large_mlp', 'convnext_nano', 'convnext_nano_ols', 'convnext_pico', 'convnext_pico_ols', 'convnext_small', 'convnext_tiny', 'convnext_tiny_hnf', 'convnext_xlarge', 'convnext_xxlarge', 'convnextv2_atto', 'convnextv2_base', 'convnextv2_femto', 'convnextv2_huge', 'convnextv2_large', 'convnextv2_nano', 'convnextv2_pico', 'convnextv2_small', 'convnextv2_tiny', 'darknet17', 'darknet21', 'darknet53', 'darknetaa53', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'densenet264d', 'dla34', 'dla46_c', 'dla46x_c', 'dla60', 'dla60x', 'dla60x_c', 'dla102', 'dla102x', 'dla102x2', 'dla169', 'eca_botnext26ts_256', 'eca_resnet33ts', 'eca_resnext26ts', 'eca_vovnet39b', 'ecaresnet26t', 'ecaresnet50d', 'ecaresnet50d_pruned', 'ecaresnet50t', 'ecaresnet101d', 'ecaresnet101d_pruned', 'ecaresnet200d', 'ecaresnet269d', 'ecaresnetlight', 'ecaresnext26t_32x4d', 'ecaresnext50t_32x4d', 'efficientnet_b0', 'efficientnet_b0_g8_gn', 'efficientnet_b0_g16_evos', 'efficientnet_b0_gn', 'efficientnet_b1', 'efficientnet_b1_pruned', 'efficientnet_b2', 'efficientnet_b2_pruned', 'efficientnet_b2a', 'efficientnet_b3', 'efficientnet_b3_gn', 'efficientnet_b3_pruned', 'efficientnet_b3a', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_el', 'efficientnet_el_pruned', 'efficientnet_em', 'efficientnet_es', 'efficientnet_es_pruned', 'efficientnet_l2', 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', 'efficientnetv2_l', 'efficientnetv2_m', 'efficientnetv2_rw_m', 'efficientnetv2_rw_s', 'efficientnetv2_rw_t', 'efficientnetv2_s', 'efficientnetv2_xl', 'ese_vovnet19b_dw', 'ese_vovnet19b_slim', 'ese_vovnet19b_slim_dw', 'ese_vovnet39b', 'ese_vovnet57b', 'ese_vovnet99b', 'fbnetc_100', 'fbnetv3_b', 'fbnetv3_d', 'fbnetv3_g', 'gc_efficientnetv2_rw_t', 'gcresnet33ts'] +``` + +Unprunable: 175 models, +``` + ['bat_resnext26ts', 'caformer_b36', 'caformer_m36', 'caformer_s18', 'caformer_s36', 'cait_m36_384', 'cait_m48_448', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_xs24_384', 'cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'coat_lite_medium', 'coat_lite_medium_384', 'coat_lite_mini', 'coat_lite_small', 'coat_lite_tiny', 'coat_mini', 'coat_small', 'coat_tiny', 'coatnet_0_224', 'coatnet_0_rw_224', 'coatnet_1_224', 'coatnet_1_rw_224', 'coatnet_2_224', 'coatnet_2_rw_224', 'coatnet_3_224', 'coatnet_3_rw_224', 'coatnet_4_224', 'coatnet_5_224', 'coatnet_bn_0_rw_224', 'coatnet_nano_cc_224', 'coatnet_nano_rw_224', 'coatnet_pico_rw_224', 'coatnet_rmlp_0_rw_224', 'coatnet_rmlp_1_rw2_224', 'coatnet_rmlp_1_rw_224', 'coatnet_rmlp_2_rw_224', 'coatnet_rmlp_2_rw_384', 'coatnet_rmlp_3_rw_224', 'coatnet_rmlp_nano_rw_224', 'coatnext_nano_rw_224', 'convformer_b36', 'convformer_m36', 'convformer_s18', 'convformer_s36', 'convit_base', 'convit_small', 'convit_tiny', 'crossvit_9_240', 'crossvit_9_dagger_240', 'crossvit_15_240', 'crossvit_15_dagger_240', 'crossvit_15_dagger_408', 'crossvit_18_240', 'crossvit_18_dagger_240', 'crossvit_18_dagger_408', 'crossvit_base_240', 'crossvit_small_240', 'crossvit_tiny_240', 'cs3darknet_focus_l', 'cs3darknet_focus_m', 'cs3darknet_focus_s', 'cs3darknet_focus_x', 'cs3darknet_l', 'cs3darknet_m', 'cs3darknet_s', 'cs3darknet_x', 'cs3edgenet_x', 'cs3se_edgenet_x', 'cs3sedarknet_l', 'cs3sedarknet_x', 'cs3sedarknet_xdw', 'cspdarknet53', 'cspresnet50', 'cspresnet50d', 'cspresnet50w', 'cspresnext50', 'davit_base', 'davit_giant', 'davit_huge', 'davit_large', 'davit_small', 'davit_tiny', 'deit3_base_patch16_224', 'deit3_base_patch16_384', 'deit3_huge_patch14_224', 'deit3_large_patch16_224', 'deit3_large_patch16_384', 'deit3_medium_patch16_224', 'deit3_small_patch16_224', 'deit3_small_patch16_384', 'deit_base_distilled_patch16_224', 'deit_base_distilled_patch16_384', 'deit_base_patch16_224', 'deit_base_patch16_384', 'deit_small_distilled_patch16_224', 'deit_small_patch16_224', 'deit_tiny_distilled_patch16_224', 'deit_tiny_patch16_224', 'densenetblur121d', 'dla60_res2net', 'dla60_res2next', 'dm_nfnet_f0', 'dm_nfnet_f1', 'dm_nfnet_f2', 'dm_nfnet_f3', 'dm_nfnet_f4', 'dm_nfnet_f5', 'dm_nfnet_f6', 'dpn48b', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'eca_halonext26ts', 'eca_nfnet_l0', 'eca_nfnet_l1', 'eca_nfnet_l2', 'eca_nfnet_l3', 'edgenext_base', 'edgenext_small', 'edgenext_small_rw', 'edgenext_x_small', 'edgenext_xx_small', 'efficientformer_l1', 'efficientformer_l3', 'efficientformer_l7', 'efficientformerv2_l', 'efficientformerv2_s0', 'efficientformerv2_s1', 'efficientformerv2_s2', 'efficientnet_b3_g8_gn', 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', 'ese_vovnet39b_evos', 'eva02_base_patch14_224', 'eva02_base_patch14_448', 'eva02_base_patch16_clip_224', 'eva02_enormous_patch14_clip_224', 'eva02_large_patch14_224', 'eva02_large_patch14_448', 'eva02_large_patch14_clip_224', 'eva02_large_patch14_clip_336', 'eva02_small_patch14_224', 'eva02_small_patch14_336', 'eva02_tiny_patch14_224', 'eva02_tiny_patch14_336', 'eva_giant_patch14_224', 'eva_giant_patch14_336', 'eva_giant_patch14_560', 'eva_giant_patch14_clip_224', 'eva_large_patch14_196', 'eva_large_patch14_336', 'flexivit_base', 'flexivit_large', 'flexivit_small', 'focalnet_base_lrf', 'focalnet_base_srf', 'focalnet_huge_fl3', 'focalnet_huge_fl4', 'focalnet_large_fl3', 'focalnet_large_fl4', 'focalnet_small_lrf', 'focalnet_small_srf', 'focalnet_tiny_lrf', 'focalnet_tiny_srf', 'focalnet_xlarge_fl3', 'focalnet_xlarge_fl4'] +``` \ No newline at end of file diff --git a/examples/timm_models/timm_pruning.py b/examples/timm_models/timm_pruning.py new file mode 100644 index 0000000..5d98a27 --- /dev/null +++ b/examples/timm_models/timm_pruning.py @@ -0,0 +1,61 @@ +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))) + +import torch +import timm +import torch_pruning as tp + +# timm==0.9.2 +# torch==1.12.1 + +timm_models = timm.list_models() +print(timm_models) +example_inputs = torch.randn(1,3,224,224) +imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") +prunable_list = [] +unprunable_list = [] +problem_with_input_shape = [] +for i, model_name in enumerate(timm_models): + print("Pruning %s..."%model_name) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + #if 'rexnet' in model_name or 'sequencer' in model_name or 'botnet' in model_name: # pruning process stuck with that architectures - skip them. + # unprunable_list.append(model_name) + # continue + try: + model = timm.create_model(model_name, pretrained=False, no_jit=True).eval().to(device) + except: # out of memory error + model = timm.create_model(model_name, pretrained=False, no_jit=True).eval() + device = 'cpu' + + input_size = model.default_cfg['input_size'] + example_inputs = torch.randn(1, *input_size).to(device) + test_output = model(example_inputs) + + print(model) + prunable = True + try: + base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) + pruner = tp.pruner.MagnitudePruner( + model, + example_inputs, + global_pruning=False, # If False, a uniform sparsity will be assigned to different layers. + importance=imp, # importance criterion for parameter selection + iterative_steps=1, # the number of iterations to achieve target sparsity + ch_sparsity=0.5, + ignored_layers=[], + ) + pruner.step() + test_output = model(example_inputs) + pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) + print("Base MACs: %d, Pruned MACs: %d"%(base_macs, pruned_macs)) + print("Base Params: %d, Pruned Params: %d"%(base_params, pruned_params)) + except Exception as e: + prunable = False + + if prunable: + prunable_list.append(model_name) + else: + unprunable_list.append(model_name) + + print("Prunable: %d models, \n %s\n"%(len(prunable_list), prunable_list)) + print("Unprunable: %d models, \n %s\n"%(len(unprunable_list), unprunable_list)) \ No newline at end of file diff --git a/examples/torchvision_models/readme.md b/examples/torchvision_models/readme.md new file mode 100644 index 0000000..d198cd9 --- /dev/null +++ b/examples/torchvision_models/readme.md @@ -0,0 +1,158 @@ +# Pruning Models from Torchvision + +## 0. Requirements + +```bash +pip install -r requirements.txt +``` +Tested environment: +``` +Pytorch==1.12.1 +Torchvision==0.13.1 +``` + +## 1. Pruning + +```python +python torchvision_pruning.py +``` + +#### Outputs: +``` +Successful Pruning: 81 Models + ['ssdlite320_mobilenet_v3_large', 'ssd300_vgg16', 'fasterrcnn_resnet50_fpn', 'fasterrcnn_resnet50_fpn_v2', 'fasterrcnn_mobilenet_v3_large_320_fpn', 'fasterrcnn_mobilenet_v3_large_fpn', 'fcos_resnet50_fpn', 'keypointrcnn_resnet50_fpn', 'maskrcnn_resnet50_fpn_v2', 'retinanet_resnet50_fpn_v2', 'alexnet', 'vit_b_16', 'vit_b_32', 'vit_l_16', 'vit_l_32', 'vit_h_14', 'convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l', 'googlenet', 'inception_v3', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_1_6gf', 'regnet_y_3_2gf', 'regnet_y_8gf', 'regnet_y_16gf', 'regnet_y_32gf', 'regnet_y_128gf', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2', 'fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101', 'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'] + +``` + +``` +Unsuccessful Pruning: 4 Models + ['raft_large', 'swin_t', 'swin_s', 'swin_b'] +``` + +#### Vision Transfomer Example +``` +==============Before pruning================= +Model Name: vit_b_32 +VisionTransformer( + (conv_proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32)) + (encoder): Encoder( + (dropout): Dropout(p=0.0, inplace=False) + (layers): Sequential( + (encoder_layer_0): EncoderBlock( + (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) + (self_attention): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True) + ) + (dropout): Dropout(p=0.0, inplace=False) + (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) + (mlp): MLPBlock( + (0): Linear(in_features=768, out_features=3072, bias=True) + (1): GELU(approximate=none) + (2): Dropout(p=0.0, inplace=False) + (3): Linear(in_features=3072, out_features=768, bias=True) + (4): Dropout(p=0.0, inplace=False) + ) + ) +... + (encoder_layer_10): EncoderBlock( + (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) + (self_attention): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True) + ) + (dropout): Dropout(p=0.0, inplace=False) + (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) + (mlp): MLPBlock( + (0): Linear(in_features=768, out_features=3072, bias=True) + (1): GELU(approximate=none) + (2): Dropout(p=0.0, inplace=False) + (3): Linear(in_features=3072, out_features=768, bias=True) + (4): Dropout(p=0.0, inplace=False) + ) + ) + (encoder_layer_11): EncoderBlock( + (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) + (self_attention): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True) + ) + (dropout): Dropout(p=0.0, inplace=False) + (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) + (mlp): MLPBlock( + (0): Linear(in_features=768, out_features=3072, bias=True) + (1): GELU(approximate=none) + (2): Dropout(p=0.0, inplace=False) + (3): Linear(in_features=3072, out_features=768, bias=True) + (4): Dropout(p=0.0, inplace=False) + ) + ) + ) + (ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True) + ) + (heads): Sequential( + (head): Linear(in_features=768, out_features=1000, bias=True) + ) +) +torch.Size([1, 1, 384]) torch.Size([1, 50, 384]) +==============After pruning================= +VisionTransformer( + (conv_proj): Conv2d(3, 384, kernel_size=(32, 32), stride=(32, 32)) + (encoder): Encoder( + (dropout): Dropout(p=0.0, inplace=False) + (layers): Sequential( + (encoder_layer_0): EncoderBlock( + (ln_1): LayerNorm((384,), eps=1e-06, elementwise_affine=True) + (self_attention): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True) + ) + (dropout): Dropout(p=0.0, inplace=False) + (ln_2): LayerNorm((384,), eps=1e-06, elementwise_affine=True) + (mlp): MLPBlock( + (0): Linear(in_features=384, out_features=1536, bias=True) + (1): GELU(approximate=none) + (2): Dropout(p=0.0, inplace=False) + (3): Linear(in_features=1536, out_features=384, bias=True) + (4): Dropout(p=0.0, inplace=False) + ) + ) +... + (encoder_layer_10): EncoderBlock( + (ln_1): LayerNorm((384,), eps=1e-06, elementwise_affine=True) + (self_attention): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True) + ) + (dropout): Dropout(p=0.0, inplace=False) + (ln_2): LayerNorm((384,), eps=1e-06, elementwise_affine=True) + (mlp): MLPBlock( + (0): Linear(in_features=384, out_features=1536, bias=True) + (1): GELU(approximate=none) + (2): Dropout(p=0.0, inplace=False) + (3): Linear(in_features=1536, out_features=384, bias=True) + (4): Dropout(p=0.0, inplace=False) + ) + ) + (encoder_layer_11): EncoderBlock( + (ln_1): LayerNorm((384,), eps=1e-06, elementwise_affine=True) + (self_attention): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True) + ) + (dropout): Dropout(p=0.0, inplace=False) + (ln_2): LayerNorm((384,), eps=1e-06, elementwise_affine=True) + (mlp): MLPBlock( + (0): Linear(in_features=384, out_features=1536, bias=True) + (1): GELU(approximate=none) + (2): Dropout(p=0.0, inplace=False) + (3): Linear(in_features=1536, out_features=384, bias=True) + (4): Dropout(p=0.0, inplace=False) + ) + ) + ) + (ln): LayerNorm((384,), eps=1e-06, elementwise_affine=True) + ) + (heads): Sequential( + (head): Linear(in_features=384, out_features=1000, bias=True) + ) +) +Pruning vit_b_32: + Params: 88224232 => 22878952 + Output: torch.Size([1, 1000]) +------------------------------------------------------ +``` diff --git a/benchmarks/prunability/torchvision_pruning.py b/examples/torchvision_models/torchvision_pruning.py similarity index 89% rename from benchmarks/prunability/torchvision_pruning.py rename to examples/torchvision_models/torchvision_pruning.py index 38d5352..95c577f 100644 --- a/benchmarks/prunability/torchvision_pruning.py +++ b/examples/torchvision_models/torchvision_pruning.py @@ -203,6 +203,16 @@ def my_prune(model, example_inputs, output_transform, model_name): print("==============Before pruning=================") print("Model Name: {}".format(model_name)) print(model) + + layer_channel_cfg = {} + for module in model.modules(): + if module not in pruner.ignored_layers: + #print(module) + if isinstance(module, nn.Conv2d): + layer_channel_cfg[module] = module.out_channels + elif isinstance(module, nn.Linear): + layer_channel_cfg[module] = module.out_features + pruner.step() if isinstance( model, VisionTransformer @@ -223,7 +233,18 @@ def my_prune(model, example_inputs, output_transform, model_name): if output_transform: out = output_transform(out) print("{} Pruning: ".format(model_name)) - print(" Params: %s => %s" % (ori_size, tp.utils.count_params(model))) + params_after_prune = tp.utils.count_params(model) + print(" Params: %s => %s" % (ori_size, params_after_prune)) + + if 'rcnn' not in model_name and model_name!='ssdlite320_mobilenet_v3_large': # RCNN may return 0 proposals, making some layers unreachable during tracing. + for module, ch in layer_channel_cfg.items(): + if isinstance(module, nn.Conv2d): + #print(module.out_channels, layer_channel_cfg[module]) + assert int(0.5*layer_channel_cfg[module]) == module.out_channels + elif isinstance(module, nn.Linear): + #print(module.out_features, layer_channel_cfg[module]) + assert int(0.5*layer_channel_cfg[module]) == module.out_features + if isinstance(out, (dict,list,tuple)): print(" Output:") for o in tp.utils.flatten_as_list(out): diff --git a/examples/yolov7/readme.md b/examples/yolov7/readme.md new file mode 100644 index 0000000..f62428c --- /dev/null +++ b/examples/yolov7/readme.md @@ -0,0 +1,84 @@ +# YOLOv7 Pruning + +## 0. Requirements + +```bash +pip install -r requirements.txt +``` +Tested environment: +``` +Pytorch==1.12.1 +Torchvision==0.13.1 +``` + +## 1. Pruning +The following scripts (adapted from [yolov7/detect.py](https://github.com/WongKinYiu/yolov7/blob/main/detect.py) and [yolov7/train.py](https://github.com/WongKinYiu/yolov7/blob/main/train.py)) provide the basic examples of pruning YOLOv7. It is important to note that the training part has not been validated yet due to the time-consuming training process. + +Note: [yolov7_detect_pruned.py](https://github.com/VainF/Torch-Pruning/blob/master/benchmarks/prunability/yolov7_detect_pruned.py) does not include any code for fine-tuning. + +```bash +git clone https://github.com/WongKinYiu/yolov7.git +cp yolov7_detect_pruned.py yolov7/ +cp yolov7_train_pruned.py yolov7/ +cd yolov7 + +# Test only: We only prune and test the YOLOv7 model in this script. COCO dataset is not required. +python yolov7_detect_pruned.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inference/images/horses.jpg + +# Training with pruned yolov7 (The training part is not validated) +# Please download the pretrained yolov7_training.pt from https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt. +python yolov7_train_pruned.py --workers 8 --device 0 --batch-size 1 --data data/coco.yaml --img 640 640 --cfg cfg/training/yolov7.yaml --weights 'yolov7_training.pt' --name yolov7 --hyp data/hyp.scratch.p5.yaml +``` + +#### Screenshot for yolov7_train_pruned.py: +![image](https://user-images.githubusercontent.com/18592211/232129303-18a61be1-b505-4950-b6a1-c60b4974291b.png) + + +#### Outputs of yolov7_detect_pruned.py: +``` +Model( + (model): Sequential( + (0): Conv( + (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (act): SiLU(inplace=True) + ) +... + (104): RepConv( + (act): SiLU(inplace=True) + (rbr_reparam): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + ) + (105): Detect( + (m): ModuleList( + (0): Conv2d(256, 255, kernel_size=(1, 1), stride=(1, 1)) + (1): Conv2d(512, 255, kernel_size=(1, 1), stride=(1, 1)) + (2): Conv2d(1024, 255, kernel_size=(1, 1), stride=(1, 1)) + ) + ) + ) +) + + +Model( + (model): Sequential( + (0): Conv( + (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (act): SiLU(inplace=True) + ) +... + (104): RepConv( + (act): SiLU(inplace=True) + (rbr_reparam): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + ) + (105): Detect( + (m): ModuleList( + (0): Conv2d(128, 255, kernel_size=(1, 1), stride=(1, 1)) + (1): Conv2d(256, 255, kernel_size=(1, 1), stride=(1, 1)) + (2): Conv2d(512, 255, kernel_size=(1, 1), stride=(1, 1)) + ) + ) + ) +) +Before Pruning: MACs=6.413721 G, #Params=0.036905 G +After Pruning: MACs=1.639895 G, #Params=0.009347 G +``` + diff --git a/benchmarks/prunability/yolov7_detect_pruned.py b/examples/yolov7/yolov7_detect_pruned.py similarity index 100% rename from benchmarks/prunability/yolov7_detect_pruned.py rename to examples/yolov7/yolov7_detect_pruned.py diff --git a/benchmarks/prunability/yolov7_train_pruned.py b/examples/yolov7/yolov7_train_pruned.py similarity index 100% rename from benchmarks/prunability/yolov7_train_pruned.py rename to examples/yolov7/yolov7_train_pruned.py diff --git a/examples/yolov8/readme.md b/examples/yolov8/readme.md new file mode 100644 index 0000000..76cefcd --- /dev/null +++ b/examples/yolov8/readme.md @@ -0,0 +1,121 @@ +# YOLOv8 Pruning + +## 0. Requirements + +```bash +pip install -r requirements.txt +``` +Tested environment: +``` +Pytorch==1.12.1 +Torchvision==0.13.1 +``` + +## 1. Pruning + +This example was implemented by [@Hyunseok-Kim0 (Hyunseok Kim)](https://github.com/Hyunseok-Kim0). Please refer to Issue [#147](https://github.com/VainF/Torch-Pruning/issues/147#issuecomment-1507475657) for more details. + +#### Ultralytics +```bash +git clone https://github.com/ultralytics/ultralytics.git +cp yolov8_pruning.py ultralytics/ +cd ultralytics +git checkout 44c7c3514d87a5e05cfb14dba5a3eeb6eb860e70 # for compatibility +``` + +#### Modification +Some functions will be automatically modified by the yolov8_pruning.py to prevent performance loss during model saving. + +##### 1. ```train``` in class ```YOLO``` +This function creates new trainer when called. Trainer loads model based on config file and reassign it to current model, which should be avoided for pruning. + +##### 2. ```save_model``` in class ```BaseTrainer``` +YOLO v8 saves trained model with half precision. Due to this precision loss, saved model shows different performance with validation result during fine-tuning. +This is modified to save the model with full precision because changing model to half precision can be done easily whenever after the pruning. + +##### 3. ```final_eval``` in class ```BaseTrainer``` +YOLO v8 replaces saved checkpoint file to half precision after training is done using ```strip_optimizer```. Half precision saving is changed with same reason above. + +#### Training +``` +# This example will craft yolov8-half and fine-tune it on the coco128 toy set. +python yolov8_pruning.py +``` + +#### Screenshot for coco128 post-training: +image + + +#### Outputs of yolov8_pruning.py: +``` +DetectionModel( + (model): Sequential( + (0): Conv( + (conv): Conv2d(3, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) + (act): SiLU(inplace=True) + ) + (1): Conv( + (conv): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) + (act): SiLU(inplace=True) + ) +... + (2): Sequential( + (0): Conv( + (conv): Conv2d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) + (act): SiLU(inplace=True) + ) + (1): Conv( + (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) + (act): SiLU(inplace=True) + ) + (2): Conv2d(320, 80, kernel_size=(1, 1), stride=(1, 1)) + ) + ) + (dfl): DFL( + (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False) + ) + ) + ) +) + + +DetectionModel( + (model): Sequential( + (0): Conv( + (conv): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn): BatchNorm2d(40, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) + (act): SiLU(inplace=True) + ) + (1): Conv( + (conv): Conv2d(40, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) + (act): SiLU(inplace=True) + ) +... + (2): Sequential( + (0): Conv( + (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) + (act): SiLU(inplace=True) + ) + (1): Conv( + (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True) + (act): SiLU(inplace=True) + ) + (2): Conv2d(320, 80, kernel_size=(1, 1), stride=(1, 1)) + ) + ) + (dfl): DFL( + (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False) + ) + ) + ) +) +Before Pruning: MACs=129.092051 G, #Params=68.229648 M +After Pruning: MACs=41.741203 G, #Params=20.787528 M +``` diff --git a/benchmarks/prunability/yolov8_pruning.py b/examples/yolov8/yolov8_pruning.py similarity index 100% rename from benchmarks/prunability/yolov8_pruning.py rename to examples/yolov8/yolov8_pruning.py diff --git a/tests/test_customized_layer.py b/tests/test_customized_layer.py index 8d00cbc..0d8d8a4 100644 --- a/tests/test_customized_layer.py +++ b/tests/test_customized_layer.py @@ -1,4 +1,5 @@ import sys, os + sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) import torch @@ -55,8 +56,8 @@ def prune_out_channels(self, layer: CustomizedLayer, idxs: Sequence[int]) -> nn. keep_idxs = list(set(range(layer.in_dim)) - set(idxs)) keep_idxs.sort() layer.in_dim = layer.in_dim-len(idxs) - layer.scale = torch.nn.Parameter(layer.scale.data.clone()[keep_idxs]) - layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs]) + layer.scale = self._prune_parameter_and_grad(layer.scale, keep_idxs, pruning_dim=0) + layer.bias = self._prune_parameter_and_grad(layer.bias, keep_idxs, pruning_dim=0) tp.prune_linear_in_channels(layer.fc, idxs) tp.prune_linear_out_channels(layer.fc, idxs) return layer @@ -68,6 +69,15 @@ def get_out_channels(self, layer): prune_in_channels = prune_out_channels get_in_channels = get_out_channels +class MyLinearPruner(tp.function.LinearPruner): + def prune_out_channels(self, layer: nn.Linear, idxs: Sequence[int]) -> nn.Linear: + print("MyLinearPruner applied to layer: ", layer) + return super().prune_out_channels(layer, idxs) + + def prune_in_channels(self, layer: nn.Linear, idxs: Sequence[int]) -> nn.Linear: + print("MyLinearPruner applied to layer: ", layer) + return super().prune_in_channels(layer, idxs) + def test_customization(): model = FullyConnectedNet(128, 10, 256) @@ -78,12 +88,17 @@ def test_customization(): DG.register_customized_layer( CustomizedLayer, my_pruner) + + my_linear_pruner = MyLinearPruner() + DG.register_customized_layer( + nn.Linear, my_linear_pruner + ) # 2. Build dependency graph DG.build_dependency(model, example_inputs=torch.randn(1,128)) # 3. get a pruning group according to the dependency graph. idxs is the indices of pruned filters. - pruning_group = DG.get_pruning_group( model.fc1, tp.prune_linear_out_channels, idxs=[0, 1, 6] ) + pruning_group = DG.get_pruning_group( model.fc1, my_linear_pruner.prune_out_channels, idxs=[0, 1, 6] ) print(pruning_group) # 4. execute this group (prune the model) diff --git a/tests/test_importance.py b/tests/test_importance.py deleted file mode 100644 index 5c42611..0000000 --- a/tests/test_importance.py +++ /dev/null @@ -1,47 +0,0 @@ -import sys, os -sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) - -import torch -from torchvision.models import resnet18 -import torch_pruning as tp -model = resnet18() - -# Global metrics -def test_imp(): - DG = tp.DependencyGraph() - example_inputs = torch.randn(1,3,224,224) - DG.build_dependency(model, example_inputs=example_inputs) - pruning_idxs = list( range( DG.get_out_channels(model.conv1) )) - pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs) - - random_importance = tp.importance.RandomImportance() - rand_imp = random_importance(pruning_group) - print("Random: ", rand_imp) - - magnitude_importance = tp.importance.MagnitudeImportance(p=1) - mag_imp = magnitude_importance(pruning_group) - print("L-1 Norm, Group Mean: ", mag_imp) - - magnitude_importance = tp.importance.MagnitudeImportance(p=2) - mag_imp = magnitude_importance(pruning_group) - print("L-2 Norm, Group Mean: ", mag_imp) - - magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='sum') - mag_imp = magnitude_importance(pruning_group) - print("L-2 Norm, Group Sum: ", mag_imp) - - magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction=None) - mag_imp = magnitude_importance(pruning_group) - print("L-2 Norm, No Reduction: ", mag_imp) - - bn_scale_importance = tp.importance.BNScaleImportance() - bn_imp = bn_scale_importance(pruning_group) - print("BN Scaling, Group mean: ", bn_imp) - - lamp_importance = tp.importance.LAMPImportance() - lamp_imp = lamp_importance(pruning_group) - print("LAMP: ", lamp_imp) - - -if __name__=='__main__': - test_imp() \ No newline at end of file diff --git a/tests/test_importance_reduction.py b/tests/test_importance_reduction.py new file mode 100644 index 0000000..9160a27 --- /dev/null +++ b/tests/test_importance_reduction.py @@ -0,0 +1,69 @@ +import sys, os +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + +import torch +from torchvision.models import resnet18 +import torch_pruning as tp +model = resnet18() + +# Global metrics +def test_imp(): + DG = tp.DependencyGraph() + example_inputs = torch.randn(1,3,224,224) + DG.build_dependency(model, example_inputs=example_inputs) + pruning_idxs = list( range( DG.get_out_channels(model.conv1) )) + pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs) + + random_importance = tp.importance.RandomImportance() + rand_imp = random_importance(pruning_group) + print("Random: ", rand_imp) + + magnitude_importance = tp.importance.MagnitudeImportance(p=1, group_reduction=None, normalizer=None) + mag_imp_raw = magnitude_importance(pruning_group) + print("L-1 Norm, No Reduction: ", mag_imp_raw) + + magnitude_importance = tp.importance.MagnitudeImportance(p=1, normalizer=None) + mag_imp = magnitude_importance(pruning_group) + print("L-1 Norm, Group Mean: ", mag_imp) + assert torch.allclose(mag_imp, mag_imp_raw.mean(0)) + + magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction=None, normalizer=None) + mag_imp_raw = magnitude_importance(pruning_group) + print("L-2 Norm, No Reduction: ", mag_imp_raw) + + magnitude_importance = tp.importance.MagnitudeImportance(p=2, normalizer=None) + mag_imp = magnitude_importance(pruning_group) + print("L-2 Norm, Group Mean: ", mag_imp) + assert torch.allclose(mag_imp, mag_imp_raw.mean(0)) + + magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='sum', normalizer=None) + mag_imp = magnitude_importance(pruning_group) + print("L-2 Norm, Group Sum: ", mag_imp) + assert torch.allclose(mag_imp, mag_imp_raw.sum(0)) + + magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='max', normalizer=None) + mag_imp = magnitude_importance(pruning_group) + print("L-2 Norm, Group Max: ", mag_imp) + assert torch.allclose(mag_imp, mag_imp_raw.max(0)[0]) + + magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='gate', normalizer=None) + mag_imp = magnitude_importance(pruning_group) + print("L-2 Norm, Group Gate: ", mag_imp) + assert torch.allclose(mag_imp, mag_imp_raw[-1]) + + magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='prod', normalizer=None) + mag_imp = magnitude_importance(pruning_group) + print("L-2 Norm, Group Prod: ", mag_imp) + print(mag_imp, torch.prod(mag_imp_raw, dim=0)) + assert torch.allclose(mag_imp, torch.prod(mag_imp_raw, dim=0)) + + bn_scale_importance = tp.importance.BNScaleImportance(normalizer=None) + bn_imp = bn_scale_importance(pruning_group) + print("BN Scaling, Group mean: ", bn_imp) + + lamp_importance = tp.importance.LAMPImportance(normalizer=None) + lamp_imp = lamp_importance(pruning_group) + print("LAMP: ", lamp_imp) + +if __name__=='__main__': + test_imp() \ No newline at end of file diff --git a/tests/test_single_channel_output.py b/tests/test_single_channel_output.py index 27ee8e9..99f5a84 100644 --- a/tests/test_single_channel_output.py +++ b/tests/test_single_channel_output.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import torch_pruning as tp -class TestModel(nn.Module): +class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3, stride=1, padding=1) @@ -20,7 +20,7 @@ def forward(self, x): return x def test_single_channel_output(): - model = TestModel() + model = Model() example_inputs = torch.randn(1, 3, 224, 224) DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs) diff --git a/torch_pruning/__init__.py b/torch_pruning/__init__.py index cef0c3d..f532b39 100644 --- a/torch_pruning/__init__.py +++ b/torch_pruning/__init__.py @@ -1,5 +1,6 @@ +from .pruner import importance from .dependency import * from .pruner import * -from . import _helpers, utils, importance +from . import _helpers, utils from .serialization import save, load, state_dict, load_state_dict \ No newline at end of file diff --git a/torch_pruning/_helpers.py b/torch_pruning/_helpers.py index c18d1b1..df71913 100644 --- a/torch_pruning/_helpers.py +++ b/torch_pruning/_helpers.py @@ -3,7 +3,35 @@ import torch from operator import add from numbers import Number - +from collections import namedtuple + +UnwrappedParameters = namedtuple('UnwrappedParameters', ['parameters', 'pruning_dim']) + +class GroupItem(namedtuple('_GroupItem', ['dep', 'idxs'])): + def __new__(cls, dep, idxs): + """ A tuple of (dep, idxs) where dep is the dependency of the group, and idxs is the list of indices in the group.""" + cls.root_idxs = None # a placeholder. Will be filled by DepGraph + return super(GroupItem, cls).__new__(cls, dep, idxs) + + def __repr__(self): + return str( (self.dep, self.idxs) ) + +class _HybridIndex(namedtuple("_PruingIndex", ["idx", "root_idx"])): + """ A tuple of (idx, root_idx) where idx is the index of the pruned dimension in the current layer, + and root_idx is the index of the pruned dimension in the root layer. + """ + def __repr__(self): + return str( (self.idx, self.root_idx) ) + +def to_plain_idxs(idxs: _HybridIndex): + if len(idxs)==0 or not isinstance(idxs[0], _HybridIndex): + return idxs + return [i.idx for i in idxs] + +def to_root_idxs(idxs: _HybridIndex): + if len(idxs)==0 or not isinstance(idxs[0], _HybridIndex): + return idxs + return [i.root_idx for i in idxs] def is_scalar(x): if isinstance(x, torch.Tensor): @@ -20,16 +48,18 @@ def __init__(self, stride=1, reverse=False): self._stride = stride self.reverse = reverse - def __call__(self, idxs): + def __call__(self, idxs: _HybridIndex): new_idxs = [] + if self.reverse == True: for i in idxs: - new_idxs.append(i // self._stride) - new_idxs = list(set(new_idxs)) + new_idxs.append( _HybridIndex( idx = (i.idx // self._stride), root_idx=i.root_idx ) ) + new_idxs = list(set(new_idxs)) else: for i in idxs: new_idxs.extend( - list(range(i * self._stride, (i + 1) * self._stride))) + [ _HybridIndex(idx=k, root_idx=i.root_idx) for k in range(i.idx * self._stride, (i.idx + 1) * self._stride) ] + ) return new_idxs @@ -38,16 +68,16 @@ def __init__(self, offset, reverse=False): self.offset = offset self.reverse = reverse - def __call__(self, idxs): + def __call__(self, idxs: _HybridIndex): if self.reverse == True: new_idxs = [ - i - self.offset[0] + _HybridIndex(idx = i.idx - self.offset[0], root_idx=i.root_idx ) for i in idxs - if (i >= self.offset[0] and i < self.offset[1]) + if (i.idx >= self.offset[0] and i.idx < self.offset[1]) ] else: - new_idxs = [i + self.offset[0] for i in idxs] + new_idxs = [ _HybridIndex(idx=i.idx + self.offset[0], root_idx=i.root_idx) for i in idxs] return new_idxs @@ -56,36 +86,17 @@ def __init__(self, offset, reverse=False): self.offset = offset self.reverse = reverse - def __call__(self, idxs): + def __call__(self, idxs: _HybridIndex): if self.reverse == True: - new_idxs = [i + self.offset[0] for i in idxs] + new_idxs = [ _HybridIndex(idx=i.idx + self.offset[0], root_idx=i.root_idx) for i in idxs] else: new_idxs = [ - i - self.offset[0] + _HybridIndex(idx = i.idx - self.offset[0], root_idx=i.root_idx) for i in idxs - if (i >= self.offset[0] and i < self.offset[1]) + if (i.idx >= self.offset[0] and i.idx < self.offset[1]) ] return new_idxs - -class _GroupConvIndexMapping(object): - def __init__(self, in_channels, out_channels, groups, reverse=False): - self.in_channels = in_channels - self.out_channels = out_channels - self.groups = groups - self.reverse = reverse - - def __call__(self, idxs): - if self.reverse == True: - new_idxs = [i + self.offset[0] for i in idxs] - else: - group_histgram = np.histogram( - idxs, bins=self.groups, range=(0, self.out_channels) - ) - max_group_size = int(group_histgram.max()) - return new_idxs - - class ScalarSum: def __init__(self): self._results = {} diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index 5a3c329..b4fec44 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -9,27 +9,38 @@ from .pruner import function from . import _helpers, utils, ops +from ._helpers import UnwrappedParameters, _HybridIndex, GroupItem + __all__ = ["Dependency", "Group", "DependencyGraph"] +_PLACEHOLDER = None + +def equal_func(func1, func2): + return ( + hasattr(func1, '__self__') and + hasattr(func2, '__self__') and + isinstance(func1.__self__, type(func2.__self__)) and + func1.__name__ == func2.__name__ + ) class Node(object): - """ Nodes of DepGraph + """ Node of DepGraph """ - def __init__(self, module: nn.Module, grad_fn, name: str = None): # For Computational Graph (Tracing) - self.inputs = [] - self.outputs = [] - self.module = module - self.grad_fn = grad_fn - self._name = name - self.type = ops.module2type(module) - self.class_type = module.__class__ + self.inputs = [] # input nodes + self.outputs = [] # output nodes + self.module = module # reference to torch.nn.Module + + self.grad_fn = grad_fn # gradient function of module output + self._name = name # node name + self.type = ops.module2type(module) # node type (enum) + self.module_class = module.__class__ # class type of the module # For Dependency Graph self.dependencies = [] # Adjacency List - self.enable_index_mapping = True - self.pruning_dim = -1 + self.enable_index_mapping = True # whether to enable index mapping + self.pruning_dim = -1 # the dimension to be pruned @property def name(self): @@ -42,7 +53,6 @@ def name(self): return fmt def add_input(self, node, allow_dumplicated=False): - #if node not in self.inputs: if allow_dumplicated is True: self.inputs.append(node) else: @@ -57,7 +67,7 @@ def add_output(self, node, allow_dumplicated=False): self.outputs.append(node) def __repr__(self): - return "".format(self.name) + return str(self) def __str__(self): return "".format(self.name) @@ -74,13 +84,13 @@ def details(self): fmt += " " * 4 + "DEP:\n" for dep in self.dependencies: fmt += " " * 8 + "{}\n".format(dep) - fmt += "\tEnable_index_mapping={}\n".format( - self.enable_index_mapping) + fmt += "\tEnable_index_mapping={}, pruning_dim={}\n".format( + self.enable_index_mapping, self.pruning_dim) fmt = "-" * 32 + "\n" return fmt -class Edge(): # for readability +class Edge(): # for readability pass @@ -103,15 +113,16 @@ def __init__( self.trigger = trigger self.handler = handler self.source = source - self.target = target - self.index_mapping = [None, None] + self.target = target + # Current coordinate system => Standard coordinate system => target coordinate system + # index_mapping[0] index_mapping[1] + self.index_mapping = [None, None] def __call__(self, idxs: list): - self.handler.__self__.pruning_dim = self.target.pruning_dim - result = self.handler( - self.target.module, - idxs, - ) + self.handler.__self__.pruning_dim = self.target.pruning_dim # set pruning_dim + if len(idxs)>0 and isinstance(idxs[0], _HybridIndex): + idxs = _helpers.to_plain_idxs(idxs) + result = self.handler(self.target.module, idxs) return result def __repr__(self): @@ -130,41 +141,47 @@ def is_triggered_by(self, pruning_fn): def __eq__(self, other): return ( - self.source == other.source + self.source == other.source and self.trigger == other.trigger and self.handler == other.handler and self.target == other.target ) + + @property + def layer(self): + return self.target.module + + @property + def pruning_fn(self): + return self.handler def __hash__(self): return hash((self.source, self.target, self.trigger, self.handler)) -GroupItem = namedtuple('GroupItem', ['dep', 'idxs']) - - class Group(object): """A group that contains dependencies and pruning indices. Each element is defined as a namedtuple('GroupItem', ['dep', 'idxs']). - A group is a iterable list + A group is a iterable List just like [ [Dep1, Indices1], [Dep2, Indices2], ..., [DepK, IndicesK] ] """ def __init__(self): self._group = list() - self._DG = None # for group.prune(idxs=NEW_IDXS) + self._DG = None # link to the DependencyGraph that produces this group. Will be filled by DependencyGraph.get_pruning_group. def prune(self, idxs=None, record_history=True): """Prune all coupled layers in the group """ - if idxs is not None: + if idxs is not None: # prune the group with the specified indices module = self._group[0].dep.target.module pruning_fn = self._group[0].dep.handler - new_group = self._DG.get_pruning_group(module, pruning_fn, idxs) + new_group = self._DG.get_pruning_group(module, pruning_fn, idxs) # create a new group with the specified indices new_group.prune() else: for dep, idxs in self._group: - if dep.target.type == ops.OPTYPE.PARAMETER: + if dep.target.type == ops.OPTYPE.PARAMETER: + # prune unwrapped nn.Parameter old_parameter = dep.target.module name = self._DG._param_to_name[old_parameter] self._DG._param_to_name.pop(old_parameter) @@ -177,19 +194,23 @@ def prune(self, idxs=None, record_history=True): self._DG._param_to_name[pruned_parameter] = name self._DG.module2node[pruned_parameter] = self._DG.module2node.pop(old_parameter) self._DG.module2node[pruned_parameter].module = pruned_parameter - else: + else: # prune nn.Module dep(idxs) + if record_history: root_module, pruning_fn, root_pruning_idx = self[0][0].target.module, self[0][0].trigger, self[0][1] root_module_name = self._DG._module2name[root_module] self._DG._pruning_history.append([root_module_name, self._DG.is_out_channel_pruning_fn(pruning_fn), root_pruning_idx]) - + def add_dep(self, dep, idxs): self._group.append(GroupItem(dep=dep, idxs=idxs)) def __getitem__(self, k): return self._group[k] + def __setitem__(self, k, v): + self._group[k] = v + @property def items(self): return self._group @@ -200,8 +221,9 @@ def has_dep(self, dep): return True return False - def has_pruning_op(self, dep, idxs): + def has_pruning_op(self, dep: Dependency, idxs: _HybridIndex): for _dep, _idxs in self._group: + #_idxs = _helpers.to_plain_idxs(_idxs) if ( _dep.target == dep.target and _dep.handler == dep.handler @@ -216,7 +238,7 @@ def __len__(self): def add_and_merge(self, dep, idxs): for i, (_dep, _idxs) in enumerate(self._group): if _dep.target == dep.target and _dep.handler == dep.handler: - self._group[i] = (_dep, list(set(_idxs + idxs))) + self._group[i] = GroupItem(dep=_dep, idxs=list(set(_idxs + idxs))) return self.add_dep(dep, idxs) @@ -244,13 +266,13 @@ def details(self): return fmt def exec(self): - """old interface, replaced by group.prune()""" + """old interface, will be deprecated in the future.""" + warnings.warn("Group.exec() will be deprecated in the future. Please use Group.prune() instead.") self.prune() def __call__(self): return self.prune() -UnwrappedParameters = namedtuple('UnwrappedParameters', ['parameters', 'pruning_dim']) class DependencyGraph(object): @@ -260,25 +282,26 @@ def __init__(self): ops.OPTYPE.SPLIT: ops.SplitPruner(), ops.OPTYPE.ELEMENTWISE: ops.ElementWisePruner(), ops.OPTYPE.RESHAPE: ops.ReshapePruner(), - ops.OPTYPE.CUSTOMIZED: None, + ops.OPTYPE.CUSTOMIZED: ops.CustomizedPruner(), # just a placeholder } self.REGISTERED_PRUNERS = function.PrunerBox.copy() # shallow copy - self.REGISTERED_PRUNERS.update(_dummy_pruners) - self.CUSTOMIZED_PRUNERS = {} + self.REGISTERED_PRUNERS.update(_dummy_pruners) # merge dummy pruners + self.CUSTOMIZED_PRUNERS = {} # user-customized pruners self.IGNORED_LAYERS = [] - # cache + # cache pruning functions for fast lookup self._in_channel_pruning_fn = set([p.prune_in_channels for p in self.REGISTERED_PRUNERS.values() if p is not None] + [p.prune_in_channels for p in self.CUSTOMIZED_PRUNERS.values() if p is not None]) self._out_channel_pruning_fn = set([p.prune_out_channels for p in self.REGISTERED_PRUNERS.values() if p is not None] + [p.prune_out_channels for p in self.CUSTOMIZED_PRUNERS.values() if p is not None]) - self._op_id = 0 + self._op_id = 0 # operatior id # Pruning History self._pruning_history = [] - def pruning_history(self): + def pruning_history(self) -> typing.List[typing.Tuple[str, bool, typing.Union[list, tuple]]]: return self._pruning_history def load_pruning_history(self, pruning_history): + """Redo the pruning history""" self._pruning_history = pruning_history for module_name, is_out_channel_pruning, pruning_idx in self._pruning_history: module = self.model @@ -295,79 +318,49 @@ def load_pruning_history(self, pruning_history): def build_dependency( self, model: torch.nn.Module, - example_inputs: typing.Union[torch.Tensor, typing.Sequence], - forward_fn: typing.Callable[[ - torch.nn.Module, typing.Union[torch.Tensor, typing.Sequence]], torch.Tensor] = None, + example_inputs: typing.Union[torch.Tensor, typing.Sequence, typing.Dict], + forward_fn: typing.Callable[[torch.nn.Module, typing.Union[torch.Tensor, typing.Sequence]], torch.Tensor] = None, output_transform: typing.Callable = None, unwrapped_parameters: typing.Dict[nn.Parameter, int] = None, - customized_pruners: typing.Dict[typing.Any, - function.BasePruningFunc] = None, + customized_pruners: typing.Dict[ typing.Union[typing.Any, torch.nn.Module],function.BasePruningFunc] = None, verbose: bool = True, - ): + ) -> "DependencyGraph": """Build a dependency graph through tracing. Args: model (class): the model to be pruned. example_inputs (torch.Tensor or List): dummy inputs for tracing. - forward_fn (Callable): a function to run the model with example_inputs, which should return a reduced tensor for backpropagation. + forward_fn (Callable): a function to forward the model with example_inputs, which should return a reduced scalr tensor for backpropagation. output_transform (Callable): a function to transform network outputs. - unwrapped_parameters (List): unwrapped nn.parameters defined by parameters. - customized_pruners (typing.Dict[typing.Any, function.BasePruningFunc]): pruners for customized layers. + unwrapped_parameters (typing.Dict[nn.Parameter, int]): unwrapped nn.parameters that do not belong to standard nn.Module. + customized_pruners (typing.Dict[ typing.Union[typing.Any, torch.nn.Module],function.BasePruningFunc]): customized pruners for a specific layer type or a specific layer instance. verbose (bool): verbose mode. """ self.verbose = verbose self.model = model - self._module2name = {module: name for ( - name, module) in model.named_modules()} + self._module2name = {module: name for (name, module) in model.named_modules()} # nn.Module => module name # Register customized pruners if customized_pruners is not None: - for customized_module, customized_pruner in customized_pruners.items(): - self.register_customized_layer( - customized_module, customized_pruner) - - # Ignore all sub-modules of customized layers - for layer_type in self.CUSTOMIZED_PRUNERS.keys(): + for customized_type, customized_pruner in customized_pruners.items(): + self.register_customized_layer(customized_type, customized_pruner) + + # Ignore all sub-modules of customized layers as they will be handled by the customized pruners + for layer_type_or_instance in self.CUSTOMIZED_PRUNERS.keys(): for m in self.model.modules(): - if isinstance(m, layer_type): - for sub_module in m.modules(): + # a layer instance or a layer type + if (m==layer_type_or_instance) or (not isinstance(layer_type_or_instance, torch.nn.Module) and isinstance(m, layer_type_or_instance)): + for sub_module in m.modules(): if sub_module != m: self.IGNORED_LAYERS.append(sub_module) # Detect unwrapped nn.parameters - wrapped_parameters = [] - prunable_module_types = self.REGISTERED_PRUNERS.keys() - for m in self.model.modules(): - op_type = ops.module2type(m) - if ( op_type in prunable_module_types and op_type!=ops.OPTYPE.ELEMENTWISE ) or m.__class__ in self.CUSTOMIZED_PRUNERS.keys(): - wrapped_parameters.extend(list(m.parameters())) - unwrapped_detected = [] - _param_to_name = {} - for name, p in self.model.named_parameters(): - is_wrapped = False - for p_wrapped in wrapped_parameters: - if p is p_wrapped: - is_wrapped = True - break - if not is_wrapped: - unwrapped_detected.append(p) - _param_to_name[p] = name - if unwrapped_parameters is None: - unwrapped_parameters = [] - self._param_to_name = _param_to_name - unwrapped_detected = list( set(unwrapped_detected) - set([p for (p, _) in unwrapped_parameters]) ) - if len(unwrapped_detected)>0 and self.verbose: - warnings.warn("Unwrapped parameters detected: {}.\n Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please provide an unwrapped_parameters argument.".format([_param_to_name[p] for p in unwrapped_detected])) - for p in unwrapped_detected: - # get the last dimension that >1 - def last_non_singleton_dim(tensor): - non_singleton_dims = [i for i, s in enumerate(tensor.shape) if s > 1] - return non_singleton_dims[-1] if non_singleton_dims else None - pruning_dim = last_non_singleton_dim(p) - if pruning_dim is not None: - unwrapped_parameters.append( UnwrappedParameters(parameters=p, pruning_dim=pruning_dim) ) # prune the last non-singleton dim by daufault - self.unwrapped_parameters = unwrapped_parameters - # Build computational graph by tracing. + self._param_to_name, self.unwrapped_parameters = self._detect_unwrapped_parameters(unwrapped_parameters) + + # Detect torch.no_grad() + assert torch.is_grad_enabled(), "Dependency graph relies on backward. Please enable gradient computation." + + # Build computational graph through tracing. self.module2node = self._trace( model, example_inputs, forward_fn, output_transform=output_transform ) @@ -384,7 +377,7 @@ def last_non_singleton_dim(tensor): def register_customized_layer( self, - layer_type: typing.Type, + layer_type_or_instance: typing.Union[typing.Any, torch.nn.Module], layer_pruner: function.BasePruningFunc, ): """Register a customized pruner @@ -392,7 +385,8 @@ def register_customized_layer( layer_type (class): the type of target layer pruner (tp.pruner.BasePruningFunc): a pruner for the specified layer type. """ - self.CUSTOMIZED_PRUNERS[layer_type] = layer_pruner + self.CUSTOMIZED_PRUNERS[layer_type_or_instance] = layer_pruner + # Update cache self._in_channel_pruning_fn = set([p.prune_in_channels for p in self.REGISTERED_PRUNERS.values() if p is not None] + [p.prune_in_channels for p in self.CUSTOMIZED_PRUNERS.values() if p is not None]) self._out_channel_pruning_fn = set([p.prune_out_channels for p in self.REGISTERED_PRUNERS.values() if p is not None] + [p.prune_out_channels for p in self.CUSTOMIZED_PRUNERS.values() if p is not None]) @@ -425,22 +419,18 @@ def is_out_channel_pruning_fn(self, fn: typing.Callable) -> bool: def is_in_channel_pruning_fn(self, fn: typing.Callable) -> bool: return (fn in self._in_channel_pruning_fn) - def get_pruning_plan(self, module: nn.Module, pruning_fn: typing.Callable, idxs: typing.Union[list, tuple]) -> Group: - """ An alias of DependencyGraph.get_pruning_group for compatibility. - """ - return self.get_pruning_group(module, pruning_fn, idxs) - def get_pruning_group( self, module: nn.Module, pruning_fn: typing.Callable, - idxs: typing.Union[list, tuple], + idxs: typing.Sequence[int], ) -> Group: """Get the pruning group of pruning_fn. Args: module (nn.Module): the to-be-pruned module/layer. pruning_fn (Callable): the pruning function. idxs (list or tuple): the indices of channels/dimensions. + grouped_idxs (bool): whether the indices are grouped. If True, idxs is a list of list, e.g., [[0,1,2], [3,4,5]], where each sublist is a group. """ if module not in self.module2node: raise ValueError( @@ -450,18 +440,22 @@ def get_pruning_group( pruning_fn = function.prune_depthwise_conv_out_channels if isinstance(idxs, Number): idxs = [idxs] + + idxs = [ _HybridIndex(idx=i, root_idx=i) for i in idxs ] # idxs == root_idxs for the root layer self.update_index_mapping() group = Group() + # the user pruning operation root_node = self.module2node[module] group.add_dep( - Dependency(pruning_fn, pruning_fn, - source=root_node, target=root_node), idxs + dep=Dependency(pruning_fn, pruning_fn, source=root_node, target=root_node), + idxs=idxs, ) + visited_node = set() - def _fix_dependency_graph_non_recursive(dep, idxs): + def _fix_dependency_graph_non_recursive(dep, idxs, *args): processing_stack = [(dep, idxs)] while len(processing_stack) > 0: dep, idxs = processing_stack.pop(-1) @@ -475,8 +469,8 @@ def _fix_dependency_graph_non_recursive(dep, idxs): for mapping in new_dep.index_mapping: if mapping is not None: new_indices = mapping(new_indices) - #print(new_dep, new_dep.index_mapping) - #print(len(new_indices), new_indices) + + #print(len(new_indices)) #print() if len(new_indices) == 0: continue @@ -497,11 +491,18 @@ def _fix_dependency_graph_non_recursive(dep, idxs): for dep, idxs in group.items: merged_group.add_and_merge(dep, idxs) merged_group._DG = self + for i in range(len(merged_group)): + hybrid_idxs = merged_group[i].idxs + idxs = _helpers.to_plain_idxs(hybrid_idxs) + root_idxs = _helpers.to_root_idxs(hybrid_idxs) + merged_group[i] = GroupItem(merged_group[i].dep, idxs) # transform _HybridIndex to plain index + merged_group[i].root_idxs = root_idxs return merged_group def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, ops.TORCH_LINEAR)): visited_layers = [] ignored_layers = ignored_layers+self.IGNORED_LAYERS + for m in list(self.module2node.keys()): if m in ignored_layers: continue @@ -519,6 +520,7 @@ def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, o layer_channels = pruner.get_out_channels(m) group = self.get_pruning_group( m, pruner.prune_out_channels, list(range(layer_channels))) + prunable_group = True for dep, _ in group: module = dep.target.module @@ -530,10 +532,10 @@ def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, o if prunable_group: yield group - def get_pruner_of_module(self, module): - p = self.CUSTOMIZED_PRUNERS.get(module.__class__, None) + def get_pruner_of_module(self, module: nn.Module): + p = self.CUSTOMIZED_PRUNERS.get(module.__class__, None) # customized pruners for a specific layer type if p is None: - p = self.REGISTERED_PRUNERS.get(ops.module2type(module), None) + p = self.REGISTERED_PRUNERS.get(ops.module2type(module), None) # standard pruners return p def get_out_channels(self, module_or_node): @@ -602,7 +604,46 @@ def _infer_in_channels_recursively(self, node: Node): if ch == 0: return None return ch - + + def _detect_unwrapped_parameters(self, unwrapped_parameters): + # Detect wrapped nn.Parameters + wrapped_parameters = [] + prunable_module_types = self.REGISTERED_PRUNERS.keys() + for m in self.model.modules(): + op_type = ops.module2type(m) + if ( op_type in prunable_module_types and op_type!=ops.OPTYPE.ELEMENTWISE ) or m.__class__ in self.CUSTOMIZED_PRUNERS.keys() or m in self.CUSTOMIZED_PRUNERS.keys(): + wrapped_parameters.extend(list(m.parameters())) + + # Detect unwrapped nn.Parameters + unwrapped_detected = [] + _param_to_name = {} + for name, p in self.model.named_parameters(): + is_wrapped = False + for p_wrapped in wrapped_parameters: + if p is p_wrapped: + is_wrapped = True + break + if not is_wrapped: + unwrapped_detected.append(p) + _param_to_name[p] = name + if unwrapped_parameters is None: + unwrapped_parameters = [] + unwrapped_detected = list( set(unwrapped_detected) - set([p for (p, _) in unwrapped_parameters]) ) + if len(unwrapped_detected)>0 and self.verbose: + warning_str = "Unwrapped parameters detected: {}.\n Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please provide an unwrapped_parameters argument.".format([_param_to_name[p] for p in unwrapped_detected]) + warnings.warn(warning_str) + + # set default pruning dim for unwrapped parameters + for p in unwrapped_detected: + # get the last dimension that >1 + def last_non_singleton_dim(tensor): + non_singleton_dims = [i for i, s in enumerate(tensor.shape) if s > 1] + return non_singleton_dims[-1] if non_singleton_dims else None + pruning_dim = last_non_singleton_dim(p) + if pruning_dim is not None: + unwrapped_parameters.append( UnwrappedParameters(parameters=p, pruning_dim=pruning_dim) ) # prune the last non-singleton dim by daufault + return _param_to_name, unwrapped_parameters + def _build_dependency(self, module2node): for _, node in module2node.items(): @@ -654,6 +695,7 @@ def _record_grad_fn(module, inputs, outputs): outputs = outputs.data gradfn2module[outputs.grad_fn] = module + # Register hooks for prunable modules registered_types = tuple(ops.type2class( t) for t in self.REGISTERED_PRUNERS.keys()) + tuple(self.CUSTOMIZED_PRUNERS.keys()) hooks = [ @@ -661,8 +703,8 @@ def _record_grad_fn(module, inputs, outputs): for m in model.modules() if (isinstance(m, registered_types) and m not in self.IGNORED_LAYERS) ] - - # Feed forward and record gradient functions of prunable modules + + # Feed forward to record gradient functions of prunable modules if forward_fn is not None: out = forward_fn(model, example_inputs) elif isinstance(example_inputs, dict): @@ -672,17 +714,16 @@ def _record_grad_fn(module, inputs, outputs): out = model(*example_inputs) except: out = model(example_inputs) - for hook in hooks: hook.remove() + # for recursive models or layers reused = [m for (m, count) in visited.items() if count > 1] - # build graph + # Graph tracing if output_transform is not None: out = output_transform(out) - - module2node = {} + module2node = {} # create a mapping from nn.Module to tp.dependency.Node for o in utils.flatten_as_list(out): self._trace_computational_graph( module2node, o.grad_fn, gradfn2module, reused) @@ -832,13 +873,14 @@ def _init_shape_information(self): if node.type == ops.OPTYPE.SPLIT: grad_fn = node.grad_fn if hasattr(grad_fn, '_saved_self_sizes'): + MAX_LEGAL_DIM = 100 if hasattr(grad_fn, '_saved_split_sizes') and hasattr(grad_fn, '_saved_dim') : - if grad_fn._saved_dim != 1: + if grad_fn._saved_dim != 1 and grad_fn._saved_dim < MAX_LEGAL_DIM: # a temp fix for pytorch==1.11, where the _saved_dim is an uninitialized value like 118745347895359 continue chs = list(grad_fn._saved_split_sizes) node.module.split_sizes = chs elif hasattr(grad_fn, '_saved_split_size') and hasattr(grad_fn, '_saved_dim'): - if grad_fn._saved_dim != 1: + if grad_fn._saved_dim != 1 and grad_fn._saved_dim < MAX_LEGAL_DIM: # a temp fix for pytorch==1.11, where the _saved_dim is an uninitialized value like 118745347895359 continue chs = [grad_fn._saved_split_size for _ in range(len(node.outputs))] node.module.split_sizes = chs @@ -956,7 +998,7 @@ def _update_concat_index_mapping(self, cat_node: Node): else: chs = [] for n in cat_node.inputs: - chs.append(self.infer_channels(n, cat_node)) + chs.append(self.infer_channels_between(n, cat_node)) cat_node.module.concat_sizes = chs offsets = [0] @@ -1025,7 +1067,7 @@ def _update_split_index_mapping(self, split_node: Node): addressed_dep.append(dep) break - def infer_channels(self, node_1, node_2): + def infer_channels_between(self, node_1, node_2): if node_1.type == ops.OPTYPE.SPLIT: for i, n in enumerate(node_1.outputs): if n == node_2: diff --git a/torch_pruning/ops.py b/torch_pruning/ops.py index 0215e6e..acba6b4 100644 --- a/torch_pruning/ops.py +++ b/torch_pruning/ops.py @@ -128,6 +128,9 @@ class ReshapePruner(DummyPruner): class ElementWisePruner(DummyPruner): pass +class CustomizedPruner(DummyPruner): + pass + # Standard Modules TORCH_CONV = nn.modules.conv._ConvNd diff --git a/torch_pruning/pruner/__init__.py b/torch_pruning/pruner/__init__.py index d8419b8..1fcd556 100644 --- a/torch_pruning/pruner/__init__.py +++ b/torch_pruning/pruner/__init__.py @@ -1,2 +1,3 @@ from .function import * -from .algorithms import * \ No newline at end of file +from .algorithms import * +from . import importance \ No newline at end of file diff --git a/torch_pruning/pruner/algorithms/__init__.py b/torch_pruning/pruner/algorithms/__init__.py index 203adb2..af20c98 100644 --- a/torch_pruning/pruner/algorithms/__init__.py +++ b/torch_pruning/pruner/algorithms/__init__.py @@ -1,4 +1,5 @@ from .metapruner import MetaPruner from .magnitude_based_pruner import MagnitudePruner from .batchnorm_scale_pruner import BNScalePruner -from .group_norm_pruner import GroupNormPruner \ No newline at end of file +from .group_norm_pruner import GroupNormPruner +from .growing_reg_pruner import GrowingRegPruner \ No newline at end of file diff --git a/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py b/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py index 31a9d79..dd28a2d 100644 --- a/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py +++ b/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py @@ -4,6 +4,9 @@ from .scheduler import linear_scheduler import torch import torch.nn as nn +import math + +from ..importance import MagnitudeImportance class BNScalePruner(MetaPruner): def __init__( @@ -12,6 +15,7 @@ def __init__( example_inputs, importance, reg=1e-5, + group_lasso=False, iterative_steps=1, iterative_sparsity_scheduler: Callable = linear_scheduler, ch_sparsity=0.5, @@ -41,8 +45,25 @@ def __init__( output_transform=output_transform, ) self.reg = reg + self._groups = list(self.DG.get_all_groups()) + self.group_lasso = True + if self.group_lasso: + self._l2_imp = MagnitudeImportance(p=2, group_reduction='mean', normalizer=None, target_types=[nn.modules.batchnorm._BatchNorm]) + + def regularize(self, model, reg=None): + if reg is None: + reg = self.reg # use the default reg - def regularize(self, model): - for m in model.modules(): - if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and m.affine==True and m not in self.ignored_layers: - m.weight.grad.data.add_(self.reg*torch.sign(m.weight.data)) + if self.group_lasso==False: + for m in model.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and m.affine==True and m not in self.ignored_layers: + m.weight.grad.data.add_(reg*torch.sign(m.weight.data)) + else: + for group in self._groups: + group_l2norm_sq, group_size = self._l2_imp(group, return_group_size=True) + if group_l2norm_sq is None: + continue + for dep, _ in group: + layer = dep.layer + if isinstance(layer, nn.modules.batchnorm._BatchNorm) and layer.affine==True and layer not in self.ignored_layers: + layer.weight.grad.data.add_(reg * math.sqrt(group_size) * (1 / group_l2norm_sq.sqrt()) * layer.weight.data) # Group Lasso https://tibshirani.su.domains/ftp/sparse-grlasso.pdf \ No newline at end of file diff --git a/torch_pruning/pruner/algorithms/group_norm_pruner.py b/torch_pruning/pruner/algorithms/group_norm_pruner.py index 299e4a1..0309807 100644 --- a/torch_pruning/pruner/algorithms/group_norm_pruner.py +++ b/torch_pruning/pruner/algorithms/group_norm_pruner.py @@ -7,6 +7,8 @@ class GroupNormPruner(MetaPruner): + """ Only for reproducing our results in the paper. Not recommended for practical use. Please refer to MagnitudePruner for a general implementation of magnitude-based pruning. + """ def __init__( self, model, @@ -158,7 +160,7 @@ def regularize(self, model, base=16): function.prune_linear_in_channels, ]: gn = group_norm - if hasattr(dep.target, 'index_transform') and isinstance(dep.target.index_transform, _FlattenIndexTransform): + if hasattr(dep.target, 'index_transform') and isinstance(dep.target.index_transform, _FlattenIndexMapping): gn = group_norm.repeat_interleave(w.shape[1]//group_norm.shape[0]) # regularize input channels if prune_fn==function.prune_conv_in_channels and layer.groups>1: diff --git a/torch_pruning/pruner/algorithms/growing_reg_pruner.py b/torch_pruning/pruner/algorithms/growing_reg_pruner.py new file mode 100644 index 0000000..272b544 --- /dev/null +++ b/torch_pruning/pruner/algorithms/growing_reg_pruner.py @@ -0,0 +1,89 @@ +from .metapruner import MetaPruner +from .scheduler import linear_scheduler +import typing +import torch +import torch.nn as nn + +from ..importance import MagnitudeImportance, GroupNormImportance +from .. import function +import math + +class GrowingRegPruner(MetaPruner): + def __init__( + self, + model, + example_inputs, + importance, + reg=1e-5, + delta_reg = 1e-5, + iterative_steps=1, + iterative_sparsity_scheduler: typing.Callable = linear_scheduler, + ch_sparsity=0.5, + ch_sparsity_dict=None, + global_pruning=False, + max_ch_sparsity=1.0, + round_to=None, + ignored_layers=None, + customized_pruners=None, + unwrapped_parameters=None, + output_transform=None, + target_types=[nn.modules.conv._ConvNd, nn.Linear, nn.modules.batchnorm._BatchNorm], + ): + super(GrowingRegPruner, self).__init__( + model=model, + example_inputs=example_inputs, + importance=importance, + iterative_steps=iterative_steps, + iterative_sparsity_scheduler=iterative_sparsity_scheduler, + ch_sparsity=ch_sparsity, + ch_sparsity_dict=ch_sparsity_dict, + global_pruning=global_pruning, + max_ch_sparsity=max_ch_sparsity, + round_to=round_to, + ignored_layers=ignored_layers, + customized_pruners=customized_pruners, + unwrapped_parameters=unwrapped_parameters, + output_transform=output_transform, + ) + self.base_reg = reg + self._groups = list(self.DG.get_all_groups()) + self.group_lasso = True + self._l2_imp = GroupNormImportance() + + group_reg = {} + for group in self._groups: + group_reg[group] = torch.ones( len(group[0].idxs) ) * self.base_reg + self.group_reg = group_reg + self.delta_reg = delta_reg + + def update_reg(self): + for group in self._groups: + group_l2norm_sq = self._l2_imp(group) + if group_l2norm_sq is None: + continue + reg = self.group_reg[group] + standarized_imp = (group_l2norm_sq.max() - group_l2norm_sq) / (group_l2norm_sq.max() - group_l2norm_sq.min() + 1e-8) + reg = reg + self.delta_reg * standarized_imp.to(reg.device) + self.group_reg[group] = reg + + def regularize(self, model): + for i, group in enumerate(self._groups): + group_l2norm_sq = self._l2_imp(group) + if group_l2norm_sq is None: + continue + + reg = self.group_reg[group] + for dep, idxs in group: + layer = dep.layer + pruning_fn = dep.pruning_fn + if isinstance(layer, nn.modules.batchnorm._BatchNorm) and layer.affine==True and layer not in self.ignored_layers: + layer.weight.grad.data.add_(reg.to(layer.weight.device) * layer.weight.data) + elif isinstance(layer, (nn.modules.conv._ConvNd, nn.Linear)) and layer not in self.ignored_layers: + if pruning_fn in [function.prune_conv_out_channels, function.prune_linear_out_channels]: + w = layer.weight.data[idxs] + g = w * reg.to(layer.weight.device).view( -1, *([1]*(len(w.shape)-1)) ) #/ group_norm.view( -1, *([1]*(len(w.shape)-1)) ) * group_size #group_size #* scale.view( -1, *([1]*(len(w.shape)-1)) ) + layer.weight.grad.data[idxs]+= g + elif pruning_fn in [function.prune_conv_in_channels, function.prune_linear_in_channels]: + w = layer.weight.data[:, idxs] + g = w * reg.to(layer.weight.device).view( 1, -1, *([1]*(len(w.shape)-2)) ) #/ gn.view( 1, -1, *([1]*(len(w.shape)-2)) ) * group_size #* scale.view( 1, -1, *([1]*(len(w.shape)-2)) ) + layer.weight.grad.data[:, idxs]+=g \ No newline at end of file diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index e991d83..c225b19 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -9,54 +9,53 @@ class MetaPruner: """ - Meta Pruner for structural pruning. + Meta pruner for structural pruning. Args: - model (nn.Module): A to-be-pruned model - example_inputs (torch.Tensor or List): dummy inputs for graph tracing. - importance (Callable): importance estimator. - global_pruning (bool): enable global pruning. - ch_sparsity (float): global channel sparisty. - ch_sparsity_dict (Dict[nn.Module, float]): layer-specific sparsity. - iterative_steps (int): number of steps for iterative pruning. - iterative_sparsity_scheduler (Callable): scheduler for iterative pruning. - max_ch_sparsity (float): maximum channel sparsity. - ignored_layers (List[nn.Module]): ignored modules. - round_to (int): channel rounding. - customized_pruners (dict): a dict containing module-pruner pairs. - unwrapped_parameters (list): nn.Parameter that does not belong to any supported layerss. - root_module_types (list): types of prunable modules. - output_transform (Callable): A function to transform network outputs. + # Basic + * model (nn.Module): A to-be-pruned model + * example_inputs (torch.Tensor or List): dummy inputs for graph tracing. + * importance (Callable): importance estimator. + * global_pruning (bool): enable global pruning. Default: False. + * ch_sparsity (float): global channel sparisty. Default: 0.5. + * ch_sparsity_dict (Dict[nn.Module, float]): layer-specific sparsity. Will cover ch_sparsity if specified. Default: None. + * max_ch_sparsity (float): maximum channel sparsity. Default: 1.0. + * iterative_steps (int): number of steps for iterative pruning. Default: 1. + * iterative_sparsity_scheduler (Callable): scheduler for iterative pruning. Default: linear_scheduler. + * ignored_layers (List[nn.Module | typing.Type]): ignored modules. Default: None. + * round_to (int): channel rounding. E.g., round_to=8 means channels will be rounded to 8x. Default: None. + + # Adavanced + * customized_pruners (dict): a dict containing module-pruner pairs. Default: None. + * unwrapped_parameters (dict): a dict containing unwrapped parameters & pruning dims. Default: None. + * root_module_types (list): types of prunable modules. Default: [nn.Conv2d, nn.Linear, nn.LSTM]. + * forward_fn (Callable): A function to execute model.forward. Default: None. + * output_transform (Callable): A function to transform network outputs. Default: None. """ def __init__( self, # Basic - model: nn.Module, - example_inputs: torch.Tensor, - importance: typing.Callable, - # https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#global-pruning. - global_pruning: bool = False, + model: nn.Module, # a simple pytorch model + example_inputs: torch.Tensor, # a dummy input for graph tracing. Should be on the same + importance: typing.Callable, # tp.importance.Importance for group importance estimation + global_pruning: bool = False, # https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#global-pruning. ch_sparsity: float = 0.5, # channel/dim sparsity - ch_sparsity_dict: typing.Dict[nn.Module, float] = None, - max_ch_sparsity: float = 1.0, + ch_sparsity_dict: typing.Dict[nn.Module, float] = None, # layer-specific sparsity, will cover ch_sparsity if specified + max_ch_sparsity: float = 1.0, # maximum sparsity. useful if over-pruning happens. iterative_steps: int = 1, # for iterative pruning - iterative_sparsity_scheduler: typing.Callable = linear_scheduler, - ignored_layers: typing.List[nn.Module] = None, + iterative_sparsity_scheduler: typing.Callable = linear_scheduler, # scheduler for iterative pruning. + ignored_layers: typing.List[nn.Module] = None, # ignored layers + round_to: int = None, # round channels to a multiple of round_to # Advanced - round_to: int = None, # round channels to 8x, 16x, ... - # for grouped channels. - channel_groups: typing.Dict[nn.Module, int] = dict(), - # pruners for customized layers - customized_pruners: typing.Dict[typing.Any, - function.BasePruningFunc] = None, - # unwrapped nn.Parameters like ViT.pos_emb - unwrapped_parameters: typing.List[nn.Parameter] = None, - root_module_types: typing.List = [ - ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM], # root module for each group - output_transform: typing.Callable = None, + channel_groups: typing.Dict[nn.Module, int] = dict(), # channel groups for layers like group convs & group norms + customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None, # pruners for customized layers. E.g., {nn.Linear: my_linear_pruner} + unwrapped_parameters: typing.Dict[nn.Parameter, int] = None, # unwrapped nn.Parameters & pruning_dims. For example, {ViT.pos_emb: 0} + root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM], # root module for each group + forward_fn: typing.Callable = None, # a function to execute model.forward + output_transform: typing.Callable = None, # a function to transform network outputs ): self.model = model self.importance = importance @@ -64,7 +63,6 @@ def __init__( self.ch_sparsity_dict = ch_sparsity_dict if ch_sparsity_dict is not None else {} self.max_ch_sparsity = max_ch_sparsity self.global_pruning = global_pruning - self.channel_groups = channel_groups self.root_module_types = root_module_types self.round_to = round_to @@ -73,21 +71,26 @@ def __init__( self.DG = dependency.DependencyGraph().build_dependency( model, example_inputs=example_inputs, + forward_fn=forward_fn, output_transform=output_transform, unwrapped_parameters=unwrapped_parameters, customized_pruners=customized_pruners, ) + # Ignored layers self.ignored_layers = [] if ignored_layers: for layer in ignored_layers: self.ignored_layers.extend(list(layer.modules())) + # Iterative pruning + # The pruner will prune the model iteratively for several steps to achieve the target sparsity + # E.g., if iterative_steps=5, ch_sparsity=0.5, the sparsity of each step will be [0.1, 0.2, 0.3, 0.4, 0.5] self.iterative_steps = iterative_steps self.iterative_sparsity_scheduler = iterative_sparsity_scheduler self.current_step = 0 - # Record initial status + # initial channels/dims for each layer self.layer_init_out_ch = {} self.layer_init_in_ch = {} for m in self.DG.module2node.keys(): @@ -95,12 +98,12 @@ def __init__( self.layer_init_out_ch[m] = self.DG.get_out_channels(m) self.layer_init_in_ch[m] = self.DG.get_in_channels(m) - # global channel sparsity for each iterative step + # channel sparsity for each iterative step self.per_step_ch_sparsity = self.iterative_sparsity_scheduler( self.ch_sparsity, self.iterative_steps ) - # The customized channel sparsity for different layers + # The layer-specific sparsity will cover the global sparsity if specified self.ch_sparsity_dict = {} if ch_sparsity_dict is not None: for module in ch_sparsity_dict: @@ -112,7 +115,7 @@ def __init__( self.ch_sparsity_dict[submodule] = self.iterative_sparsity_scheduler( sparsity, self.iterative_steps ) - + # detect group convs & group norms for m in self.model.modules(): if isinstance(m, ops.TORCH_CONV) \ @@ -122,6 +125,7 @@ def __init__( if isinstance(m, ops.TORCH_GROUPNORM): self.channel_groups[m] = m.num_groups + # count the number of total channels at initialization if self.global_pruning: initial_total_channels = 0 for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): @@ -131,44 +135,37 @@ def __init__( group[0][0].target.module) // ch_groups) self.initial_total_channels = initial_total_channels - def pruning_history(self): + def pruning_history(self) -> typing.List[typing.Tuple[str, bool, typing.Union[list, tuple]]]: return self.DG.pruning_history() - def load_pruning_history(self, pruning_history): + def load_pruning_history(self, pruning_history) -> None: self.DG.load_pruning_history(pruning_history) - def get_target_sparsity(self, module): - s = self.ch_sparsity_dict.get(module, self.per_step_ch_sparsity)[ - self.current_step] + def get_target_sparsity(self, module) -> float: + s = self.ch_sparsity_dict.get(module, self.per_step_ch_sparsity)[self.current_step] return min(s, self.max_ch_sparsity) - def reset(self): + def reset(self) -> None: self.current_step = 0 - def regularize(self, model, loss): + def regularize(self, model, loss) -> typing.Any: """ Model regularizor """ pass - def step(self, interactive=False): + def step(self, interactive=False)-> typing.Union[typing.Generator, None]: self.current_step += 1 - if self.global_pruning: - if interactive: - return self.prune_global() - else: - for group in self.prune_global(): - group.prune() + pruning_fn = self.prune_global if self.global_pruning else self.prune_local + if interactive: # yield groups for interactive pruning + return pruning_fn() else: - if interactive: - return self.prune_local() - else: - for group in self.prune_local(): - group.prune() + for group in pruning_fn(): + group.prune() - def estimate_importance(self, group, ch_groups=1): + def estimate_importance(self, group, ch_groups=1) -> torch.Tensor: return self.importance(group, ch_groups=ch_groups) - def _check_sparsity(self, group): + def _check_sparsity(self, group) -> bool: for dep, _ in group: module = dep.target.module pruning_fn = dep.handler @@ -192,7 +189,7 @@ def _check_sparsity(self, group): return False return True - def get_channel_groups(self, group): + def get_channel_groups(self, group) -> int: if isinstance(self.channel_groups, int): return self.channel_groups for dep, _ in group: @@ -201,12 +198,13 @@ def get_channel_groups(self, group): return self.channel_groups[module] return 1 # no channel grouping - def prune_local(self): + def prune_local(self) -> typing.Generator: if self.current_step > self.iterative_steps: return for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): # check pruning rate if self._check_sparsity(group): + module = group[0][0].target.module pruning_fn = group[0][0].handler @@ -221,8 +219,12 @@ def prune_local(self): ) if self.round_to: - n_pruned = n_pruned - (n_pruned % self.round_to) - + rounded_channels = current_channels - n_pruned + # round to the nearest multiple of round_to + rounded_channels = rounded_channels - \ + (rounded_channels % self.round_to) + n_pruned = current_channels - rounded_channels + if n_pruned <= 0: continue if ch_groups > 1: @@ -235,10 +237,11 @@ def prune_local(self): [pruning_idxs+group_size*i for i in range(ch_groups)], 0) group = self.DG.get_pruning_group( module, pruning_fn, pruning_idxs.tolist()) + if self.DG.check_pruning_group(group): yield group - def prune_global(self): + def prune_global(self) -> typing.Generator: if self.current_step > self.iterative_steps: return global_importance = [] diff --git a/torch_pruning/pruner/function.py b/torch_pruning/pruner/function.py index d2c9f6f..c92d132 100644 --- a/torch_pruning/pruner/function.py +++ b/torch_pruning/pruner/function.py @@ -178,6 +178,7 @@ def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) -> nn.Module layer.num_features = layer.num_features-len(idxs) layer.running_mean = layer.running_mean.data[keep_idxs] layer.running_var = layer.running_var.data[keep_idxs] + if layer.affine: layer.weight = self._prune_parameter_and_grad(layer.weight, keep_idxs, 0) layer.bias = self._prune_parameter_and_grad(layer.bias, keep_idxs, 0) diff --git a/torch_pruning/importance.py b/torch_pruning/pruner/importance.py similarity index 55% rename from torch_pruning/importance.py rename to torch_pruning/pruner/importance.py index fb6c404..32dd832 100644 --- a/torch_pruning/importance.py +++ b/torch_pruning/pruner/importance.py @@ -3,25 +3,42 @@ import torch.nn as nn import typing -from .pruner import function -from ._helpers import _FlattenIndexMapping -from . import ops +from . import function +from ..dependency import Group +from .._helpers import _FlattenIndexMapping +from .. import ops import math class Importance(abc.ABC): - """ estimate the importance of a Pruning Group, and return an 1-D per-channel importance score. + """ Estimate the importance of a tp.Dependency.Group, and return an 1-D per-channel importance score. + + It should accept a group and a ch_groups as inputs, and return a 1-D tensor with the same length as the number of channels. + ch_groups refer to the number of internal groups, e.g., for a 64-channel **group conv** with groups=ch_groups=4, each group has 16 channels. + All groups must be pruned simultaneously and thus their importance should be accumulated across channel groups. + Just ignore the ch_groups if you are not familar with grouping. + + Example: + ```python + DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224)) + group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] ) + scorer = MagnitudeImportance() + imp_score = scorer(group, ch_groups=1) + #imp_score is a 1-D tensor with length 3 for channels [2, 6, 9] + min_score = imp_score.min() + ``` """ @abc.abstractclassmethod - def __call__(self, group) -> torch.Tensor: + def __call__(self, group: Group, ch_groups: int=1) -> torch.Tensor: raise NotImplementedError class MagnitudeImportance(Importance): - def __init__(self, p=2, group_reduction="mean", normalizer='mean'): + def __init__(self, p=2, group_reduction="mean", normalizer='mean', target_types=[nn.modules.conv._ConvNd, nn.Linear, nn.modules.batchnorm._BatchNorm]): self.p = p self.group_reduction = group_reduction self.normalizer = normalizer + self.target_types = target_types def _normalize(self, group_importance, normalizer): if normalizer is None: @@ -41,33 +58,56 @@ def _normalize(self, group_importance, normalizer): else: raise NotImplementedError - def _reduce(self, group_imp): - if self.group_reduction == "sum": - group_imp = group_imp.sum(dim=0) - elif self.group_reduction == "mean": - group_imp = group_imp.mean(dim=0) - elif self.group_reduction == "max": - group_imp = group_imp.max(dim=0)[0] - elif self.group_reduction == "prod": - group_imp = torch.prod(group_imp, dim=0) - elif self.group_reduction == 'first': - group_imp = group_imp[0] - elif self.group_reduction is None: - group_imp = group_imp + def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[typing.List[int]]): + if len(group_imp) == 0: return group_imp + if self.group_reduction == 'prod': + reduced_imp = torch.ones_like(group_imp[0]) + elif self.group_reduction == 'max': + reduced_imp = torch.ones_like(group_imp[0]) * -99999 else: - raise NotImplementedError - return group_imp - + reduced_imp = torch.zeros_like(group_imp[0]) + + for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)): + if self.group_reduction == "sum" or self.group_reduction == "mean": + reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp) # accumulated importance + elif self.group_reduction == "max": # keep the max importance + selected_imp = torch.index_select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device)) + selected_imp = torch.maximum(input=selected_imp, other=imp) + reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), selected_imp) + elif self.group_reduction == "prod": # product of importance + selected_imp = torch.index_select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device)) + torch.mul(selected_imp, imp, out=selected_imp) + reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), selected_imp) + elif self.group_reduction == 'first': + if i == 0: + reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), imp) + elif self.group_reduction == 'gate': + if i == len(group_imp)-1: + reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), imp) + elif self.group_reduction is None: + reduced_imp = torch.stack(group_imp, dim=0) # no reduction + else: + raise NotImplementedError + + if self.group_reduction == "mean": + reduced_imp /= len(group_imp) + return reduced_imp + @torch.no_grad() - def __call__(self, group, ch_groups=1): + def __call__(self, group: Group, ch_groups: int=1, return_group_size=False): group_imp = [] - # Get group norm - # print(group.details()) - for dep, idxs in group: - idxs.sort() - layer = dep.target.module - prune_fn = dep.handler - # Conv out_channels + group_idxs = [] + group_size = 0 + # Iterate over all groups and estimate group importance + for i, (dep, idxs) in enumerate(group): + layer = dep.layer + prune_fn = dep.pruning_fn + root_idxs = group[i].root_idxs + if not isinstance(layer, tuple(self.target_types)): + continue + #################### + # Conv/Linear Output + #################### if prune_fn in [ function.prune_conv_out_channels, function.prune_linear_out_channels, @@ -76,54 +116,63 @@ def __call__(self, group, ch_groups=1): w = layer.weight.data.transpose(1, 0)[idxs].flatten(1) else: w = layer.weight.data[idxs].flatten(1) - local_norm = w.abs().pow(self.p).sum(1) + local_imp = w.abs().pow(self.p).sum(1) + group_size += w.shape[1] if ch_groups > 1: - local_norm = local_norm.view(ch_groups, -1).sum(0) - local_norm = local_norm.repeat(ch_groups) - group_imp.append(local_norm) + local_imp = local_imp.view(ch_groups, -1).sum(0) + local_imp = local_imp.repeat(ch_groups) + group_imp.append(local_imp) + group_idxs.append(root_idxs) - # Conv in_channels + #################### + # Conv/Linear Input + #################### elif prune_fn in [ function.prune_conv_in_channels, function.prune_linear_in_channels, ]: - is_conv_flatten_linear = False if hasattr(layer, "transposed") and layer.transposed: - w = (layer.weight).flatten(1) + w = (layer.weight.data).flatten(1) else: - w = (layer.weight).transpose(0, 1).flatten(1) + w = (layer.weight.data).transpose(0, 1).flatten(1) + group_size += w.shape[1] if ch_groups > 1 and prune_fn == function.prune_conv_in_channels and layer.groups == 1: - # non-grouped conv and group convs - w = w.view(w.shape[0] // group_imp[0].shape[0], - group_imp[0].shape[0], w.shape[1]).transpose(0, 1).flatten(1) - local_norm = w.abs().pow(self.p).sum(1) + # non-grouped conv followed by a group conv + w = w.view(w.shape[0] // group_imp[0].shape[0], group_imp[0].shape[0], w.shape[1]).transpose(0, 1).flatten(1) + + local_imp = w.abs().pow(self.p).sum(1) if ch_groups > 1: - if len(local_norm) == len(group_imp[0]): - local_norm = local_norm.view(ch_groups, -1).sum(0) - local_norm = local_norm.repeat(ch_groups) - local_norm = local_norm[idxs] - group_imp.append(local_norm) - # BN + if len(local_imp) == len(group_imp[0]): + local_imp = local_imp.view(ch_groups, -1).sum(0) + local_imp = local_imp.repeat(ch_groups) + local_imp = local_imp[idxs] + group_imp.append(local_imp) + group_idxs.append(root_idxs) + + #################### + # BatchNorm + #################### elif prune_fn == function.prune_batchnorm_out_channels: # regularize BN if layer.affine: w = layer.weight.data[idxs] - local_norm = w.abs().pow(self.p) + local_imp = w.abs().pow(self.p) + group_size += 1 if ch_groups > 1: - local_norm = local_norm.view(ch_groups, -1).sum(0) - local_norm = local_norm.repeat(ch_groups) - # print(local_norm.shape) - group_imp.append(local_norm) - if len(group_imp) == 0: + local_imp = local_imp.view(ch_groups, -1).sum(0) + local_imp = local_imp.repeat(ch_groups) + group_imp.append(local_imp) + group_idxs.append(root_idxs) + #elif prune_fn == function.prune_multihead_attention_out_channels: + + if len(group_imp) == 0: # skip groups without parameterized layers + if return_group_size: + return None, 0 return None - imp_size = len(group_imp[0]) - aligned_group_imp = [] - for imp in group_imp: - if len(imp) == imp_size: - aligned_group_imp.append(imp) - group_imp = torch.stack(aligned_group_imp, dim=0) - group_imp = self._reduce(group_imp) + group_imp = self._reduce(group_imp, group_idxs) group_imp = self._normalize(group_imp, self.normalizer) + if return_group_size: + return group_imp, group_size return group_imp @@ -137,18 +186,21 @@ def __init__(self, group_reduction='mean', normalizer='mean'): def __call__(self, group, ch_groups=1): group_imp = [] - for dep, _ in group: - module = dep.target.module - if isinstance(module, (ops.TORCH_BATCHNORM)) and module.affine: - local_imp = torch.abs(module.weight.data) + group_idxs = [] + + for i, (dep, idxs) in enumerate(group): + layer = dep.layer + root_idxs = group[i].root_idxs + if isinstance(layer, (ops.TORCH_BATCHNORM)) and layer.affine: + local_imp = torch.abs(layer.weight.data)[idxs] if ch_groups > 1: local_imp = local_imp.view(ch_groups, -1).mean(0) local_imp = local_imp.repeat(ch_groups) group_imp.append(local_imp) + group_idxs.append(root_idxs) if len(group_imp) == 0: return None - group_imp = torch.stack(group_imp, dim=0) - group_imp = self._reduce(group_imp) + group_imp = self._reduce(group_imp, group_idxs) group_imp = self._normalize(group_imp, self.normalizer) return group_imp @@ -162,53 +214,8 @@ def __init__(self, p=2, group_reduction="mean", normalizer='mean'): super().__init__(p=p, group_reduction=group_reduction, normalizer=normalizer) @torch.no_grad() - def __call__(self, group, **kwargs): - group_imp = [] - for dep, idxs in group: - layer = dep.target.module - prune_fn = dep.handler - - if prune_fn in [ - function.prune_conv_out_channels, - function.prune_linear_out_channels, - ]: - if hasattr(layer, "transposed") and layer.transposed: - w = (layer.weight)[:, idxs].transpose(0, 1) - else: - w = (layer.weight)[idxs] - local_imp = torch.norm( - torch.flatten(w, 1), dim=1, p=self.p) - group_imp.append(local_imp) - - elif prune_fn in [ - function.prune_conv_in_channels, - function.prune_linear_in_channels, - ]: - if hasattr(layer, "transposed") and layer.transposed: - w = (layer.weight)[idxs].flatten(1) - else: - w = (layer.weight)[:, idxs].transpose(0, 1).flatten(1) - if ( - w.shape[0] != group_imp[0].shape[0] - ): # for conv-flatten-linear without global pooling - w = w.view( - group_imp[0].shape[0], - w.shape[0] // group_imp[0].shape[0], - w.shape[1], - ).flatten(1) - local_imp = torch.norm(w, dim=1, p=self.p) - group_imp.append(local_imp) - - elif prune_fn == function.prune_batchnorm_out_channels: - if layer.affine is not None: - w = (layer.weight)[idxs].view(-1, 1) - local_imp = torch.norm(w, dim=1, p=self.p) - group_imp.append(local_imp) - if len(group_imp) == 0: - return None - group_imp = torch.stack(group_imp, dim=0) - group_imp = self._reduce(group_imp) - group_imp = self._normalize(group_imp, self.normalizer) + def __call__(self, group, ch_groups=1): + group_imp = super().__call__(group, ch_groups) return self.lamp(group_imp) def lamp(self, imp): @@ -254,14 +261,14 @@ def __call__(self, group, ch_groups=1): w = layer.weight.data.transpose(1, 0)[idxs].flatten(1) else: w = layer.weight.data[idxs].flatten(1) - local_norm = w.abs().pow(self.p).sum(1) - #print(local_norm.shape, layer, idxs, ch_groups) + local_imp = w.abs().pow(self.p).sum(1) + #print(local_imp.shape, layer, idxs, ch_groups) if ch_groups > 1: - local_norm = local_norm.view(ch_groups, -1).sum(0) - local_norm = local_norm.repeat(ch_groups) - if group_norm is None: group_norm = local_norm - elif group_norm.shape[0] == local_norm.shape[0]: - group_norm += local_norm + local_imp = local_imp.view(ch_groups, -1).sum(0) + local_imp = local_imp.repeat(ch_groups) + if group_norm is None: group_norm = local_imp + elif group_norm.shape[0] == local_imp.shape[0]: + group_norm += local_imp # if layer.bias is not None: # group_norm += layer.bias.data[idxs].pow(2) # Conv in_channels @@ -287,33 +294,33 @@ def __call__(self, group, ch_groups=1): # non-grouped conv with group convs w = w.view(w.shape[0] // group_norm.shape[0], group_norm.shape[0], w.shape[1]).transpose(0, 1).flatten(1) - local_norm = w.abs().pow(self.p).sum(1) + local_imp = w.abs().pow(self.p).sum(1) if ch_groups > 1: - if len(local_norm) == len(group_norm): - local_norm = local_norm.view(ch_groups, -1).sum(0) - local_norm = local_norm.repeat(ch_groups) + if len(local_imp) == len(group_norm): + local_imp = local_imp.view(ch_groups, -1).sum(0) + local_imp = local_imp.repeat(ch_groups) if not is_conv_flatten_linear: - local_norm = local_norm[idxs] - if group_norm is None: group_norm = local_norm - elif group_norm.shape[0] == local_norm.shape[0]: - group_norm += local_norm + local_imp = local_imp[idxs] + if group_norm is None: group_norm = local_imp + elif group_norm.shape[0] == local_imp.shape[0]: + group_norm += local_imp # BN elif prune_fn == function.prune_batchnorm_out_channels: # regularize BN if layer.affine: w = layer.weight.data[idxs] - local_norm = w.abs().pow(self.p) + local_imp = w.abs().pow(self.p) if ch_groups > 1: - local_norm = local_norm.view(ch_groups, -1).sum(0) - local_norm = local_norm.repeat(ch_groups) - if group_norm is None: group_norm = local_norm - elif group_norm.shape[0] == local_norm.shape[0]: - group_norm += local_norm + local_imp = local_imp.view(ch_groups, -1).sum(0) + local_imp = local_imp.repeat(ch_groups) + if group_norm is None: group_norm = local_imp + elif group_norm.shape[0] == local_imp.shape[0]: + group_norm += local_imp elif prune_fn == function.prune_lstm_out_channels: _idxs = torch.tensor(idxs) - local_norm = 0 - local_norm_reverse = 0 + local_imp = 0 + local_imp_reverse = 0 num_layers = layer.num_layers expanded_idxs = torch.cat( [_idxs+i*layer.hidden_size for i in range(4)], dim=0) @@ -322,42 +329,42 @@ def __call__(self, group, ch_groups=1): else: postfix = [''] - local_norm += getattr(layer, 'weight_hh_l0')[expanded_idxs].abs().pow( + local_imp += getattr(layer, 'weight_hh_l0')[expanded_idxs].abs().pow( self.p).sum(1).view(4, -1).sum(0) - local_norm += getattr(layer, + local_imp += getattr(layer, 'weight_hh_l0')[:, _idxs].abs().pow(self.p).sum(0) - local_norm += getattr(layer, 'weight_ih_l0')[expanded_idxs].abs().pow( + local_imp += getattr(layer, 'weight_ih_l0')[expanded_idxs].abs().pow( self.p).sum(1).view(4, -1).sum(0) if layer.bidirectional: - local_norm_reverse += getattr(layer, 'weight_hh_l0')[ + local_imp_reverse += getattr(layer, 'weight_hh_l0')[ expanded_idxs].abs().pow(self.p).sum(1).view(4, -1).sum(0) - local_norm_reverse += getattr(layer, 'weight_hh_l0')[ + local_imp_reverse += getattr(layer, 'weight_hh_l0')[ :, _idxs].abs().pow(self.p).sum(0) - local_norm_reverse += getattr(layer, 'weight_ih_l0')[ + local_imp_reverse += getattr(layer, 'weight_ih_l0')[ expanded_idxs].abs().pow(self.p).sum(1).view(4, -1).sum(0) - local_norm = torch.cat( - [local_norm, local_norm_reverse], dim=0) - if group_norm is None: group_norm = local_norm - elif group_norm.shape[0] == local_norm.shape[0]: - group_norm += local_norm + local_imp = torch.cat( + [local_imp, local_imp_reverse], dim=0) + if group_norm is None: group_norm = local_imp + elif group_norm.shape[0] == local_imp.shape[0]: + group_norm += local_imp elif prune_fn == function.prune_lstm_in_channels: - local_norm = getattr(layer, 'weight_ih_l0')[ + local_imp = getattr(layer, 'weight_ih_l0')[ :, idxs].abs().pow(self.p).sum(0) if layer.bidirectional: - local_norm_reverse += getattr(layer, 'weight_ih_l0_reverse')[ + local_imp_reverse += getattr(layer, 'weight_ih_l0_reverse')[ :, idxs].abs().pow(self.p).sum(0) - local_norm = torch.cat( - [local_norm, local_norm_reverse], dim=0) - if group_norm is None: group_norm = local_norm - elif group_norm.shape[0] == local_norm.shape[0]: - group_norm += local_norm + local_imp = torch.cat( + [local_imp, local_imp_reverse], dim=0) + if group_norm is None: group_norm = local_imp + elif group_norm.shape[0] == local_imp.shape[0]: + group_norm += local_imp group_imp = group_norm**(1/self.p) group_imp = self._normalize(group_imp, self.normalizer) return group_imp -class TaylorImportance(Importance): +class TaylorImportance(MagnitudeImportance): def __init__(self, group_reduction="mean", normalizer='mean', multivariable=False): self.group_reduction = group_reduction self.normalizer = normalizer @@ -381,30 +388,15 @@ def _normalize(self, group_importance, normalizer): else: raise NotImplementedError - def _reduce(self, group_imp): - if self.group_reduction == "sum": - group_imp = group_imp.sum(dim=0) - elif self.group_reduction == "mean": - group_imp = group_imp.mean(dim=0) - elif self.group_reduction == "max": - group_imp = group_imp.max(dim=0)[0] - elif self.group_reduction == "prod": - group_imp = torch.prod(group_imp, dim=0) - elif self.group_reduction == 'first': - group_imp = group_imp[0] - elif self.group_reduction is None: - group_imp = group_imp - else: - raise NotImplementedError - return group_imp - @torch.no_grad() def __call__(self, group, ch_groups=1): group_imp = [] - for dep, idxs in group: + group_idxs = [] + for i, (dep, idxs) in enumerate(group): idxs.sort() layer = dep.target.module prune_fn = dep.handler + root_idxs = group[i].root_idxs if prune_fn in [ function.prune_conv_out_channels, @@ -422,6 +414,8 @@ def __call__(self, group, ch_groups=1): else: local_imp = (w * dw).abs().sum(1) group_imp.append(local_imp) + group_idxs.append(root_idxs) + # Conv in_channels elif prune_fn in [ function.prune_conv_in_channels, @@ -438,6 +432,8 @@ def __call__(self, group, ch_groups=1): else: local_imp = (w * dw).abs().sum(1) group_imp.append(local_imp) + group_idxs.append(root_idxs) + # BN elif prune_fn == function.prune_groupnorm_out_channels: # regularize BN @@ -446,14 +442,8 @@ def __call__(self, group, ch_groups=1): dw = layer.weight.grad.data[idxs] local_imp = (w*dw).abs() group_imp.append(local_imp) - if len(group_imp) == 0: - return None - imp_size = len(group_imp[0]) - aligned_group_imp = [] - for imp in group_imp: - if len(imp) == imp_size: - aligned_group_imp.append(imp) - group_imp = torch.stack(aligned_group_imp, dim=0) - group_imp = self._reduce(group_imp) + group_idxs.append(root_idxs) + + group_imp = self._reduce(group_imp, group_idxs) group_imp = self._normalize(group_imp, self.normalizer) return group_imp diff --git a/torch_pruning/utils/op_counter.py b/torch_pruning/utils/op_counter.py index 7ccda94..1bda1ca 100644 --- a/torch_pruning/utils/op_counter.py +++ b/torch_pruning/utils/op_counter.py @@ -316,7 +316,7 @@ def accumulate_flops(self): def get_model_parameters_number(model): - params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + params_num = sum(p.numel() for p in model.parameters()) return params_num