diff --git a/flair/embeddings.py b/flair/embeddings.py index a390c77e9..5316c2975 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -462,13 +462,16 @@ def embedding_length(self) -> int: def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: - cache_path = '{}-tmp-cache.sqllite'.format(self.name) if self.cache_directory is None else os.path.join( - self.cache_directory, '{}-tmp-cache.sqllite'.format(os.path.basename(self.name))) - - # by default, use_cache is false (for older pre-trained models TODO: remove in version 0.4) - if 'cache' not in self.__dict__ or 'cache_directory' not in self.__dict__ or not os.path.exists(cache_path): + # this whole block is for compatibility with older serialized models TODO: remove in version 0.4 + if 'cache' not in self.__dict__ or 'cache_directory' not in self.__dict__: self.use_cache = False self.cache_directory = None + else: + cache_path = '{}-tmp-cache.sqllite'.format(self.name) if not self.cache_directory else os.path.join( + self.cache_directory, '{}-tmp-cache.sqllite'.format(os.path.basename(self.name))) + if not os.path.exists(cache_path): + self.use_cache = False + self.cache_directory = None # if cache is used, try setting embeddings from cache first if self.use_cache: @@ -553,15 +556,14 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: class DocumentMeanEmbeddings(DocumentEmbeddings): - def __init__(self, word_embeddings: List[TokenEmbeddings]): + def __init__(self, token_embeddings: List[TokenEmbeddings]): """The constructor takes a list of embeddings to be combined.""" super().__init__() - self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings) + self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings) self.name: str = 'document_mean' - self.__embedding_length: int = 0 - self.__embedding_length = self.embeddings.embedding_length + self.__embedding_length: int = self.embeddings.embedding_length if torch.cuda.is_available(): self.cuda() @@ -631,18 +633,12 @@ def __init__(self, """ super().__init__() - self.embeddings: List[TokenEmbeddings] = token_embeddings - - # IMPORTANT: add embeddings as torch modules - for i, embedding in enumerate(self.embeddings): - self.add_module('token_embedding_{}'.format(i), embedding) + self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings) self.reproject_words = reproject_words self.bidirectional = bidirectional - self.length_of_all_token_embeddings = 0 - for token_embedding in self.embeddings: - self.length_of_all_token_embeddings += token_embedding.embedding_length + self.length_of_all_token_embeddings: int = self.embeddings.embedding_length self.name = 'document_lstm' self.static_embeddings = False @@ -691,8 +687,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]): sentences.sort(key=lambda x: len(x), reverse=True) - for token_embedding in self.embeddings: - token_embedding.embed(sentences) + self.embeddings.embed(sentences) # first, sort sentences by number of tokens longest_token_sequence_in_batch: int = len(sentences[0]) diff --git a/tests/test_model_integration.py b/tests/test_model_integration.py new file mode 100644 index 000000000..3ddd9419e --- /dev/null +++ b/tests/test_model_integration.py @@ -0,0 +1,252 @@ +import os +import shutil + +from flair.data import Sentence +from flair.data_fetcher import NLPTaskDataFetcher, NLPTask +from flair.embeddings import WordEmbeddings, CharLMEmbeddings, DocumentLSTMEmbeddings, TokenEmbeddings +from flair.models import SequenceTagger, TextClassifier +from flair.trainers import SequenceTaggerTrainer, TextClassifierTrainer + + +def test_train_load_use_tagger(): + + corpus = NLPTaskDataFetcher.fetch_data(NLPTask.FASHION) + tag_dictionary = corpus.make_tag_dictionary('ner') + + embeddings = WordEmbeddings('glove') + + tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type='ner', + use_crf=False) + + # initialize trainer + trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True) + + trainer.train('./results', learning_rate=0.1, mini_batch_size=2, max_epochs=3) + + loaded_model: SequenceTagger = SequenceTagger.load_from_file('./results/final-model.pt') + + sentence = Sentence('I love Berlin') + sentence_empty = Sentence(' ') + + loaded_model.predict(sentence) + loaded_model.predict([sentence, sentence_empty]) + loaded_model.predict([sentence_empty]) + + # clean up results directory + shutil.rmtree('./results') + + +def test_train_charlm_load_use_tagger(): + + corpus = NLPTaskDataFetcher.fetch_data(NLPTask.FASHION) + tag_dictionary = corpus.make_tag_dictionary('ner') + + embeddings = CharLMEmbeddings('news-forward-fast') + + tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type='ner', + use_crf=False) + + # initialize trainer + trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True) + + trainer.train('./results', learning_rate=0.1, mini_batch_size=2, max_epochs=3) + + loaded_model: SequenceTagger = SequenceTagger.load_from_file('./results/final-model.pt') + + sentence = Sentence('I love Berlin') + sentence_empty = Sentence(' ') + + loaded_model.predict(sentence) + loaded_model.predict([sentence, sentence_empty]) + loaded_model.predict([sentence_empty]) + + # clean up results directory + shutil.rmtree('./results') + + +def test_train_charlm_changed_chache_load_use_tagger(): + + corpus = NLPTaskDataFetcher.fetch_data(NLPTask.FASHION) + tag_dictionary = corpus.make_tag_dictionary('ner') + + # make a temporary cache directory that we remove afterwards + os.makedirs('./results/cache/', exist_ok=True) + embeddings = CharLMEmbeddings('news-forward-fast', cache_directory='./results/cache/') + + tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type='ner', + use_crf=False) + + # initialize trainer + trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True) + + trainer.train('./results', learning_rate=0.1, mini_batch_size=2, max_epochs=3) + + # remove the cache directory + shutil.rmtree('./results/cache') + + loaded_model: SequenceTagger = SequenceTagger.load_from_file('./results/final-model.pt') + + sentence = Sentence('I love Berlin') + sentence_empty = Sentence(' ') + + loaded_model.predict(sentence) + loaded_model.predict([sentence, sentence_empty]) + loaded_model.predict([sentence_empty]) + + # clean up results directory + shutil.rmtree('./results') + + +def test_train_charlm_nochache_load_use_tagger(): + + corpus = NLPTaskDataFetcher.fetch_data(NLPTask.FASHION) + tag_dictionary = corpus.make_tag_dictionary('ner') + + embeddings = CharLMEmbeddings('news-forward-fast', use_cache=False) + + tagger: SequenceTagger = SequenceTagger(hidden_size=256, + embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type='ner', + use_crf=False) + + # initialize trainer + trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True) + + trainer.train('./results', learning_rate=0.1, mini_batch_size=2, max_epochs=3) + + loaded_model: SequenceTagger = SequenceTagger.load_from_file('./results/final-model.pt') + + sentence = Sentence('I love Berlin') + sentence_empty = Sentence(' ') + + loaded_model.predict(sentence) + loaded_model.predict([sentence, sentence_empty]) + loaded_model.predict([sentence_empty]) + + # clean up results directory + shutil.rmtree('./results') + + +def test_load_use_serialized_tagger(): + + loaded_model: SequenceTagger = SequenceTagger.load('ner') + + sentence = Sentence('I love Berlin') + sentence_empty = Sentence(' ') + + loaded_model.predict(sentence) + loaded_model.predict([sentence, sentence_empty]) + loaded_model.predict([sentence_empty]) + + +def test_train_load_use_classifier(): + corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB) + label_dict = corpus.make_label_dictionary() + + glove_embedding: WordEmbeddings = WordEmbeddings('en-glove') + document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings([glove_embedding], 128, 1, False, 64, False, + False) + + model = TextClassifier(document_embeddings, label_dict, False) + + trainer = TextClassifierTrainer(model, corpus, label_dict, False) + trainer.train('./results', max_epochs=2) + + sentence = Sentence("Berlin is a really nice city.") + + for s in model.predict(sentence): + for l in s.labels: + assert (l.value is not None) + assert (0.0 <= l.score <= 1.0) + assert (type(l.score) is float) + + loaded_model = TextClassifier.load_from_file('./results/final-model.pt') + + sentence = Sentence('I love Berlin') + sentence_empty = Sentence(' ') + + loaded_model.predict(sentence) + loaded_model.predict([sentence, sentence_empty]) + loaded_model.predict([sentence_empty]) + + # clean up results directory + shutil.rmtree('./results') + + +def test_train_charlm_load_use_classifier(): + corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB) + label_dict = corpus.make_label_dictionary() + + glove_embedding: TokenEmbeddings = CharLMEmbeddings('news-forward-fast') + document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings([glove_embedding], 128, 1, False, 64, False, + False) + + model = TextClassifier(document_embeddings, label_dict, False) + + trainer = TextClassifierTrainer(model, corpus, label_dict, False) + trainer.train('./results', max_epochs=2) + + sentence = Sentence("Berlin is a really nice city.") + + for s in model.predict(sentence): + for l in s.labels: + assert (l.value is not None) + assert (0.0 <= l.score <= 1.0) + assert (type(l.score) is float) + + loaded_model = TextClassifier.load_from_file('./results/final-model.pt') + + sentence = Sentence('I love Berlin') + sentence_empty = Sentence(' ') + + loaded_model.predict(sentence) + loaded_model.predict([sentence, sentence_empty]) + loaded_model.predict([sentence_empty]) + + # clean up results directory + shutil.rmtree('./results') + + +def test_train_charlm__nocache_load_use_classifier(): + corpus = NLPTaskDataFetcher.fetch_data(NLPTask.IMDB) + label_dict = corpus.make_label_dictionary() + + glove_embedding: TokenEmbeddings = CharLMEmbeddings('news-forward-fast', use_cache=False) + document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings([glove_embedding], 128, 1, False, 64, + False, + False) + + model = TextClassifier(document_embeddings, label_dict, False) + + trainer = TextClassifierTrainer(model, corpus, label_dict, False) + trainer.train('./results', max_epochs=2) + + sentence = Sentence("Berlin is a really nice city.") + + for s in model.predict(sentence): + for l in s.labels: + assert (l.value is not None) + assert (0.0 <= l.score <= 1.0) + assert (type(l.score) is float) + + loaded_model = TextClassifier.load_from_file('./results/final-model.pt') + + sentence = Sentence('I love Berlin') + sentence_empty = Sentence(' ') + + loaded_model.predict(sentence) + loaded_model.predict([sentence, sentence_empty]) + loaded_model.predict([sentence_empty]) + + # clean up results directory + shutil.rmtree('./results') \ No newline at end of file