diff --git a/litstudy/nlp.py b/litstudy/nlp.py index b939e6e..0c88ea5 100644 --- a/litstudy/nlp.py +++ b/litstudy/nlp.py @@ -274,31 +274,40 @@ def best_topic_for_documents(self) -> List[int]: return np.argmax(self.doc2topic, axis=1) -def train_nmf_model(corpus: Corpus, num_topics: int, seed=0, max_iter=500) -> TopicModel: - """Train a topic model using NMF. +def train_nmf_model( + corpus: Corpus, num_topics: int, seed=0, max_iter=500, filename=None +) -> TopicModel: + """Train a topic model using NMF and save unless given file exists. :param num_topics: The number of topics to train. :param seed: The seed used for random number generation. :param max_iter: The maximum number of iterations to use for training. More iterations mean better results, but longer training times. + :param filename: Name of gensim model to save, or to load if file exists. """ import gensim.models.nmf + from os.path import isfile dic = corpus.dictionary freqs = corpus.frequencies - tfidf = gensim.models.tfidfmodel.TfidfModel(dictionary=dic) - model = gensim.models.nmf.Nmf( - list(tfidf[freqs]), - num_topics=num_topics, - passes=max_iter, - random_state=seed, - w_stop_condition=1e-9, - h_stop_condition=1e-9, - w_max_iter=50, - h_max_iter=50, - ) + if filename == None or isfile(filename) == False: + tfidf = gensim.models.tfidfmodel.TfidfModel(dictionary=dic) + model = gensim.models.nmf.Nmf( + list(tfidf[freqs]), + num_topics=num_topics, + passes=max_iter, + random_state=seed, + w_stop_condition=1e-9, + h_stop_condition=1e-9, + w_max_iter=50, + h_max_iter=50, + ) + if filename != None: + model.save(filename) + else: + model = gensim.models.nmf.Nmf.load(filename) doc2topic = corpus2dense(model[freqs], num_topics).T topic2token = model.get_topics() @@ -306,11 +315,12 @@ def train_nmf_model(corpus: Corpus, num_topics: int, seed=0, max_iter=500) -> To return TopicModel(dic, doc2topic, topic2token) -def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel: +def train_lda_model(corpus: Corpus, num_topics, seed=0, filename=None, **kwargs) -> TopicModel: """Train a topic model using LDA. :param num_topics: The number of topics to train. :param seed: The seed used for random number generation. + :param filename: Name of gensim model to save, or to load if file exists. :param kwargs: Arguments passed to `gensim.models.lda.LdaModel` (gensim3) or `gensim.models.ldamodel.LdaModel` (gensim4). """ @@ -319,17 +329,26 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel: freqs = corpus.frequencies from importlib.metadata import version + from os.path import isfile gensim_mayor = int(version("gensim").split(".")[0]) - if gensim_mayor == 3: from gensim.models.lda import LdaModel - model = LdaModel(list(corpus), **kwargs) + if filename == None or isfile(filename) == False: + model = LdaModel(list(corpus), **kwargs) + if filename != None: + model.save(filename) elif gensim_mayor == 4: from gensim.models.ldamodel import LdaModel - model = LdaModel(freqs, id2word=dic, num_topics=num_topics, **kwargs) + if filename == None or isfile(filename) == False: + model = LdaModel(freqs, id2word=dic, num_topics=num_topics, **kwargs) + if filename != None: + model.save(filename) + else: + model = LdaModel.load(filename) + else: from sys import exit @@ -341,19 +360,22 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel: return TopicModel(dic, doc2topic, topic2token) -def train_elda_model(corpus: Corpus, num_topics, num_models=4, seed=0, **kwargs) -> TopicModel: +def train_elda_model( + corpus: Corpus, num_topics, num_models=4, seed=0, filename=None, **kwargs +) -> TopicModel: """Train a topic model using ensemble LDA. :param num_topics: The number of topics to train. :param num_models: The number of models to train. :param seed: The seed used for random number generation. + :param filename: Name of gensim model to save, or to load if file exists. :param kwargs: Arguments passed to `gensim.models.ensemblelda.EnsembleLda` (gensim4). """ from importlib.metadata import version + from os.path import isfile gensim_mayor = int(version("gensim").split(".")[0]) - if gensim_mayor <= 3: from sys import exit @@ -364,14 +386,21 @@ def train_elda_model(corpus: Corpus, num_topics, num_models=4, seed=0, **kwargs) from gensim.models.ensemblelda import EnsembleLda - model = EnsembleLda( - topic_model_class="ldamulticore", - corpus=freqs, - id2word=dic, - num_topics=num_topics, - num_models=num_models, - **kwargs - ) + if filename == None or isfile(filename) == False: + model = EnsembleLda( + topic_model_class="ldamulticore", + corpus=freqs, + id2word=dic, + num_topics=num_topics, + num_models=num_models, + **kwargs + ) + if filename != None: + model.save(filename) + else: + model = EnsembleLda.load(filename) + + model = model.generate_gensim_representation() doc2topic = corpus2dense(model[freqs], num_topics).T topic2token = model.get_topics()