From 9c48c1a5cfe92624f08eb3ae92d4419b02595ba4 Mon Sep 17 00:00:00 2001 From: "Lars O. Grobe" Date: Mon, 12 Feb 2024 00:16:17 +0100 Subject: [PATCH 1/4] Added suport for persistent gensim models. Use filename-parameter when training. Existing models will not be overwritten and bypass training. --- litstudy/nlp.py | 73 +++++++++++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/litstudy/nlp.py b/litstudy/nlp.py index b939e6e..d1c150f 100644 --- a/litstudy/nlp.py +++ b/litstudy/nlp.py @@ -274,31 +274,39 @@ def best_topic_for_documents(self) -> List[int]: return np.argmax(self.doc2topic, axis=1) -def train_nmf_model(corpus: Corpus, num_topics: int, seed=0, max_iter=500) -> TopicModel: - """Train a topic model using NMF. +def train_nmf_model( + corpus: Corpus, num_topics: int, seed=0, max_iter=500, filename=None +) -> TopicModel: + """Train a topic model using NMF and save unless given file exists. :param num_topics: The number of topics to train. :param seed: The seed used for random number generation. :param max_iter: The maximum number of iterations to use for training. More iterations mean better results, but longer training times. + :param filename: Name of gensim model to save, or to load if file exists. """ import gensim.models.nmf dic = corpus.dictionary freqs = corpus.frequencies - tfidf = gensim.models.tfidfmodel.TfidfModel(dictionary=dic) - model = gensim.models.nmf.Nmf( - list(tfidf[freqs]), - num_topics=num_topics, - passes=max_iter, - random_state=seed, - w_stop_condition=1e-9, - h_stop_condition=1e-9, - w_max_iter=50, - h_max_iter=50, - ) + if filename == None or os.path.isfile(filename) == False: + tfidf = gensim.models.tfidfmodel.TfidfModel(dictionary=dic) + model = gensim.models.nmf.Nmf( + list(tfidf[freqs]), + num_topics=num_topics, + passes=max_iter, + random_state=seed, + w_stop_condition=1e-9, + h_stop_condition=1e-9, + w_max_iter=50, + h_max_iter=50, + ) + if filename != None: + model.save(filename) + else: + model = gensim.models.nmf.Nmf.load(filename) doc2topic = corpus2dense(model[freqs], num_topics).T topic2token = model.get_topics() @@ -306,11 +314,12 @@ def train_nmf_model(corpus: Corpus, num_topics: int, seed=0, max_iter=500) -> To return TopicModel(dic, doc2topic, topic2token) -def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel: +def train_lda_model(corpus: Corpus, num_topics, seed=0, filename=None, **kwargs) -> TopicModel: """Train a topic model using LDA. :param num_topics: The number of topics to train. :param seed: The seed used for random number generation. + :param filename: Name of gensim model to save, or to load if file exists. :param kwargs: Arguments passed to `gensim.models.lda.LdaModel` (gensim3) or `gensim.models.ldamodel.LdaModel` (gensim4). """ @@ -321,7 +330,6 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel: from importlib.metadata import version gensim_mayor = int(version("gensim").split(".")[0]) - if gensim_mayor == 3: from gensim.models.lda import LdaModel @@ -329,7 +337,13 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel: elif gensim_mayor == 4: from gensim.models.ldamodel import LdaModel - model = LdaModel(freqs, id2word=dic, num_topics=num_topics, **kwargs) + if filename == None or os.path.isfile(filename) == False: + model = LdaModel(freqs, id2word=dic, num_topics=num_topics, **kwargs) + if filename != None: + model.save(filename) + else: + model = LdaModel.load(filename) + else: from sys import exit @@ -341,19 +355,21 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel: return TopicModel(dic, doc2topic, topic2token) -def train_elda_model(corpus: Corpus, num_topics, num_models=4, seed=0, **kwargs) -> TopicModel: +def train_elda_model( + corpus: Corpus, num_topics, num_models=4, seed=0, filename=None, **kwargs +) -> TopicModel: """Train a topic model using ensemble LDA. :param num_topics: The number of topics to train. :param num_models: The number of models to train. :param seed: The seed used for random number generation. + :param filename: Name of gensim model to save, or to load if file exists. :param kwargs: Arguments passed to `gensim.models.ensemblelda.EnsembleLda` (gensim4). """ from importlib.metadata import version gensim_mayor = int(version("gensim").split(".")[0]) - if gensim_mayor <= 3: from sys import exit @@ -364,14 +380,19 @@ def train_elda_model(corpus: Corpus, num_topics, num_models=4, seed=0, **kwargs) from gensim.models.ensemblelda import EnsembleLda - model = EnsembleLda( - topic_model_class="ldamulticore", - corpus=freqs, - id2word=dic, - num_topics=num_topics, - num_models=num_models, - **kwargs - ) + if filename == None or os.path.isfile(filename) == False: + model = EnsembleLda( + topic_model_class="ldamulticore", + corpus=freqs, + id2word=dic, + num_topics=num_topics, + num_models=num_models, + **kwargs + ) + if filename != None: + model.save(filename) + else: + model = LdaEnsembleLda.load(filename) doc2topic = corpus2dense(model[freqs], num_topics).T topic2token = model.get_topics() From 63d763dc1141bba65042e866098dba27f0c434de Mon Sep 17 00:00:00 2001 From: "Lars O. Grobe" Date: Mon, 12 Feb 2024 00:21:07 +0100 Subject: [PATCH 2/4] Fixed import of os.path.isfile(). --- litstudy/nlp.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/litstudy/nlp.py b/litstudy/nlp.py index d1c150f..2004d13 100644 --- a/litstudy/nlp.py +++ b/litstudy/nlp.py @@ -287,11 +287,12 @@ def train_nmf_model( :param filename: Name of gensim model to save, or to load if file exists. """ import gensim.models.nmf + from os.path import isfile dic = corpus.dictionary freqs = corpus.frequencies - if filename == None or os.path.isfile(filename) == False: + if filename == None or isfile(filename) == False: tfidf = gensim.models.tfidfmodel.TfidfModel(dictionary=dic) model = gensim.models.nmf.Nmf( list(tfidf[freqs]), @@ -328,16 +329,20 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, filename=None, **kwargs) freqs = corpus.frequencies from importlib.metadata import version + from os.path import isfile gensim_mayor = int(version("gensim").split(".")[0]) if gensim_mayor == 3: from gensim.models.lda import LdaModel - model = LdaModel(list(corpus), **kwargs) + if filename == None or isfile(filename) == False: + model = LdaModel(list(corpus), **kwargs) + if filename != None: + model.save(filename) elif gensim_mayor == 4: from gensim.models.ldamodel import LdaModel - if filename == None or os.path.isfile(filename) == False: + if filename == None or isfile(filename) == False: model = LdaModel(freqs, id2word=dic, num_topics=num_topics, **kwargs) if filename != None: model.save(filename) @@ -368,6 +373,7 @@ def train_elda_model( """ from importlib.metadata import version + from os.path import isfile gensim_mayor = int(version("gensim").split(".")[0]) if gensim_mayor <= 3: @@ -380,7 +386,7 @@ def train_elda_model( from gensim.models.ensemblelda import EnsembleLda - if filename == None or os.path.isfile(filename) == False: + if filename == None or isfile(filename) == False: model = EnsembleLda( topic_model_class="ldamulticore", corpus=freqs, From 13634dad483079331681e12ef85b3dc7570a1403 Mon Sep 17 00:00:00 2001 From: "Lars O. Grobe" Date: Mon, 12 Feb 2024 17:37:53 +0100 Subject: [PATCH 3/4] Fixed typo. --- litstudy/nlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litstudy/nlp.py b/litstudy/nlp.py index 2004d13..a595ed3 100644 --- a/litstudy/nlp.py +++ b/litstudy/nlp.py @@ -398,7 +398,7 @@ def train_elda_model( if filename != None: model.save(filename) else: - model = LdaEnsembleLda.load(filename) + model = EnsembleLda.load(filename) doc2topic = corpus2dense(model[freqs], num_topics).T topic2token = model.get_topics() From b9ec0cb291e935e673e71fe059c2d6bd67829275 Mon Sep 17 00:00:00 2001 From: "Lars O. Grobe" Date: Mon, 12 Feb 2024 18:12:10 +0100 Subject: [PATCH 4/4] Convert ensembla LDA to LDA. --- litstudy/nlp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litstudy/nlp.py b/litstudy/nlp.py index a595ed3..0c88ea5 100644 --- a/litstudy/nlp.py +++ b/litstudy/nlp.py @@ -400,6 +400,8 @@ def train_elda_model( else: model = EnsembleLda.load(filename) + model = model.generate_gensim_representation() + doc2topic = corpus2dense(model[freqs], num_topics).T topic2token = model.get_topics()