diff --git a/optim_factory.py b/optim_factory.py index ccf66cd..1842eac 100644 --- a/optim_factory.py +++ b/optim_factory.py @@ -101,7 +101,8 @@ def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=N for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights - if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or \ + name.endswith(".gamma") or name.endswith(".beta"): group_name = "no_decay" this_weight_decay = 0. else: