Skip to content

Commit

Permalink
feat/thread-chats-only (#241)
Browse files Browse the repository at this point in the history
- use conversation threads for user chats only
#240
- add acars_text to info_type to capture air crew "chats"
https://github.com/orgs/MIT-AI-Accelerator/projects/3/views/3?pane=issue&itemId=80096759
- add user/persona context to summarization text
https://github.com/orgs/MIT-AI-Accelerator/projects/3/views/3?pane=issue&itemId=80614000
- requires DB init and ppg-common build
- confirmed end-to-end and pytests passing
  • Loading branch information
emiliecowen authored Sep 23, 2024
1 parent 752260a commit 53c2c87
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 63 deletions.
14 changes: 10 additions & 4 deletions app/aimodels/bertopic/ai_services/basic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class TopicDocumentData(BaseModel):
document_channels: list[str]
document_links: list[str]
document_metadata: list[dict]
document_summarization_messages: list[str]
embeddings: np.ndarray

class Config:
Expand Down Expand Up @@ -244,6 +245,7 @@ def get_document_info(self, topic_model, topic_document_data: TopicDocumentData,
document_info['Channel'] = topic_document_data.document_channels
document_info['Link'] = topic_document_data.document_links
document_info['Metadata'] = topic_document_data.document_metadata
document_info['Summarization_Message'] = topic_document_data.document_summarization_messages
return document_info

def train_bertopic_on_documents(self, db, documents, precalculated_embeddings, num_topics, document_df, seed_topic_list=None, num_related_docs=DEFAULT_N_REPR_DOCS, trends_only=False, trend_depth=DEFAULT_TREND_DEPTH_DAYS, train_percent=DEFAULT_TRAIN_PERCENT) -> BasicInferenceOutputs:
Expand All @@ -267,6 +269,7 @@ def train_bertopic_on_documents(self, db, documents, precalculated_embeddings, n
document_channels = list(document_df['channel_name'].values),
document_links = list(document_df['mm_link'].values),
document_metadata = list(document_df['mm_metadata'].values),
document_summarization_messages = list(document_df['summarization_message'].values),
embeddings = embeddings)

topic_model, topic_document_data_train, topic_document_data_test = self.build_topic_model(
Expand Down Expand Up @@ -381,7 +384,8 @@ def build_topic_model(self, topic_document_data: TopicDocumentData, num_topics,
'nickname': topic_document_data.document_nicknames,
'channel': topic_document_data.document_channels,
'link': topic_document_data.document_links,
'metadata': topic_document_data.document_metadata})
'metadata': topic_document_data.document_metadata,
'summarization_message': topic_document_data.document_summarization_messages})

