Skip to content

Commit

Permalink
Enable to predict when doing fine-tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
massongit committed Nov 27, 2017
1 parent c1d488d commit 4c81907
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions deepcrf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,27 +304,28 @@ def parse_to_tag_ids(sentences):
'n_vocab_add': n_vocab_add,
'use_cudnn': args['use_cudnn']}

deepcrf.util.write_vocab(save_name + '.model_attr', net_attrs)
model_filename = args['model_filename'] if args['model_filename'] else None

if args['initial_links_filename']:
net_attrs['initial_model'] = BiLSTM_CNN_CRF(
**deepcrf.util.load_vocab(args['model_attr_filename']))
serializers.load_hdf5(model_filename, net_attrs['initial_model'])
fine_tuning_links = set(deepcrf.util.load_vocab(
args['initial_links_filename']))
model_links = set()

for link in net_attrs['initial_model'].namedlinks():
link = link[0].split('/')[1]
if link:
model_links.add(link)

for link in model_links - fine_tuning_links:
setattr(net_attrs['initial_model'], link, None)

deepcrf.util.write_vocab(save_name + '.fine-tuning_links',
list(model_links & fine_tuning_links))
if is_train:
deepcrf.util.write_vocab(save_name + '.model_attr', net_attrs)
if args['initial_links_filename']:
net_attrs['initial_model'] = BiLSTM_CNN_CRF(
**deepcrf.util.load_vocab(args['model_attr_filename']))
serializers.load_hdf5(model_filename, net_attrs['initial_model'])
fine_tuning_links = set(deepcrf.util.load_vocab(
args['initial_links_filename']))
model_links = set()

for link in net_attrs['initial_model'].namedlinks():
link = link[0].split('/')[1]
if link:
model_links.add(link)

for link in model_links - fine_tuning_links:
setattr(net_attrs['initial_model'], link, None)

deepcrf.util.write_vocab(save_name + '.fine-tuning_links',
list(model_links & fine_tuning_links))

net = BiLSTM_CNN_CRF(**net_attrs)
my_cudnn(args['use_cudnn'])
Expand Down Expand Up @@ -398,7 +399,7 @@ def eval_loop(x_data, x_char_data, y_data, x_train_additionals=[]):

return predict_lists, sum_loss, predicted_results

if model_filename is not None and not args['initial_links_filename']:
if model_filename is not None and (is_test or (is_train and not args['initial_links_filename'])):
serializers.load_hdf5(model_filename, net)

if is_test:
Expand Down

0 comments on commit 4c81907

Please sign in to comment.