Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr authored Aug 21, 2024
1 parent 74bbc4f commit 0b4265a
Showing 1 changed file with 75 additions and 63 deletions.
138 changes: 75 additions & 63 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,12 @@ def fit_transform(
documents, embeddings, assigned_documents, assigned_embeddings = self._zeroshot_topic_modeling(
documents, embeddings
)

# Filter UMAP embeddings to only non-assigned embeddings to be used for clustering
umap_embeddings = self.umap_model.transform(embeddings)
if len(documents) > 0:
umap_embeddings = self.umap_model.transform(embeddings)

if len(documents) > 0: # No zero-shot topics matched
if len(documents) > 0:
# Cluster reduced embeddings
documents, probabilities = self._cluster_embeddings(umap_embeddings, documents, y=y)
if self._is_zeroshot() and len(assigned_documents) > 0:
Expand All @@ -467,7 +469,6 @@ def fit_transform(
# All documents matches zero-shot topics
documents = assigned_documents
embeddings = assigned_embeddings
topics_before_reduction = self.topics_

# Sort and Map Topic IDs by their frequency
if not self.nr_topics:
Expand Down Expand Up @@ -505,17 +506,11 @@ def fit_transform(
sim_matrix = cosine_similarity(embeddings, np.array(self.topic_embeddings_))

if self.calculate_probabilities:
probabilities = sim_matrix
self.probabilities_ = sim_matrix
else:
# Use `topics_before_reduction` because `self.topics_` may have already been updated from
# reducing topics, and the original probabilities are needed for `self._map_probabilities()`
probabilities = sim_matrix[
np.arange(len(documents)),
np.array(topics_before_reduction) + self._outliers,
]

# Resulting output
self.probabilities_ = self._map_probabilities(probabilities, original_topics=True)
self.probabilities_ = np.max(sim_matrix, axis=1)
else:
self.probabilities_ = self._map_probabilities(probabilities, original_topics=True)
predictions = documents.Topic.to_list()

return predictions, self.probabilities_
Expand Down Expand Up @@ -2148,7 +2143,7 @@ def merge_topics(

# Update topics
documents.Topic = documents.Topic.map(mapping)
self.topic_mapper_.add_mappings(mapping)
self.topic_mapper_.add_mappings(mapping, topic_model=self)
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)
self._update_topic_size(documents)
Expand Down Expand Up @@ -3841,19 +3836,24 @@ def _zeroshot_topic_modeling(

# Check that if a number of topics was specified, it exceeds the number of zeroshot topics matched
num_zeroshot_topics = len(assigned_documents["Topic"].unique())
if self.nr_topics and not self.nr_topics > num_zeroshot_topics:
raise ValueError(
f"The set nr_topics ({self.nr_topics}) must exceed the number of matched zero-shot topics "
f"({num_zeroshot_topics}). Consider raising nr_topics or raising the "
f"zeroshot_min_similarity ({self.zeroshot_min_similarity})."
)
if self.nr_topics != "auto":
if self.nr_topics and not self.nr_topics > num_zeroshot_topics:
raise ValueError(
f"The set nr_topics ({self.nr_topics}) must exceed the number of matched zero-shot topics "
f"({num_zeroshot_topics}). Consider raising nr_topics or raising the "
f"zeroshot_min_similarity ({self.zeroshot_min_similarity})."
)

# Select non-assigned topics to be clustered
documents = documents.iloc[non_assigned_ids]
documents["Old_ID"] = documents["ID"].copy()
documents["ID"] = range(len(documents))
embeddings = embeddings[non_assigned_ids]

if len(documents) == 0:
self.topics_ = assigned_documents["Topic"].values.tolist()
self.topic_mapper_ = TopicMapper(self.topics_)

logger.info("Zeroshot Step 1 - Completed \u2713")
return documents, embeddings, assigned_documents, assigned_embeddings

Expand Down Expand Up @@ -3914,6 +3914,7 @@ def _combine_zeroshot_topics(
# Combine the clustered documents/embeddings with assigned documents/embeddings in the original order
documents = pd.concat([documents, assigned_documents])
embeddings = np.vstack([embeddings, assigned_embeddings])
documents.ID = documents.Old_ID
sorted_indices = documents.Old_ID.argsort()
documents = documents.iloc[sorted_indices]
embeddings = embeddings[sorted_indices]
Expand Down Expand Up @@ -4407,50 +4408,12 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)
# Map topics
documents.Topic = new_topics
self._update_topic_size(documents)
self.topic_mapper_.add_mappings(mapped_topics)
self.topic_mapper_.add_mappings(mapped_topics, topic_model=self)

# Update representations
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)

# When zero-shot topic(s) are present in the topics to merge,
# determine whether to take one of the zero-shot topic labels
# or use a calculated representation.
if self._is_zeroshot():
new_topic_id_to_zeroshot_topic_idx = {}
topics_to_map = {
topic_mapping[0]: topic_mapping[1] for topic_mapping in np.array(self.topic_mapper_.mappings_)[:, -2:]
}

for topic_to, topics_from in basic_mappings.items():
# When extracting topics, the reduced topics were reordered.
# Must get the updated topic_to.
topic_to = topics_to_map[topic_to]

# which of the original topics are zero-shot
zeroshot_topic_ids = [
topic_id for topic_id in topics_from if topic_id in self._topic_id_to_zeroshot_topic_idx
]
if len(zeroshot_topic_ids) == 0:
continue

# If any of the original topics are zero-shot, take the best fitting zero-shot label
# if the cosine similarity with the new topic exceeds the zero-shot threshold
zeroshot_labels = [
self.zeroshot_topic_list[self._topic_id_to_zeroshot_topic_idx[topic_id]]
for topic_id in zeroshot_topic_ids
]
zeroshot_embeddings = self._extract_embeddings(zeroshot_labels)
cosine_similarities = cosine_similarity(
zeroshot_embeddings, [self.topic_embeddings_[topic_to]]
).flatten()
best_zeroshot_topic_idx = np.argmax(cosine_similarities)
best_cosine_similarity = cosine_similarities[best_zeroshot_topic_idx]
if best_cosine_similarity >= self.zeroshot_min_similarity:
new_topic_id_to_zeroshot_topic_idx[topic_to] = zeroshot_topic_ids[best_zeroshot_topic_idx]

self._topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx

self._update_topic_size(documents)
return documents

Expand Down Expand Up @@ -4503,7 +4466,7 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)
}

# Update documents and topics
self.topic_mapper_.add_mappings(mapped_topics)
self.topic_mapper_.add_mappings(mapped_topics, topic_model=self)
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)
self._update_topic_size(documents)
Expand Down Expand Up @@ -4533,13 +4496,17 @@ def _sort_mappings_by_frequency(self, documents: pd.DataFrame) -> pd.DataFrame:
documents: Updated dataframe with documents and the mapped
and re-ordered topic ids
"""
self._update_topic_size(documents)
# No need to sort if it's the first pass of zero-shot topic modeling
nr_zeroshot = len(self._topic_id_to_zeroshot_topic_idx)
if self._is_zeroshot and not self.nr_topics and nr_zeroshot > 0:
return documents

# Map topics based on frequency
self._update_topic_size(documents)
df = pd.DataFrame(self.topic_sizes_.items(), columns=["Old_Topic", "Size"]).sort_values("Size", ascending=False)
df = df[df.Old_Topic != -1]
sorted_topics = {**{-1: -1}, **dict(zip(df.Old_Topic, range(len(df))))}
self.topic_mapper_.add_mappings(sorted_topics)
self.topic_mapper_.add_mappings(sorted_topics, topic_model=self)

# Map documents
documents.Topic = documents.Topic.map(sorted_topics).fillna(documents.Topic).astype(int)
Expand Down Expand Up @@ -4729,11 +4696,12 @@ def get_mappings(self, original_topics: bool = True) -> Mapping[int, int]:
mappings = dict(zip(mappings[:, 0], mappings[:, 1]))
return mappings

def add_mappings(self, mappings: Mapping[int, int]):
def add_mappings(self, mappings: Mapping[int, int], topic_model: BERTopic):
"""Add new column(s) of topic mappings.
Arguments:
mappings: The mappings to add
topic_model: The topic model this TopicMapper belongs to
"""
for topics in self.mappings_:
topic = topics[-1]
Expand All @@ -4742,6 +4710,50 @@ def add_mappings(self, mappings: Mapping[int, int]):
else:
topics.append(-1)

# When zero-shot topic(s) are present in the topics to merge,
# determine whether to take one of the zero-shot topic labels
# or use a calculated representation.
if topic_model._is_zeroshot() and len(topic_model._topic_id_to_zeroshot_topic_idx) > 0:
new_topic_id_to_zeroshot_topic_idx = {}
topics_to_map = {
topic_mapping[0]: topic_mapping[1]
for topic_mapping in np.array(topic_model.topic_mapper_.mappings_)[:, -2:]
}

# Map topic_to to topics_from
mapping = defaultdict(list)
for key, value in topics_to_map.items():
mapping[value].append(key)

for topic_to, topics_from in mapping.items():
# which of the original topics are zero-shot
zeroshot_topic_ids = [
topic_id for topic_id in topics_from if topic_id in topic_model._topic_id_to_zeroshot_topic_idx
]
if len(zeroshot_topic_ids) == 0:
continue

# If any of the original topics are zero-shot, take the best fitting zero-shot label
# if the cosine similarity with the new topic exceeds the zero-shot threshold
zeroshot_labels = [
topic_model.zeroshot_topic_list[topic_model._topic_id_to_zeroshot_topic_idx[topic_id]]
for topic_id in zeroshot_topic_ids
]
zeroshot_embeddings = topic_model._extract_embeddings(zeroshot_labels)
cosine_similarities = cosine_similarity(
zeroshot_embeddings, [topic_model.topic_embeddings_[topic_to]]
).flatten()
best_zeroshot_topic_idx = np.argmax(cosine_similarities)
best_cosine_similarity = cosine_similarities[best_zeroshot_topic_idx]

if best_cosine_similarity >= topic_model.zeroshot_min_similarity:
# Using the topic ID from before mapping, get the idx into the zeroshot topic list
new_topic_id_to_zeroshot_topic_idx[topic_to] = topic_model._topic_id_to_zeroshot_topic_idx[
zeroshot_topic_ids[best_zeroshot_topic_idx]
]

topic_model._topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx

def add_new_topics(self, mappings: Mapping[int, int]):
"""Add new row(s) of topic mappings.
Expand Down

0 comments on commit 0b4265a

Please sign in to comment.