if self.weak_learner_obj:
l_test = self.label_applier.apply(
Expand Down Expand Up @@ -414,6 +418,7 @@ def build_topic_model(self, topic_document_data: TopicDocumentData, num_topics,
document_channels = list(data_train['channel']),
document_links = list(data_train['link']),
document_metadata = list(data_train['metadata']),
document_summarization_messages = list(data_train['summarization_message']),
embeddings = topic_document_data.embeddings[:train_len-1])
topic_document_data_test = TopicDocumentData(document_text_list = list(data_test['document']),
document_messages = list(data_test['message']),
Expand All @@ -422,7 +427,8 @@ def build_topic_model(self, topic_document_data: TopicDocumentData, num_topics,
document_nicknames = list(data_test['nickname']),
document_channels= list(data_test['channel']),
document_links = list(data_test['link']),
document_metadata = list(data_test['metadata']),
document_metadata = list(data_test['metadata']),
document_summarization_messages = list(data_test['summarization_message']),
embeddings = topic_document_data.embeddings[train_len:])

umap_model = UMAP(n_neighbors = DEFAULT_UMAP_NEIGHBORS,
Expand Down Expand Up @@ -473,7 +479,7 @@ def create_topic_visualizations(self, document_info_train, topic_model, document

summary_text = 'topic summarization disabled'
if self.topic_summarizer:
summary_text = self.topic_summarizer.get_summary(topic_docs['Message'].to_list())
summary_text = self.topic_summarizer.get_summary(topic_docs['Summarization_Message'].to_list())

# topic-level timeline visualization
topic_timeline_visualization_list = topic_timeline_visualization_list + [topic_model.visualize_topics_over_time(
Expand All @@ -484,7 +490,7 @@ def create_topic_visualizations(self, document_info_train, topic_model, document
topic_id=row['Topic'],
name=row['Name'],
top_n_words=topic_docs['Top_n_words'].unique()[0],
top_n_documents=topic_docs.rename(columns={'Document': 'Lowercase', 'Message': 'Document'})[[
top_n_documents=topic_docs.rename(columns={'Document': 'Lowercase', 'Summarization_Message': 'Document'})[[
'Document', 'Timestamp', 'User', 'Nickname', 'Channel', 'Link', 'Probability']].to_dict(),
summary=summary_text,
is_trending=row['is_trending'])]
Expand Down
20 changes: 20 additions & 0 deletions app/aimodels/bertopic/routers/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ppg.schemas.bertopic.bertopic_trained import BertopicTrained, BertopicTrainedCreate, BertopicTrainedUpdate
from ppg.schemas.bertopic.bertopic_visualization import BertopicVisualizationCreate
from ppg.schemas.bertopic.topic import TopicSummaryUpdate
from app.core.logging import logger
from app.core.minio import pickle_and_upload_object_to_minio
from ..ai_services.basic_inference import BasicInference, MIN_BERTOPIC_DOCUMENTS, DEFAULT_TRAIN_PERCENT
from app.dependencies import get_db, get_minio
Expand All @@ -32,6 +33,7 @@ class TrainModelRequest(BaseModel):
weak_learner_id: UUID4 | None
summarization_model_id: UUID4 | None
document_ids: list[UUID4] = []
summarization_document_ids: list[UUID4] = []
num_topics: int = 2
seed_topics: list[list] = []
stop_words: list[str] = []
Expand Down Expand Up @@ -65,6 +67,9 @@ def train_bertopic_post(request: TrainModelRequest, db: Session = Depends(get_db
if len(request.document_ids) < MIN_BERTOPIC_DOCUMENTS:
raise HTTPException(
status_code=400, detail="must have at least 7 documents to find topics")
elif len(request.summarization_document_ids) == 0 or len(request.document_ids) != len(request.summarization_document_ids):
logger.warning("reusing document_ids for summarization, length mismatch")
request.summarization_document_ids = request.document_ids

# validate train percent
if request.train_percent < 0.0 or request.train_percent > 1.0:
Expand All @@ -89,6 +94,21 @@ def train_bertopic_post(request: TrainModelRequest, db: Session = Depends(get_db

documents, precalculated_embeddings = get_documents_and_embeddings(db, request.document_ids, request.sentence_transformer_id)
document_df = crud_mattermost.mattermost_documents.get_document_dataframe(db, document_uuids=request.document_ids)
summarization_document_df = crud_mattermost.mattermost_documents.get_document_dataframe(db, document_uuids=request.summarization_document_ids)

if all(document_df['message_id'].str.len() > 0):
# merge original messages with summarization text
document_df = pd.merge(document_df,
summarization_document_df[['message_id', 'message']],
on='message_id',
how='left',
validate='1:1').rename(columns={"message_x": "message",
"message_y": "summarization_message"})
fix_mask = document_df['summarization_message'].isnull()
document_df.summarization_message[fix_mask] = document_df.message[fix_mask]
else:
logger.warning("reusing document_ids for summarization, missing merge criteria")
document_df['summarization_message'] = document_df['message']

# train the model
basic_inference = BasicInference(bertopic_sentence_transformer_obj,
Expand Down
3 changes: 2 additions & 1 deletion app/initial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ppg.schemas.bertopic.bertopic_embedding_pretrained import BertopicEmbeddingPretrainedCreate, BertopicEmbeddingPretrainedUpdate
from ppg.schemas.gpt4all.llm_pretrained import LlmPretrainedCreate, LlmPretrainedUpdate
from ppg.schemas.bertopic.document import DocumentCreate
from ppg.schemas.mattermost.mattermost_documents import ThreadTypeEnum

from app.aimodels.bertopic.models.bertopic_embedding_pretrained import BertopicEmbeddingPretrainedModel, EmbeddingModelTypeEnum
from app.aimodels.bertopic.models.document import DocumentModel
Expand Down Expand Up @@ -438,7 +439,7 @@ def init_mattermost_documents(db:Session, bot_obj: MattermostUserModel) -> None:
adf = adf[~adf.id.isin(existing_ids)].drop_duplicates(subset='id')

adf.rename(columns={'id': 'message_id'}, inplace=True)
return crud_mattermost.mattermost_documents.create_all_using_df(db, ddf=adf, is_thread=False)
return crud_mattermost.mattermost_documents.create_all_using_df(db, ddf=adf, thread_type=ThreadTypeEnum.MESSAGE)

########## large object uploads ################

Expand Down
69 changes: 47 additions & 22 deletions app/mattermost/crud/crud_mattermost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ppg.schemas.mattermost.mattermost_channels import MattermostChannelCreate
from ppg.schemas.mattermost.mattermost_documents import MattermostDocumentCreate, InfoTypeEnum
from ppg.schemas.mattermost.mattermost_documents import MattermostDocumentCreate, InfoTypeEnum, ThreadTypeEnum
from ppg.schemas.mattermost.mattermost_users import MattermostUserCreate, MattermostUserUpdate
from ppg.schemas.bertopic.document import DocumentCreate
import ppg.services.mattermost_utils as mattermost_utils
Expand Down Expand Up @@ -53,14 +53,14 @@ def get_by_user_name(self, db: Session, *, user_name: str) -> Union[MattermostUs


class CRUDMattermostDocument(CRUDBase[MattermostDocumentModel, MattermostDocumentCreate, MattermostDocumentCreate]):
def get_by_message_id(self, db: Session, *, message_id: str, is_thread = False) -> Union[MattermostDocumentModel, None]:
def get_by_message_id(self, db: Session, *, message_id: str, thread_type = ThreadTypeEnum.MESSAGE) -> Union[MattermostDocumentModel, None]:
if not message_id:
return None

# each mattermost document is allowed a single conversation thread
return db.query(self.model).filter(self.model.message_id == message_id, self.model.is_thread == is_thread).first()
return db.query(self.model).filter(self.model.message_id == message_id, self.model.thread_type == thread_type).first()

def get_all_by_message_id(self, db: Session, *, message_id: str, is_thread = False) -> Union[MattermostDocumentModel, None]:
def get_all_by_message_id(self, db: Session, *, message_id: str) -> Union[MattermostDocumentModel, None]:
if not message_id:
return None

Expand All @@ -84,7 +84,7 @@ def get_all_channel_documents(self, db: Session, channels: list[str], history_de
documents += sum([db.query(self.model).join(DocumentModel).filter(self.model.channel == cuuid,
DocumentModel.original_created_time >= stime,
DocumentModel.original_created_time <= ctime,
self.model.is_thread == False,
self.model.thread_type == ThreadTypeEnum.MESSAGE,
self.model.info_type == itype).all() for cuuid in channels], [])

return documents
Expand All @@ -104,6 +104,8 @@ def get_mm_document_dataframe(self, db: Session, *, mm_document_uuids: list[str]
'document': document[0][1].id,
'user_id': document[0][2].user_id,
'user_uuid': document[0][2].id,
'user_name': document[0][2].user_name,
'nickname': document[0][2].nickname,
'channel_id': document[0][3].channel_id,
'channel_uuid': document[0][3].id,
'create_at': document[0][1].original_created_time,
Expand All @@ -112,7 +114,8 @@ def get_mm_document_dataframe(self, db: Session, *, mm_document_uuids: list[str]
'props': document[0][0].props,
'metadata': document[0][0].doc_metadata,
'document_id': document[0][1].id,
'is_thread': document[0][0].is_thread}])],
'info_type': document[0][0].info_type,
'thread_type': document[0][0].thread_type}])],
ignore_index=True)

return ddf
Expand Down Expand Up @@ -188,7 +191,7 @@ def get_document_dataframe(self, db: Session, *, document_uuids: list[str]) -> U

return ddf

def create_all_using_df(self, db: Session, *, ddf: pd.DataFrame, is_thread = False) -> Union[MattermostDocumentModel, None]:
def create_all_using_df(self, db: Session, *, ddf: pd.DataFrame, thread_type = ThreadTypeEnum.MESSAGE) -> Union[MattermostDocumentModel, None]:

mattermost_documents = []
for key, row in ddf.iterrows():
Expand Down Expand Up @@ -220,7 +223,7 @@ def create_all_using_df(self, db: Session, *, ddf: pd.DataFrame, is_thread = Fal
channel=row['channel'],
user=row['user'],
document=document_obj.id,
is_thread=is_thread,
thread_type=thread_type,
info_type=info_type)]

return self.create_all_using_id(db, obj_in_list=mattermost_documents)
Expand Down Expand Up @@ -399,7 +402,7 @@ def populate_mm_document_info(db: Session, *, document_df: pd.DataFrame):
udf.rename(columns={"id": "message_id"}, inplace=True)

# create new document objects in db
new_mattermost_docs = new_mattermost_docs + mattermost_documents.create_all_using_df(db, ddf=udf, is_thread=False)
new_mattermost_docs = new_mattermost_docs + mattermost_documents.create_all_using_df(db, ddf=udf, thread_type=ThreadTypeEnum.MESSAGE)

return new_mattermost_docs

Expand All @@ -410,40 +413,59 @@ def convert_conversation_threads(df: pd.DataFrame):

df['root_id'] = df['root_id'].fillna('')
df['message'] = df['message'].fillna('')
df['thread'] = df['message'].fillna('')
df['thread_speaker'] = df['message'].fillna('')
df['thread_speaker_persona'] = df['message'].fillna('')
threads = {}
threads_speaker = {}
threads_speaker_persona = {}
threads_row = {}

for index, row in df.iterrows():
thread = row['root_id']
utterance = row['message']
utterance.replace("\n", " ")
speaker = row['user_name']
persona = row['nickname']
utterance_speaker = speaker + ': ' + utterance
utterance_speaker_persona = speaker + ' (' + persona + '): ' + utterance
p_id = row['message_id']

if utterance.find("added to the channel") < 0 and utterance.find("joined the channel") < 0 and utterance.find("left the channel") < 0:
if len(thread) > 0:
if thread not in threads:
threads[thread] = [utterance.replace("\n", " ")]
threads[thread] = [utterance]
threads_speaker[thread] = [utterance_speaker]
threads_speaker_persona[thread] = [utterance_speaker_persona]
else:
threads[thread].append(utterance.replace("\n", " "))
threads[thread].append(utterance)
threads_speaker[thread].append(utterance_speaker)
threads_speaker_persona[thread].append(utterance_speaker_persona)
else:
t = []
t.append(utterance.replace("\n", " "))
threads[p_id] = t
threads[p_id] = [utterance]
threads_speaker[p_id] = [utterance_speaker]
threads_speaker_persona[p_id] = [utterance_speaker_persona]
threads_row[p_id] = row
keys = set(sorted(threads.keys())).intersection(threads_row.keys())

data = []
conversations = []
for index, key in enumerate(keys):
row = threads_row[key]
row['message'] = "\n".join(threads[key])
data.append(row)
row['thread'] = "\n".join(threads[key])
row['thread_speaker'] = "\n".join(threads_speaker[key])
row['thread_speaker_persona'] = "\n".join(threads_speaker_persona[key])
conversations.append(row)

return pd.DataFrame(data, columns=df.columns)
return pd.DataFrame(conversations, columns=df.columns)


def parse_props(jobj: dict):
jobj = jobj['attachments'][0]

author_name = jobj['author_name']
title = jobj['title']
msg = '[%s] %s' % (jobj['title'], jobj['text'])
fallback = jobj['fallback']
msg = '[%s] %s' % (title, jobj['text'])

if 'Dataminr' in author_name:
info_type = InfoTypeEnum.DATAMINR
Expand All @@ -454,7 +476,10 @@ def parse_props(jobj: dict):
info_type = InfoTypeEnum.ARINC
elif 'ACARS' in title:
info_type = InfoTypeEnum.ACARS
msg = parse_props_acars(jobj)
msg = parse_props_acars(jobj, title=title)
elif 'ACARS Free Text' in fallback:
info_type = InfoTypeEnum.ACARS_TEXT
msg = parse_props_acars(jobj, title='ACARS Free Text')
elif 'NOTAM' in title:
info_type = InfoTypeEnum.NOTAM
msg = parse_props_notam(jobj)
Expand Down Expand Up @@ -485,8 +510,8 @@ def parse_props_notam(jobj: dict):
return msg


def parse_props_acars(jobj: dict):
msg = '[%s] %s' % (jobj['title'], jobj['text'])
def parse_props_acars(jobj: dict, title: str):
msg = '[%s] %s' % (title, jobj['text'])

if jobj['fields'] is not None:
tail_num = ''
Expand Down
6 changes: 3 additions & 3 deletions app/mattermost/models/mattermost_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy import Column, UUID, String, ForeignKey, Enum, JSON, Boolean, UniqueConstraint
from sqlalchemy.ext.mutable import MutableDict
from ppg.core.config import OriginationEnum
from ppg.schemas.mattermost.mattermost_documents import InfoTypeEnum
from ppg.schemas.mattermost.mattermost_documents import InfoTypeEnum, ThreadTypeEnum
from app.db.base_class import Base
from app.core.config import get_originated_from

Expand All @@ -26,11 +26,11 @@ class MattermostDocumentModel(Base):
has_reactions = Column(Boolean(), default=False)
props = Column(MutableDict.as_mutable(JSON))
doc_metadata = Column(MutableDict.as_mutable(JSON))
is_thread = Column(Boolean(), default=False)
thread_type = Column(Enum(ThreadTypeEnum), default=ThreadTypeEnum.MESSAGE)
info_type = Column(Enum(InfoTypeEnum), default=InfoTypeEnum.CHAT)
originated_from = Column(Enum(OriginationEnum),
default=get_originated_from)

# mattermost message IDs must be unique,
# allow for a single conversation thread for each message
__table_args__ = (UniqueConstraint('message_id', 'is_thread', name='_messageid_isthread_uc'),)
__table_args__ = (UniqueConstraint('message_id', 'thread_type', name='_messageid_threadtype_uc'),)
Loading

0 comments on commit 53c2c87

Please sign in to comment.