From 53c2c87fa9af9ae073d6bceb7c3125bd7bcde846 Mon Sep 17 00:00:00 2001 From: emiliecowen <123681156+emiliecowen@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:52:38 -0400 Subject: [PATCH] feat/thread-chats-only (#241) - use conversation threads for user chats only https://github.com/MIT-AI-Accelerator/c3po-model-server/issues/240 - add acars_text to info_type to capture air crew "chats" https://github.com/orgs/MIT-AI-Accelerator/projects/3/views/3?pane=issue&itemId=80096759 - add user/persona context to summarization text https://github.com/orgs/MIT-AI-Accelerator/projects/3/views/3?pane=issue&itemId=80614000 - requires DB init and ppg-common build - confirmed end-to-end and pytests passing --- .../bertopic/ai_services/basic_inference.py | 14 +++- app/aimodels/bertopic/routers/train.py | 20 +++++ app/initial_data.py | 3 +- app/mattermost/crud/crud_mattermost.py | 69 ++++++++++++------ app/mattermost/models/mattermost_documents.py | 6 +- app/mattermost/router.py | 73 ++++++++++++++----- poetry.lock | 2 +- .../mattermost/mattermost_documents.py | 8 +- ppg-common/setup.py | 2 +- tests/mattermost/test_mattermost_crud.py | 37 +++++++--- tests/mattermost/test_mattermost_router.py | 11 ++- 11 files changed, 182 insertions(+), 63 deletions(-) diff --git a/app/aimodels/bertopic/ai_services/basic_inference.py b/app/aimodels/bertopic/ai_services/basic_inference.py index 3e1bff0..d76c21c 100644 --- a/app/aimodels/bertopic/ai_services/basic_inference.py +++ b/app/aimodels/bertopic/ai_services/basic_inference.py @@ -179,6 +179,7 @@ class TopicDocumentData(BaseModel): document_channels: list[str] document_links: list[str] document_metadata: list[dict] + document_summarization_messages: list[str] embeddings: np.ndarray class Config: @@ -244,6 +245,7 @@ def get_document_info(self, topic_model, topic_document_data: TopicDocumentData, document_info['Channel'] = topic_document_data.document_channels document_info['Link'] = topic_document_data.document_links document_info['Metadata'] = topic_document_data.document_metadata + document_info['Summarization_Message'] = topic_document_data.document_summarization_messages return document_info def train_bertopic_on_documents(self, db, documents, precalculated_embeddings, num_topics, document_df, seed_topic_list=None, num_related_docs=DEFAULT_N_REPR_DOCS, trends_only=False, trend_depth=DEFAULT_TREND_DEPTH_DAYS, train_percent=DEFAULT_TRAIN_PERCENT) -> BasicInferenceOutputs: @@ -267,6 +269,7 @@ def train_bertopic_on_documents(self, db, documents, precalculated_embeddings, n document_channels = list(document_df['channel_name'].values), document_links = list(document_df['mm_link'].values), document_metadata = list(document_df['mm_metadata'].values), + document_summarization_messages = list(document_df['summarization_message'].values), embeddings = embeddings) topic_model, topic_document_data_train, topic_document_data_test = self.build_topic_model( @@ -381,7 +384,8 @@ def build_topic_model(self, topic_document_data: TopicDocumentData, num_topics, 'nickname': topic_document_data.document_nicknames, 'channel': topic_document_data.document_channels, 'link': topic_document_data.document_links, - 'metadata': topic_document_data.document_metadata}) + 'metadata': topic_document_data.document_metadata, + 'summarization_message': topic_document_data.document_summarization_messages}) if self.weak_learner_obj: l_test = self.label_applier.apply( @@ -414,6 +418,7 @@ def build_topic_model(self, topic_document_data: TopicDocumentData, num_topics, document_channels = list(data_train['channel']), document_links = list(data_train['link']), document_metadata = list(data_train['metadata']), + document_summarization_messages = list(data_train['summarization_message']), embeddings = topic_document_data.embeddings[:train_len-1]) topic_document_data_test = TopicDocumentData(document_text_list = list(data_test['document']), document_messages = list(data_test['message']), @@ -422,7 +427,8 @@ def build_topic_model(self, topic_document_data: TopicDocumentData, num_topics, document_nicknames = list(data_test['nickname']), document_channels= list(data_test['channel']), document_links = list(data_test['link']), - document_metadata = list(data_test['metadata']), + document_metadata = list(data_test['metadata']), + document_summarization_messages = list(data_test['summarization_message']), embeddings = topic_document_data.embeddings[train_len:]) umap_model = UMAP(n_neighbors = DEFAULT_UMAP_NEIGHBORS, @@ -473,7 +479,7 @@ def create_topic_visualizations(self, document_info_train, topic_model, document summary_text = 'topic summarization disabled' if self.topic_summarizer: - summary_text = self.topic_summarizer.get_summary(topic_docs['Message'].to_list()) + summary_text = self.topic_summarizer.get_summary(topic_docs['Summarization_Message'].to_list()) # topic-level timeline visualization topic_timeline_visualization_list = topic_timeline_visualization_list + [topic_model.visualize_topics_over_time( @@ -484,7 +490,7 @@ def create_topic_visualizations(self, document_info_train, topic_model, document topic_id=row['Topic'], name=row['Name'], top_n_words=topic_docs['Top_n_words'].unique()[0], - top_n_documents=topic_docs.rename(columns={'Document': 'Lowercase', 'Message': 'Document'})[[ + top_n_documents=topic_docs.rename(columns={'Document': 'Lowercase', 'Summarization_Message': 'Document'})[[ 'Document', 'Timestamp', 'User', 'Nickname', 'Channel', 'Link', 'Probability']].to_dict(), summary=summary_text, is_trending=row['is_trending'])] diff --git a/app/aimodels/bertopic/routers/train.py b/app/aimodels/bertopic/routers/train.py index d62dc08..08557a9 100644 --- a/app/aimodels/bertopic/routers/train.py +++ b/app/aimodels/bertopic/routers/train.py @@ -8,6 +8,7 @@ from ppg.schemas.bertopic.bertopic_trained import BertopicTrained, BertopicTrainedCreate, BertopicTrainedUpdate from ppg.schemas.bertopic.bertopic_visualization import BertopicVisualizationCreate from ppg.schemas.bertopic.topic import TopicSummaryUpdate +from app.core.logging import logger from app.core.minio import pickle_and_upload_object_to_minio from ..ai_services.basic_inference import BasicInference, MIN_BERTOPIC_DOCUMENTS, DEFAULT_TRAIN_PERCENT from app.dependencies import get_db, get_minio @@ -32,6 +33,7 @@ class TrainModelRequest(BaseModel): weak_learner_id: UUID4 | None summarization_model_id: UUID4 | None document_ids: list[UUID4] = [] + summarization_document_ids: list[UUID4] = [] num_topics: int = 2 seed_topics: list[list] = [] stop_words: list[str] = [] @@ -65,6 +67,9 @@ def train_bertopic_post(request: TrainModelRequest, db: Session = Depends(get_db if len(request.document_ids) < MIN_BERTOPIC_DOCUMENTS: raise HTTPException( status_code=400, detail="must have at least 7 documents to find topics") + elif len(request.summarization_document_ids) == 0 or len(request.document_ids) != len(request.summarization_document_ids): + logger.warning("reusing document_ids for summarization, length mismatch") + request.summarization_document_ids = request.document_ids # validate train percent if request.train_percent < 0.0 or request.train_percent > 1.0: @@ -89,6 +94,21 @@ def train_bertopic_post(request: TrainModelRequest, db: Session = Depends(get_db documents, precalculated_embeddings = get_documents_and_embeddings(db, request.document_ids, request.sentence_transformer_id) document_df = crud_mattermost.mattermost_documents.get_document_dataframe(db, document_uuids=request.document_ids) + summarization_document_df = crud_mattermost.mattermost_documents.get_document_dataframe(db, document_uuids=request.summarization_document_ids) + + if all(document_df['message_id'].str.len() > 0): + # merge original messages with summarization text + document_df = pd.merge(document_df, + summarization_document_df[['message_id', 'message']], + on='message_id', + how='left', + validate='1:1').rename(columns={"message_x": "message", + "message_y": "summarization_message"}) + fix_mask = document_df['summarization_message'].isnull() + document_df.summarization_message[fix_mask] = document_df.message[fix_mask] + else: + logger.warning("reusing document_ids for summarization, missing merge criteria") + document_df['summarization_message'] = document_df['message'] # train the model basic_inference = BasicInference(bertopic_sentence_transformer_obj, diff --git a/app/initial_data.py b/app/initial_data.py index 98fa6f3..e57e5d4 100644 --- a/app/initial_data.py +++ b/app/initial_data.py @@ -17,6 +17,7 @@ from ppg.schemas.bertopic.bertopic_embedding_pretrained import BertopicEmbeddingPretrainedCreate, BertopicEmbeddingPretrainedUpdate from ppg.schemas.gpt4all.llm_pretrained import LlmPretrainedCreate, LlmPretrainedUpdate from ppg.schemas.bertopic.document import DocumentCreate +from ppg.schemas.mattermost.mattermost_documents import ThreadTypeEnum from app.aimodels.bertopic.models.bertopic_embedding_pretrained import BertopicEmbeddingPretrainedModel, EmbeddingModelTypeEnum from app.aimodels.bertopic.models.document import DocumentModel @@ -438,7 +439,7 @@ def init_mattermost_documents(db:Session, bot_obj: MattermostUserModel) -> None: adf = adf[~adf.id.isin(existing_ids)].drop_duplicates(subset='id') adf.rename(columns={'id': 'message_id'}, inplace=True) - return crud_mattermost.mattermost_documents.create_all_using_df(db, ddf=adf, is_thread=False) + return crud_mattermost.mattermost_documents.create_all_using_df(db, ddf=adf, thread_type=ThreadTypeEnum.MESSAGE) ########## large object uploads ################ diff --git a/app/mattermost/crud/crud_mattermost.py b/app/mattermost/crud/crud_mattermost.py index 7a6f364..219c465 100644 --- a/app/mattermost/crud/crud_mattermost.py +++ b/app/mattermost/crud/crud_mattermost.py @@ -5,7 +5,7 @@ from fastapi import HTTPException from sqlalchemy.orm import Session from ppg.schemas.mattermost.mattermost_channels import MattermostChannelCreate -from ppg.schemas.mattermost.mattermost_documents import MattermostDocumentCreate, InfoTypeEnum +from ppg.schemas.mattermost.mattermost_documents import MattermostDocumentCreate, InfoTypeEnum, ThreadTypeEnum from ppg.schemas.mattermost.mattermost_users import MattermostUserCreate, MattermostUserUpdate from ppg.schemas.bertopic.document import DocumentCreate import ppg.services.mattermost_utils as mattermost_utils @@ -53,14 +53,14 @@ def get_by_user_name(self, db: Session, *, user_name: str) -> Union[MattermostUs class CRUDMattermostDocument(CRUDBase[MattermostDocumentModel, MattermostDocumentCreate, MattermostDocumentCreate]): - def get_by_message_id(self, db: Session, *, message_id: str, is_thread = False) -> Union[MattermostDocumentModel, None]: + def get_by_message_id(self, db: Session, *, message_id: str, thread_type = ThreadTypeEnum.MESSAGE) -> Union[MattermostDocumentModel, None]: if not message_id: return None # each mattermost document is allowed a single conversation thread - return db.query(self.model).filter(self.model.message_id == message_id, self.model.is_thread == is_thread).first() + return db.query(self.model).filter(self.model.message_id == message_id, self.model.thread_type == thread_type).first() - def get_all_by_message_id(self, db: Session, *, message_id: str, is_thread = False) -> Union[MattermostDocumentModel, None]: + def get_all_by_message_id(self, db: Session, *, message_id: str) -> Union[MattermostDocumentModel, None]: if not message_id: return None @@ -84,7 +84,7 @@ def get_all_channel_documents(self, db: Session, channels: list[str], history_de documents += sum([db.query(self.model).join(DocumentModel).filter(self.model.channel == cuuid, DocumentModel.original_created_time >= stime, DocumentModel.original_created_time <= ctime, - self.model.is_thread == False, + self.model.thread_type == ThreadTypeEnum.MESSAGE, self.model.info_type == itype).all() for cuuid in channels], []) return documents @@ -104,6 +104,8 @@ def get_mm_document_dataframe(self, db: Session, *, mm_document_uuids: list[str] 'document': document[0][1].id, 'user_id': document[0][2].user_id, 'user_uuid': document[0][2].id, + 'user_name': document[0][2].user_name, + 'nickname': document[0][2].nickname, 'channel_id': document[0][3].channel_id, 'channel_uuid': document[0][3].id, 'create_at': document[0][1].original_created_time, @@ -112,7 +114,8 @@ def get_mm_document_dataframe(self, db: Session, *, mm_document_uuids: list[str] 'props': document[0][0].props, 'metadata': document[0][0].doc_metadata, 'document_id': document[0][1].id, - 'is_thread': document[0][0].is_thread}])], + 'info_type': document[0][0].info_type, + 'thread_type': document[0][0].thread_type}])], ignore_index=True) return ddf @@ -188,7 +191,7 @@ def get_document_dataframe(self, db: Session, *, document_uuids: list[str]) -> U return ddf - def create_all_using_df(self, db: Session, *, ddf: pd.DataFrame, is_thread = False) -> Union[MattermostDocumentModel, None]: + def create_all_using_df(self, db: Session, *, ddf: pd.DataFrame, thread_type = ThreadTypeEnum.MESSAGE) -> Union[MattermostDocumentModel, None]: mattermost_documents = [] for key, row in ddf.iterrows(): @@ -220,7 +223,7 @@ def create_all_using_df(self, db: Session, *, ddf: pd.DataFrame, is_thread = Fal channel=row['channel'], user=row['user'], document=document_obj.id, - is_thread=is_thread, + thread_type=thread_type, info_type=info_type)] return self.create_all_using_id(db, obj_in_list=mattermost_documents) @@ -399,7 +402,7 @@ def populate_mm_document_info(db: Session, *, document_df: pd.DataFrame): udf.rename(columns={"id": "message_id"}, inplace=True) # create new document objects in db - new_mattermost_docs = new_mattermost_docs + mattermost_documents.create_all_using_df(db, ddf=udf, is_thread=False) + new_mattermost_docs = new_mattermost_docs + mattermost_documents.create_all_using_df(db, ddf=udf, thread_type=ThreadTypeEnum.MESSAGE) return new_mattermost_docs @@ -410,32 +413,50 @@ def convert_conversation_threads(df: pd.DataFrame): df['root_id'] = df['root_id'].fillna('') df['message'] = df['message'].fillna('') + df['thread'] = df['message'].fillna('') + df['thread_speaker'] = df['message'].fillna('') + df['thread_speaker_persona'] = df['message'].fillna('') threads = {} + threads_speaker = {} + threads_speaker_persona = {} threads_row = {} + for index, row in df.iterrows(): thread = row['root_id'] utterance = row['message'] + utterance.replace("\n", " ") + speaker = row['user_name'] + persona = row['nickname'] + utterance_speaker = speaker + ': ' + utterance + utterance_speaker_persona = speaker + ' (' + persona + '): ' + utterance p_id = row['message_id'] + if utterance.find("added to the channel") < 0 and utterance.find("joined the channel") < 0 and utterance.find("left the channel") < 0: if len(thread) > 0: if thread not in threads: - threads[thread] = [utterance.replace("\n", " ")] + threads[thread] = [utterance] + threads_speaker[thread] = [utterance_speaker] + threads_speaker_persona[thread] = [utterance_speaker_persona] else: - threads[thread].append(utterance.replace("\n", " ")) + threads[thread].append(utterance) + threads_speaker[thread].append(utterance_speaker) + threads_speaker_persona[thread].append(utterance_speaker_persona) else: - t = [] - t.append(utterance.replace("\n", " ")) - threads[p_id] = t + threads[p_id] = [utterance] + threads_speaker[p_id] = [utterance_speaker] + threads_speaker_persona[p_id] = [utterance_speaker_persona] threads_row[p_id] = row keys = set(sorted(threads.keys())).intersection(threads_row.keys()) - data = [] + conversations = [] for index, key in enumerate(keys): row = threads_row[key] - row['message'] = "\n".join(threads[key]) - data.append(row) + row['thread'] = "\n".join(threads[key]) + row['thread_speaker'] = "\n".join(threads_speaker[key]) + row['thread_speaker_persona'] = "\n".join(threads_speaker_persona[key]) + conversations.append(row) - return pd.DataFrame(data, columns=df.columns) + return pd.DataFrame(conversations, columns=df.columns) def parse_props(jobj: dict): @@ -443,7 +464,8 @@ def parse_props(jobj: dict): author_name = jobj['author_name'] title = jobj['title'] - msg = '[%s] %s' % (jobj['title'], jobj['text']) + fallback = jobj['fallback'] + msg = '[%s] %s' % (title, jobj['text']) if 'Dataminr' in author_name: info_type = InfoTypeEnum.DATAMINR @@ -454,7 +476,10 @@ def parse_props(jobj: dict): info_type = InfoTypeEnum.ARINC elif 'ACARS' in title: info_type = InfoTypeEnum.ACARS - msg = parse_props_acars(jobj) + msg = parse_props_acars(jobj, title=title) + elif 'ACARS Free Text' in fallback: + info_type = InfoTypeEnum.ACARS_TEXT + msg = parse_props_acars(jobj, title='ACARS Free Text') elif 'NOTAM' in title: info_type = InfoTypeEnum.NOTAM msg = parse_props_notam(jobj) @@ -485,8 +510,8 @@ def parse_props_notam(jobj: dict): return msg -def parse_props_acars(jobj: dict): - msg = '[%s] %s' % (jobj['title'], jobj['text']) +def parse_props_acars(jobj: dict, title: str): + msg = '[%s] %s' % (title, jobj['text']) if jobj['fields'] is not None: tail_num = '' diff --git a/app/mattermost/models/mattermost_documents.py b/app/mattermost/models/mattermost_documents.py index 5fba3df..305a03e 100644 --- a/app/mattermost/models/mattermost_documents.py +++ b/app/mattermost/models/mattermost_documents.py @@ -3,7 +3,7 @@ from sqlalchemy import Column, UUID, String, ForeignKey, Enum, JSON, Boolean, UniqueConstraint from sqlalchemy.ext.mutable import MutableDict from ppg.core.config import OriginationEnum -from ppg.schemas.mattermost.mattermost_documents import InfoTypeEnum +from ppg.schemas.mattermost.mattermost_documents import InfoTypeEnum, ThreadTypeEnum from app.db.base_class import Base from app.core.config import get_originated_from @@ -26,11 +26,11 @@ class MattermostDocumentModel(Base): has_reactions = Column(Boolean(), default=False) props = Column(MutableDict.as_mutable(JSON)) doc_metadata = Column(MutableDict.as_mutable(JSON)) - is_thread = Column(Boolean(), default=False) + thread_type = Column(Enum(ThreadTypeEnum), default=ThreadTypeEnum.MESSAGE) info_type = Column(Enum(InfoTypeEnum), default=InfoTypeEnum.CHAT) originated_from = Column(Enum(OriginationEnum), default=get_originated_from) # mattermost message IDs must be unique, # allow for a single conversation thread for each message - __table_args__ = (UniqueConstraint('message_id', 'is_thread', name='_messageid_isthread_uc'),) + __table_args__ = (UniqueConstraint('message_id', 'thread_type', name='_messageid_threadtype_uc'),) diff --git a/app/mattermost/router.py b/app/mattermost/router.py index ced6f14..a33f671 100644 --- a/app/mattermost/router.py +++ b/app/mattermost/router.py @@ -6,7 +6,7 @@ import pandas as pd from sqlalchemy.orm import Session from ppg.schemas.bertopic.document import DocumentUpdate -from ppg.schemas.mattermost.mattermost_documents import MattermostDocument, MattermostDocumentUpdate, InfoTypeEnum +from ppg.schemas.mattermost.mattermost_documents import MattermostDocument, MattermostDocumentUpdate, InfoTypeEnum, ThreadTypeEnum from ppg.schemas.mattermost.mattermost_users import MattermostUser import ppg.services.mattermost_utils as mattermost_utils from app.core.config import settings @@ -140,7 +140,7 @@ async def upload_mm_channel_docs(request: UploadDocumentRequest, db: Session = D adf = adf[~adf.id.isin(existing_ids)].drop_duplicates(subset='id') adf.rename(columns={'id': 'message_id'}, inplace=True) - crud_mattermost.mattermost_documents.create_all_using_df(db, ddf=adf, is_thread=False) + crud_mattermost.mattermost_documents.create_all_using_df(db, ddf=adf, thread_type=ThreadTypeEnum.MESSAGE) return crud_mattermost.mattermost_documents.get_all_channel_documents(db, channels=channel_uuids, @@ -184,16 +184,20 @@ async def get_mm_channel_docs(team_name: str, channel_name: str, class ConversationThreadRequest(BaseModel): mattermost_document_ids: list[UUID4] = [] +class ConversationThreadResponse(BaseModel): + threads: list[MattermostDocument] = [] + threads_speaker: list[MattermostDocument] = [] + threads_speaker_persona: list[MattermostDocument] = [] @router.post( "/mattermost/conversation_threads", - response_model=Union[list[MattermostDocument], HTTPValidationError], + response_model=Union[ConversationThreadResponse, HTTPValidationError], responses={'422': {'model': HTTPValidationError}}, summary="Retrieve Mattermost conversation documents", response_description="Retrieved Mattermost conversation documents") async def convert_conversation_threads(request: ConversationThreadRequest, db: Session = Depends(get_db)) -> ( - Union[list[MattermostDocument], HTTPValidationError] + Union[ConversationThreadResponse, HTTPValidationError] ): """ Retrieve Mattermost conversation documents @@ -206,21 +210,54 @@ async def convert_conversation_threads(request: ConversationThreadRequest, if document_df.empty: raise HTTPException(status_code=422, detail="Mattermost documents not found") + # only thread user chat messages. other types, such as ACARS, NOTAMS should remain unthreaded + chat_df = document_df[document_df['info_type'] == InfoTypeEnum.CHAT] + other_df = document_df[document_df['info_type'] != InfoTypeEnum.CHAT] + # convert message utterances to conversation threads - conversation_df = crud_mattermost.convert_conversation_threads(df=document_df) + conversation_df = crud_mattermost.convert_conversation_threads(df=chat_df) conversation_df.rename(columns={'user_uuid': 'user','channel_uuid': 'channel'}, inplace=True) - document_objs = [] - new_threads_df = pd.DataFrame() + other_mm_doc_objs = [crud_mattermost.mattermost_documents.get_by_message_id(db, message_id=row['message_id']) + for key, row in other_df.iterrows()] + if not other_df.empty and (len(other_mm_doc_objs) != len(other_df)): + raise HTTPException(status_code=422, detail="Unable to find non chat documents") + + thread_document_objs = ConversationThreadResponse() + thread_document_objs.threads = create_conversation_objects(db=db, + thread_type=ThreadTypeEnum.THREAD, + conversation_df=conversation_df) + other_mm_doc_objs + + thread_document_objs.threads_speaker = create_conversation_objects(db=db, + thread_type=ThreadTypeEnum.THREAD_USER, + conversation_df=conversation_df) + other_mm_doc_objs + thread_document_objs.threads_speaker_persona = create_conversation_objects(db=db, + thread_type=ThreadTypeEnum.THREAD_USER_PERSONA, + conversation_df=conversation_df) + other_mm_doc_objs + + return thread_document_objs + +def create_conversation_objects(db: Session, thread_type: ThreadTypeEnum, conversation_df: pd.DataFrame) -> list[MattermostDocument]: + + thread_document_objs = [] + thread_df = pd.DataFrame() + for _, row in conversation_df.iterrows(): - mm_document_obj = crud_mattermost.mattermost_documents.get_by_message_id(db, message_id=row['message_id'], is_thread=True) + + thread_str = row['thread'] + if thread_type == ThreadTypeEnum.THREAD_USER: + thread_str = row['thread_speaker'] + if thread_type == ThreadTypeEnum.THREAD_USER_PERSONA: + thread_str = row['thread_speaker_persona'] + + mm_document_obj = crud_mattermost.mattermost_documents.get_by_message_id(db, message_id=row['message_id'], thread_type=thread_type) # update existing thread if mm_document_obj: document_obj = crud_document.document.get(db, id=row['document_id']) crud_document.document.update(db, db_obj=document_obj, - obj_in=DocumentUpdate(text=row['message'], + obj_in=DocumentUpdate(text=thread_str, original_created_time=document_obj.original_created_time)) updated_mm_doc_obj = crud_mattermost.mattermost_documents.update(db, db_obj=mm_document_obj, @@ -235,23 +272,23 @@ async def convert_conversation_threads(request: ConversationThreadRequest, channel=mm_document_obj.channel, user=mm_document_obj.user, document=mm_document_obj.document, - is_thread=True, + thread_type=thread_type, info_type=mm_document_obj.info_type)) - document_objs = document_objs + [updated_mm_doc_obj] + thread_document_objs = thread_document_objs + [updated_mm_doc_obj] else: - new_threads_df = pd.concat([new_threads_df, pd.DataFrame([row])]) + row['message'] = thread_str + thread_df = pd.concat([thread_df, pd.DataFrame([row])]) # create new thread objects - if not new_threads_df.empty: - new_mm_doc_objs = crud_mattermost.mattermost_documents.create_all_using_df(db, ddf=new_threads_df, is_thread=True) - document_objs = document_objs + new_mm_doc_objs + if not thread_df.empty: + new_mm_doc_objs = crud_mattermost.mattermost_documents.create_all_using_df(db, ddf=thread_df, thread_type=thread_type) + thread_document_objs = thread_document_objs + new_mm_doc_objs - if len(document_objs) != len(conversation_df): + if len(thread_document_objs) != len(conversation_df): raise HTTPException(status_code=422, detail="Unable to create conversation threads") - return document_objs - + return thread_document_objs class SubstringUploadRequest(BaseModel): team_id: str diff --git a/poetry.lock b/poetry.lock index 5d1922b..2658205 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3011,7 +3011,7 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "ppg-common" -version = "1.7.0" +version = "1.9.0" description = "A library for PPG common code" optional = false python-versions = "*" diff --git a/ppg-common/ppg/schemas/mattermost/mattermost_documents.py b/ppg-common/ppg/schemas/mattermost/mattermost_documents.py index f44da61..2f4363f 100644 --- a/ppg-common/ppg/schemas/mattermost/mattermost_documents.py +++ b/ppg-common/ppg/schemas/mattermost/mattermost_documents.py @@ -9,11 +9,17 @@ class InfoTypeEnum(str, enum.Enum): NOTAM = "notam" DATAMINR = "dataminr" ACARS = "acars" + ACARS_TEXT = "acars_text" ENVISION = "envision" CAMPS = "camps" ARINC = "arinc" UDL = "udl" +class ThreadTypeEnum(str, enum.Enum): + MESSAGE = "message" + THREAD = "thread" + THREAD_USER = "thread_user" + THREAD_USER_PERSONA = "thread_user_persona" # Shared properties class MattermostDocumentBase(BaseModel): @@ -27,7 +33,7 @@ class MattermostDocumentBase(BaseModel): has_reactions: bool props: dict doc_metadata: dict - is_thread: bool + thread_type: ThreadTypeEnum info_type: InfoTypeEnum class MattermostDocumentCreate(MattermostDocumentBase): diff --git a/ppg-common/setup.py b/ppg-common/setup.py index 80be210..5312726 100644 --- a/ppg-common/setup.py +++ b/ppg-common/setup.py @@ -1,7 +1,7 @@ from setuptools import find_packages, setup setup(name='ppg-common', - version='1.7.0', + version='1.9.0', description='A library for PPG common code', url='--', author='MIT Lincoln Laboratory', diff --git a/tests/mattermost/test_mattermost_crud.py b/tests/mattermost/test_mattermost_crud.py index 148ad5e..f23ec2d 100644 --- a/tests/mattermost/test_mattermost_crud.py +++ b/tests/mattermost/test_mattermost_crud.py @@ -8,7 +8,7 @@ from app.aimodels.bertopic.models.document import DocumentModel from app.aimodels.bertopic.crud import crud_document from ppg.core.config import OriginationEnum -from ppg.schemas.mattermost.mattermost_documents import InfoTypeEnum +from ppg.schemas.mattermost.mattermost_documents import InfoTypeEnum, ThreadTypeEnum def test_crud_mattermost(db: Session): @@ -142,8 +142,11 @@ def test_crud_mattermost(db: Session): 'props': {'leggo': 'myeggo'}, 'metadata': {'cuckoo': 'forcocoapuffs'}, }]) - cdf_is_thread = cdf.loc[0, 'root_id'] == db_obj.root_message_id - mmdocs = crud.mattermost_documents.create_all_using_df(db, ddf=cdf, is_thread=cdf_is_thread) + if cdf.loc[0, 'root_id'] == db_obj.root_message_id: + cdf_thread_type = ThreadTypeEnum.THREAD + else: + cdf_thread_type = ThreadTypeEnum.MESSAGE + mmdocs = crud.mattermost_documents.create_all_using_df(db, ddf=cdf, thread_type=cdf_thread_type) assert len(mmdocs) == 1 mmdoc = mmdocs[0] newdoc = crud_document.document.get(db, mmdoc.document) @@ -157,7 +160,7 @@ def test_crud_mattermost(db: Session): assert mmdoc.has_reactions == cdf.loc[0, 'has_reactions'] assert mmdoc.props == cdf.loc[0, 'props'] assert mmdoc.doc_metadata == cdf.loc[0, 'metadata'] - assert mmdoc.is_thread == cdf_is_thread + assert mmdoc.thread_type == cdf_thread_type assert mmdoc.originated_from == settings.originated_from @@ -205,18 +208,30 @@ def test_convert_conversation_threads(): msg1 = 'message 1.' msg2 = 'message 2.' + usr1 = 'user_a' + usr2 = 'user_b' # construct message data frame with reply and convert to conversation thread document_df = pd.DataFrame() - document_df = pd.concat([document_df, pd.DataFrame( - [{'message_id': '1', 'message': msg1, 'root_id': ''}])]) - document_df = pd.concat([document_df, pd.DataFrame( - [{'message_id': '2', 'message': msg2, 'root_id': '1'}])]) + document_df = pd.concat([document_df, pd.DataFrame([{'message_id': '1', + 'message': msg1, + 'root_id': '', + 'user_name': usr1, + 'nickname': usr1, + 'info_type': InfoTypeEnum.CHAT}])]) + document_df = pd.concat([document_df, pd.DataFrame([{'message_id': '2', + 'message': msg2, + 'root_id': '1', + 'user_name': usr2, + 'nickname': usr2, + 'info_type': InfoTypeEnum.CHAT}])]) conversation_df = crud.convert_conversation_threads(document_df) assert len(conversation_df) == (len(document_df) - 1) - assert conversation_df['message'].iloc[0] == '%s\n%s' % (msg1, msg2) + assert conversation_df['thread'].iloc[0] == '%s\n%s' % (msg1, msg2) + assert conversation_df['thread_speaker'].iloc[0] == '%s: %s\n%s: %s' % (usr1, msg1, usr2, msg2) + assert conversation_df['thread_speaker_persona'].iloc[0] == '%s (%s): %s\n%s (%s): %s' % (usr1, usr1, msg1, usr2, usr2, msg2) def test_parse_props(): @@ -229,6 +244,7 @@ def test_parse_props(): 'author_name': aname, 'title': ittl, 'text': imsg, + 'fallback': '', 'fields': []}]} itype, omsg = crud.parse_props(jobj) emsg = '[%s] %s' % (ittl, imsg) @@ -275,6 +291,7 @@ def test_parse_props_notam(): 'author_name': '', 'title': ittl, 'text': imsg, + 'fallback': '', 'fields': [{'title': 'Location', 'value': 'KCAT', 'short': True}, {'title': 'Valid', 'value': '4149/0409Z - 4201/2359Z', 'short': True}]}]} itype, omsg = crud.parse_props(jobj) @@ -293,6 +310,7 @@ def test_parse_props_acars(): 'author_name': '', 'title': ittl, 'text': imsg, + 'fallback': '', 'fields': [{'title': 'Tail #', 'value': '8675309', 'short': True}, {'title': 'Mission #', 'value': '8675309', 'short': True}, {'title': 'Callsign', 'value': 'CAT123', 'short': True}]}]} @@ -311,6 +329,7 @@ def test_parse_props_dataminr(): 'author_name': 'Dataminr', 'title': imsg, 'text': '', + 'fallback': '', 'fields': [{'title': 'Alert Type', 'value': 'Urgent', 'short': False}, {'title': 'Event Time', 'value': '26/06/2024 18:08:19', 'short': False}, {'title': 'Event Location', 'value': 'Lexington, MA USA\n', 'short': False}, diff --git a/tests/mattermost/test_mattermost_router.py b/tests/mattermost/test_mattermost_router.py index f3280b1..ee3c2c6 100644 --- a/tests/mattermost/test_mattermost_router.py +++ b/tests/mattermost/test_mattermost_router.py @@ -13,6 +13,7 @@ from app.mattermost.models.mattermost_documents import MattermostDocumentModel from app.aimodels.bertopic.crud import crud_document from ppg.schemas.bertopic.document import DocumentCreate +from ppg.schemas.mattermost.mattermost_documents import ThreadTypeEnum @pytest.fixture(scope='module') def channel_db_obj(db: Session): @@ -72,7 +73,7 @@ def mm_db_obj_thread(channel_db_obj: MattermostChannelModel, hashtags='', props=dict(), doc_metadata=dict(), - is_thread=True) + thread_type=ThreadTypeEnum.THREAD) return crud_mattermost.mattermost_documents.create(db, obj_in=mm_doc_obj_in) # returns 422 @@ -335,7 +336,9 @@ def test_mattermost_conversation_thread_no_thread(mm_db_obj: MattermostDocumentM mm_docs = response.json() assert response.status_code == 200 - assert str(mm_db_obj.message_id) in [mm_doc['message_id'] for mm_doc in mm_docs] + assert str(mm_db_obj.message_id) in [mm_doc['message_id'] for mm_doc in mm_docs['threads']] + assert str(mm_db_obj.message_id) in [mm_doc['message_id'] for mm_doc in mm_docs['threads_speaker']] + assert str(mm_db_obj.message_id) in [mm_doc['message_id'] for mm_doc in mm_docs['threads_speaker_persona']] def test_mattermost_conversation_thread_thread(mm_db_obj_thread: MattermostDocumentModel, client: TestClient): @@ -346,4 +349,6 @@ def test_mattermost_conversation_thread_thread(mm_db_obj_thread: MattermostDocum mm_docs = response.json() assert response.status_code == 200 - assert str(mm_db_obj_thread.message_id) in [mm_doc['message_id'] for mm_doc in mm_docs] + assert str(mm_db_obj_thread.message_id) in [mm_doc['message_id'] for mm_doc in mm_docs['threads']] + assert str(mm_db_obj_thread.message_id) in [mm_doc['message_id'] for mm_doc in mm_docs['threads_speaker']] + assert str(mm_db_obj_thread.message_id) in [mm_doc['message_id'] for mm_doc in mm_docs['threads_speaker_persona']]