diff --git a/deepcrf/main.py b/deepcrf/main.py index 40b0a57..abeef88 100644 --- a/deepcrf/main.py +++ b/deepcrf/main.py @@ -304,27 +304,28 @@ def parse_to_tag_ids(sentences): 'n_vocab_add': n_vocab_add, 'use_cudnn': args['use_cudnn']} - deepcrf.util.write_vocab(save_name + '.model_attr', net_attrs) model_filename = args['model_filename'] if args['model_filename'] else None - if args['initial_links_filename']: - net_attrs['initial_model'] = BiLSTM_CNN_CRF( - **deepcrf.util.load_vocab(args['model_attr_filename'])) - serializers.load_hdf5(model_filename, net_attrs['initial_model']) - fine_tuning_links = set(deepcrf.util.load_vocab( - args['initial_links_filename'])) - model_links = set() - - for link in net_attrs['initial_model'].namedlinks(): - link = link[0].split('/')[1] - if link: - model_links.add(link) - - for link in model_links - fine_tuning_links: - setattr(net_attrs['initial_model'], link, None) - - deepcrf.util.write_vocab(save_name + '.fine-tuning_links', - list(model_links & fine_tuning_links)) + if is_train: + deepcrf.util.write_vocab(save_name + '.model_attr', net_attrs) + if args['initial_links_filename']: + net_attrs['initial_model'] = BiLSTM_CNN_CRF( + **deepcrf.util.load_vocab(args['model_attr_filename'])) + serializers.load_hdf5(model_filename, net_attrs['initial_model']) + fine_tuning_links = set(deepcrf.util.load_vocab( + args['initial_links_filename'])) + model_links = set() + + for link in net_attrs['initial_model'].namedlinks(): + link = link[0].split('/')[1] + if link: + model_links.add(link) + + for link in model_links - fine_tuning_links: + setattr(net_attrs['initial_model'], link, None) + + deepcrf.util.write_vocab(save_name + '.fine-tuning_links', + list(model_links & fine_tuning_links)) net = BiLSTM_CNN_CRF(**net_attrs) my_cudnn(args['use_cudnn']) @@ -398,7 +399,7 @@ def eval_loop(x_data, x_char_data, y_data, x_train_additionals=[]): return predict_lists, sum_loss, predicted_results - if model_filename is not None and not args['initial_links_filename']: + if model_filename is not None and (is_test or (is_train and not args['initial_links_filename'])): serializers.load_hdf5(model_filename, net) if is_test: