Skip to content

Commit

Permalink
add tqdm to predict
Browse files Browse the repository at this point in the history
  • Loading branch information
LiyuanLucasLiu committed Sep 19, 2017
1 parent e887f04 commit c695bfb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
5 changes: 4 additions & 1 deletion model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch.autograd as autograd
import numpy as np
import itertools
import sys
from tqdm import tqdm

from model.crf import CRFDecode_vb
from model.utils import *
Expand Down Expand Up @@ -123,7 +125,8 @@ def output_batch(self, ner_model, features, fout):
"""
f_len = len(features)

for ind in range(0, f_len, self.batch_size):
for ind in tqdm( range(0, f_len, self.batch_size), mininterval=1,
desc=' - Process', leave=False, file=sys.stdout):
eind = min(f_len, ind + self.batch_size)
labels = self.apply_model(ner_model, features[ind: eind])
labels = torch.unbind(labels, 1)
Expand Down
7 changes: 5 additions & 2 deletions seq_wc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
parser.add_argument('--output_file', default='output.txt', help='path to output file')
args = parser.parse_args()

print('loading dictionary')
with open(args.load_arg, 'r') as f:
jd = json.load(f)
jd = jd['args']
Expand All @@ -42,15 +43,16 @@
if args.gpu >= 0:
torch.cuda.set_device(args.gpu)

# load corpus
# loading corpus
print('loading corpus')
with codecs.open(args.input_file, 'r', 'utf-8') as f:
lines = f.readlines()

# converting format

features = utils.read_features(lines)

# build model
print('loading model')
ner_model = LM_LSTM_CRF(len(l_map), len(c_map), jd['char_dim'], jd['char_hidden'], jd['char_layers'], jd['word_dim'], jd['word_hidden'], jd['word_layers'], len(f_map), jd['drop_out'], large_CRF=jd['small_crf'], if_highway=jd['high_way'], in_doc_words=in_doc_words, highway_layers = jd['highway_layers'])

ner_model.load_state_dict(checkpoint_file['state_dict'])
Expand All @@ -67,5 +69,6 @@
decode_label = (args.decode_type == 'label')
predictor = predict_wc(if_cuda, f_map, c_map, l_map, f_map['<eof>'], c_map['\n'], l_map['<pad>'], l_map['<start>'], decode_label, args.batch_size, jd['caseless'])

print('annotating')
with open(args.output_file, 'w') as fout:
predictor.output_batch(ner_model, features, fout)

0 comments on commit c695bfb

Please sign in to comment.