diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 6406b2d80..8059d5b5e 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -11,6 +11,8 @@ from typing import List, Tuple, Union +from flair.training_utils import clear_embeddings + START_TAG: str = '' STOP_TAG: str = '' @@ -359,6 +361,9 @@ def predict(self, sentences: Union[List[Sentence], Sentence], mini_batch_size=32 if type(sentences) is Sentence: sentences = [sentences] + # remove previous embeddings + clear_embeddings(sentences) + # make mini-batches batches = [sentences[x:x + mini_batch_size] for x in range(0, len(sentences), mini_batch_size)] diff --git a/setup.py b/setup.py index dffec2d5c..a89d4fc79 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='flair', - version='0.2.0', + version='0.2.0.post1', description='A very simple framework for state-of-the-art NLP', long_description=open("README.md").read(), long_description_content_type="text/markdown", diff --git a/tests/test_sequence_tagger.py b/tests/test_sequence_tagger.py new file mode 100644 index 000000000..dfa8ea736 --- /dev/null +++ b/tests/test_sequence_tagger.py @@ -0,0 +1,17 @@ +from flair.data import Sentence +from flair.models import SequenceTagger + + +def test_tag_sentence(): + + # test tagging + sentence = Sentence('I love Berlin') + + tagger = SequenceTagger.load('ner') + + tagger.predict(sentence) + + # test re-tagging + tagger = SequenceTagger.load('pos') + + tagger.predict(sentence)