diff --git a/README.md b/README.md
index 45a75c6..ef1e0ba 100644
--- a/README.md
+++ b/README.md
@@ -3,9 +3,9 @@
[![Documentation Status](https://readthedocs.org/projects/lm-lstm-crf/badge/?version=latest)](http://lm-lstm-crf.readthedocs.io/en/latest/?badge=latest)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
-This project provides high-performance character-aware sequence labeling tools and tutorials. Model details can be accessed [here](http://arxiv.org/abs/1709.04109), and the implementation is based on the PyTorch library.
+This project provides high-performance character-aware sequence labeling tools, including [Training](#usage), [Evaluation](#evaluation) and [Prediction](#prediction).
-LM-LSTM-CRF achieves F1 score of 91.71+/-0.10 on the CoNLL 2003 NER dataset, without using any additional corpus or resource.
+Details about LM-LSTM-CRF can be accessed [here](http://arxiv.org/abs/1709.04109), and the implementation is based on the PyTorch library. Our model achieves F1 score of 91.71+/-0.10 on the CoNLL 2003 NER dataset, without using any additional corpus or resource.
The documents would be available [here](http://lm-lstm-crf.readthedocs.io/en/latest/).
@@ -202,6 +202,14 @@ to
newcomers
Uzbekistan
.
+```
+and the corresponding output is:
+
+```
+-DOCSTART- -DOCSTART- -DOCSTART-
+
+But China saw their luck desert them in the second match of the group , crashing to a surprise 2-0 defeat to newcomers Uzbekistan .
+
```
## Reference
diff --git a/model/predictor.py b/model/predictor.py
index 6deaae2..c89bc9b 100644
--- a/model/predictor.py
+++ b/model/predictor.py
@@ -114,7 +114,7 @@ def decode_s(self, feature, label):
return chunks
- def output_batch(self, ner_model, features, fout):
+ def output_batch(self, ner_model, documents, fout):
"""
decode the whole corpus in the specific format by calling apply_model to fit specific models
@@ -123,18 +123,22 @@ def output_batch(self, ner_model, features, fout):
feature (list): list of words list
fout: output file
"""
- f_len = len(features)
+ d_len = len(documents)
- for ind in tqdm( range(0, f_len, self.batch_size), mininterval=1,
+ for d_ind in tqdm( range(0, d_len), mininterval=1,
desc=' - Process', leave=False, file=sys.stdout):
- eind = min(f_len, ind + self.batch_size)
- labels = self.apply_model(ner_model, features[ind: eind])
- labels = torch.unbind(labels, 1)
-
- for ind2 in range(ind, eind):
- f = features[ind2]
- l = labels[ind2 - ind][0: len(f) ]
- fout.write(self.decode_str(features[ind2], l) + '\n\n')
+ fout.write('-DOCSTART- -DOCSTART- -DOCSTART-\n\n')
+ features = documents[d_ind]
+ f_len = len(features)
+ for ind in range(0, f_len, self.batch_size):
+ eind = min(f_len, ind + self.batch_size)
+ labels = self.apply_model(ner_model, features[ind: eind])
+ labels = torch.unbind(labels, 1)
+
+ for ind2 in range(ind, eind):
+ f = features[ind2]
+ l = labels[ind2 - ind][0: len(f) ]
+ fout.write(self.decode_str(features[ind2], l) + '\n\n')
def apply_model(self, ner_model, features):
"""
diff --git a/model/utils.py b/model/utils.py
index 89b2b79..bcf6563 100644
--- a/model/utils.py
+++ b/model/utils.py
@@ -239,23 +239,45 @@ def read_corpus(lines):
return features, labels
-def read_features(lines):
+def read_features(lines, multi_docs = True):
"""
convert un-annotated corpus into features
"""
- features = list()
- tmp_fl = list()
- for line in lines:
- if not (line.isspace() or (len(line) > 10 and line[0:10] == '-DOCSTART-')):
- line = line.rstrip()
- tmp_fl.append(line)
- elif len(tmp_fl) > 0:
+ if multi_docs:
+ documents = list()
+ features = list()
+ tmp_fl = list()
+ for line in lines:
+ if_doc_end = (len(line) > 10 and line[0:10] == '-DOCSTART-')
+ if not (line.isspace() or if_doc_end):
+ line = line.rstrip()
+ tmp_fl.append(line)
+ else:
+ if len(tmp_fl) > 0:
+ features.append(tmp_fl)
+ tmp_fl = list()
+ if if_doc_end and len(features) > 0:
+ documents.append(features)
+ features = list()
+ if len(tmp_fl) > 0:
features.append(tmp_fl)
- tmp_fl = list()
- if len(tmp_fl) > 0:
- features.append(tmp_fl)
-
- return features
+ if len(features) >0:
+ documents.append(features)
+ return documents
+ else:
+ features = list()
+ tmp_fl = list()
+ for line in lines:
+ if not (line.isspace() or (len(line) > 10 and line[0:10] == '-DOCSTART-')):
+ line = line.rstrip()
+ tmp_fl.append(line)
+ elif len(tmp_fl) > 0:
+ features.append(tmp_fl)
+ tmp_fl = list()
+ if len(tmp_fl) > 0:
+ features.append(tmp_fl)
+
+ return features
def shrink_embedding(feature_map, word_dict, word_embedding, caseless):
"""