Skip to content

Commit

Permalink
Merge branch 'v1.2' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF authored Jul 21, 2023
2 parents 788d24a + 7781b88 commit 3d7c4f8
Show file tree
Hide file tree
Showing 37 changed files with 2,567 additions and 852 deletions.
66 changes: 32 additions & 34 deletions README.md

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions benchmarks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
parser.add_argument("--max-sparsity", type=float, default=1.0)
parser.add_argument("--soft-keeping-ratio", type=float, default=0.0)
parser.add_argument("--reg", type=float, default=5e-4)
parser.add_argument("--delta_reg", type=float, default=1e-4, help='for growing regularization')
parser.add_argument("--weight-decay", type=float, default=5e-4)

parser.add_argument("--seed", type=int, default=None)
Expand Down Expand Up @@ -109,6 +110,7 @@ def train_model(
best_acc = -1
for epoch in range(epochs):
model.train()

for i, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
Expand All @@ -129,6 +131,10 @@ def train_model(
optimizer.param_groups[0]["lr"],
)
)

if pruner is not None and isinstance(pruner, tp.pruner.GrowingRegPruner):
pruner.update_reg() # increase the strength of regularization
#print(pruner.group_reg[pruner._groups[0]])

model.eval()
acc, val_loss = eval(model, test_loader, device=device)
Expand Down Expand Up @@ -171,13 +177,21 @@ def get_pruner(model, example_inputs):
args.sparsity_learning = True
imp = tp.importance.BNScaleImportance()
pruner_entry = partial(tp.pruner.BNScalePruner, reg=args.reg, global_pruning=args.global_pruning)
elif args.method == "group_slim":
args.sparsity_learning = True
imp = tp.importance.BNScaleImportance()
pruner_entry = partial(tp.pruner.BNScalePruner, reg=args.reg, global_pruning=args.global_pruning, group_lasso=True)
elif args.method == "group_norm":
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=args.global_pruning)
elif args.method == "group_sl":
args.sparsity_learning = True
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GroupNormPruner, reg=args.reg, global_pruning=args.global_pruning)
elif args.method == "growing_reg":
args.sparsity_learning = True
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg, global_pruning=args.global_pruning)
else:
raise NotImplementedError

Expand Down
Binary file removed benchmarks/prunability/coco_image.jpg
Binary file not shown.
348 changes: 0 additions & 348 deletions benchmarks/prunability/readme.md

This file was deleted.

Loading

0 comments on commit 3d7c4f8

Please sign in to comment.