Skip to content

Commit

Permalink
add depth export (PaddlePaddle#853)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 authored Aug 6, 2021
1 parent 5e92215 commit bd727d0
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion examples/model_compression/ofa/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import logging
import os
import math
import random
import time
import json
Expand Down Expand Up @@ -85,6 +86,11 @@ def parse_args():
type=float,
default=1.0,
help="width mult you want to export")
parser.add_argument(
'--depth_mult',
type=float,
default=1.0,
help="depth mult you want to export")
args = parser.parse_args()
return args

Expand All @@ -106,6 +112,18 @@ def do_train(args):
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config_path = os.path.join(args.model_name_or_path, 'model_config.json')
cfg_dict = dict(json.loads(open(config_path).read()))

if args.depth_mult < 1.0:
depth = round(cfg_dict["init_args"][0]['num_hidden_layers'] * args.depth_mult)
cfg_dict["init_args"][0]['num_hidden_layers'] = depth
kept_layers_index = {}
for idx, i in enumerate(range(1, depth+1)):
kept_layers_index[idx] = math.floor(i / args.depth_mult) - 1

os.rename(config_path, config_path+'_bak')
with open(config_path, "w", encoding="utf-8") as f:
f.write(json.dumps(cfg_dict, ensure_ascii=False))

num_labels = cfg_dict['num_classes']

model = model_class.from_pretrained(
Expand All @@ -114,14 +132,24 @@ def do_train(args):
origin_model = model_class.from_pretrained(
args.model_name_or_path, num_classes=num_labels)

os.rename(config_path+'_bak', config_path)

sp_config = supernet(expand_ratio=[1.0, args.width_mult])
model = Convert(sp_config).convert(model)

ofa_model = OFA(model)

sd = paddle.load(
os.path.join(args.model_name_or_path, 'model_state.pdparams'))
ofa_model.model.set_state_dict(sd)

for name, params in ofa_model.model.named_parameters():
if 'encoder' not in name:
params.set_value(sd[name])
else:
idx = int(name.strip().split('.')[3])
mapping_name = name.replace('.'+str(idx)+'.', '.'+str(kept_layers_index[idx])+'.')
params.set_value(sd[mapping_name])

best_config = utils.dynabert_config(ofa_model, args.width_mult)
ofa_model.export(
best_config,
Expand Down

0 comments on commit bd727d0

Please sign in to comment.