Skip to content

Commit

Permalink
feat/emoji-fix (#237)
Browse files Browse the repository at this point in the history
- move emoji removal from vectorizer to topic document preprocessing
https://github.com/orgs/MIT-AI-Accelerator/projects/3/views/3?pane=issue&itemId=72090681
- confirmed end-to-end and pytests passing
  • Loading branch information
emiliecowen authored Sep 13, 2024
1 parent d200714 commit 752260a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
11 changes: 10 additions & 1 deletion app/aimodels/bertopic/ai_services/basic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class TopicDocumentData(BaseModel):
document_nicknames: list[str]
document_channels: list[str]
document_links: list[str]
document_metadata: list[dict]
embeddings: np.ndarray

class Config:
Expand Down Expand Up @@ -242,6 +243,7 @@ def get_document_info(self, topic_model, topic_document_data: TopicDocumentData,
document_info['Nickname'] = topic_document_data.document_nicknames
document_info['Channel'] = topic_document_data.document_channels
document_info['Link'] = topic_document_data.document_links
document_info['Metadata'] = topic_document_data.document_metadata
return document_info

def train_bertopic_on_documents(self, db, documents, precalculated_embeddings, num_topics, document_df, seed_topic_list=None, num_related_docs=DEFAULT_N_REPR_DOCS, trends_only=False, trend_depth=DEFAULT_TREND_DEPTH_DAYS, train_percent=DEFAULT_TRAIN_PERCENT) -> BasicInferenceOutputs:
Expand All @@ -264,6 +266,7 @@ def train_bertopic_on_documents(self, db, documents, precalculated_embeddings, n
document_nicknames = list(document_df['nickname'].values),
document_channels = list(document_df['channel_name'].values),
document_links = list(document_df['mm_link'].values),
document_metadata = list(document_df['mm_metadata'].values),
embeddings = embeddings)

topic_model, topic_document_data_train, topic_document_data_test = self.build_topic_model(
Expand Down Expand Up @@ -377,7 +380,8 @@ def build_topic_model(self, topic_document_data: TopicDocumentData, num_topics,
'user': topic_document_data.document_users,
'nickname': topic_document_data.document_nicknames,
'channel': topic_document_data.document_channels,
'link': topic_document_data.document_links})
'link': topic_document_data.document_links,
'metadata': topic_document_data.document_metadata})

if self.weak_learner_obj:
l_test = self.label_applier.apply(
Expand All @@ -390,6 +394,9 @@ def build_topic_model(self, topic_document_data: TopicDocumentData, num_topics,
# convert documents to lowercase prior to stopword removal
data_train['document'] = data_train['message'].str.lower()

# remove emojis from topic document text
data_train['document'] = data_train['document'].replace(to_replace =':[a-zA-Z0-9_]*:', value = '', regex = True)

# split data, train, then infer. assumes documents, embeddings
# sorted by timestamp previously (in train request)
train_len = round(len(data_train) * train_percent)
Expand All @@ -406,6 +413,7 @@ def build_topic_model(self, topic_document_data: TopicDocumentData, num_topics,
document_nicknames = list(data_train['nickname']),
document_channels = list(data_train['channel']),
document_links = list(data_train['link']),
document_metadata = list(data_train['metadata']),
embeddings = topic_document_data.embeddings[:train_len-1])
topic_document_data_test = TopicDocumentData(document_text_list = list(data_test['document']),
document_messages = list(data_test['message']),
Expand All @@ -414,6 +422,7 @@ def build_topic_model(self, topic_document_data: TopicDocumentData, num_topics,
document_nicknames = list(data_test['nickname']),
document_channels= list(data_test['channel']),
document_links = list(data_test['link']),
document_metadata = list(data_test['metadata']),
embeddings = topic_document_data.embeddings[train_len:])

umap_model = UMAP(n_neighbors = DEFAULT_UMAP_NEIGHBORS,
Expand Down
9 changes: 1 addition & 8 deletions app/aimodels/bertopic/routers/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,14 @@ def train_bertopic_post(request: TrainModelRequest, db: Session = Depends(get_db
documents, precalculated_embeddings = get_documents_and_embeddings(db, request.document_ids, request.sentence_transformer_id)
document_df = crud_mattermost.mattermost_documents.get_document_dataframe(db, document_uuids=request.document_ids)

# add emojis to stopword list
emojis = set()
df_emoji = document_df[document_df['mm_metadata'].map(lambda x: 'emojis' in x)]
for key, row in df_emoji.iterrows():
for e in row['mm_metadata']['emojis']:
emojis.add(e['name'])

# train the model
basic_inference = BasicInference(bertopic_sentence_transformer_obj,
s3,
request.prompt_template,
request.refine_template,
bertopic_weak_learner_obj,
llm_pretrained_obj,
stop_word_list=request.stop_words + list(emojis))
stop_word_list=request.stop_words)
inference_output = basic_inference.train_bertopic_on_documents(db,
documents, precalculated_embeddings=precalculated_embeddings, num_topics=request.num_topics,
document_df=document_df,
Expand Down

0 comments on commit 752260a

Please sign in to comment.