Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gensim4 LdaModel #73

Merged
merged 15 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litstudy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
build_corpus,
train_nmf_model,
train_lda_model,
train_elda_model,
compute_word_distribution,
calculate_embedding,
) # noqa: F401
Expand Down
61 changes: 57 additions & 4 deletions litstudy/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,16 +311,69 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel:

:param num_topics: The number of topics to train.
:param seed: The seed used for random number generation.
:param kwargs: Arguments passed to `gensim.models.lda.LdaModel`.
:param kwargs: Arguments passed to `gensim.models.lda.LdaModel` (gensim3)
or `gensim.models.ldamodel.LdaModel` (gensim4).
"""
from gensim.models.lda import LdaModel

dic = corpus.dictionary
freqs = corpus.frequencies

model = LdaModel(list(corpus), **kwargs)
from importlib.metadata import version

doc2topic = corpus2dense(model[freqs], num_topics)
gensim_mayor = int(version("gensim").split(".")[0])

if gensim_mayor == 3:
from gensim.models.lda import LdaModel

model = LdaModel(list(corpus), **kwargs)
elif gensim_mayor == 4:
from gensim.models.ldamodel import LdaModel

model = LdaModel(freqs, id2word=dic, num_topics=num_topics, **kwargs)
else:
from sys import exit

exit("LdaModel could not be imported from gensim 3 or 4.")

doc2topic = corpus2dense(model[freqs], num_topics).T
topic2token = model.get_topics()

return TopicModel(dic, doc2topic, topic2token)


def train_elda_model(corpus: Corpus, num_topics, num_models=4, seed=0, **kwargs) -> TopicModel:
"""Train a topic model using ensemble LDA.

:param num_topics: The number of topics to train.
:param num_models: The number of models to train.
:param seed: The seed used for random number generation.
:param kwargs: Arguments passed to `gensim.models.ensemblelda.EnsembleLda` (gensim4).
"""

from importlib.metadata import version

gensim_mayor = int(version("gensim").split(".")[0])

if gensim_mayor <= 3:
from sys import exit

exit("EnsembleLda requires at least gensim 4.")

dic = corpus.dictionary
freqs = corpus.frequencies

from gensim.models.ensemblelda import EnsembleLda

model = EnsembleLda(
topic_model_class="ldamulticore",
corpus=freqs,
id2word=dic,
num_topics=num_topics,
num_models=num_models,
**kwargs
)

doc2topic = corpus2dense(model[freqs], num_topics).T
topic2token = model.get_topics()

return TopicModel(dic, doc2topic, topic2token)
Expand Down
1 change: 1 addition & 0 deletions litstudy/sources/scopus_csv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
support loading Scopus CSV export.
"""

from typing import List, Optional
from ..types import Document, Author, DocumentSet, DocumentIdentifier, Affiliation
from ..common import robust_open
Expand Down
Loading