diff --git a/clients/ts-sdk/openapi.json b/clients/ts-sdk/openapi.json index 8dca178a9f..83e998741e 100644 --- a/clients/ts-sdk/openapi.json +++ b/clients/ts-sdk/openapi.json @@ -12,7 +12,7 @@ "name": "BSL", "url": "https://github.com/devflowinc/trieve/blob/main/LICENSE.txt" }, - "version": "0.12.0" + "version": "0.12.1" }, "servers": [ { @@ -1426,7 +1426,7 @@ "Chunk" ], "summary": "Split HTML Content into Chunks", - "description": "This endpoint receives a single html string and splits it into chunks based on the headings and\nbody content. The headings are split based on headding html tags. chunk_html has a maximum size\nof 256Kb.", + "description": "This endpoint receives a single html string and splits it into chunks based on the headings and\nbody content. The headings are split based on heading html tags. chunk_html has a maximum size\nof 256Kb.", "operationId": "split_html_content", "requestBody": { "description": "JSON request payload to perform RAG on some chunks (chunks)", @@ -4648,7 +4648,7 @@ "security": [ { "ApiKey": [ - "readonly" + "admin" ] } ] @@ -4710,7 +4710,7 @@ "security": [ { "ApiKey": [ - "readonly" + "admin" ] } ] @@ -4780,7 +4780,7 @@ "security": [ { "ApiKey": [ - "readonly" + "admin" ] } ] @@ -4849,7 +4849,68 @@ "security": [ { "ApiKey": [ - "readonly" + "admin" + ] + } + ] + } + }, + "/api/message/{message_id}": { + "get": { + "tags": [ + "Message" + ], + "summary": "Get a message by its ID", + "description": "Quickly get the full object for a given message. From the message, you can get the topic and all messages which exist on that topic.", + "operationId": "get_message_by_id", + "parameters": [ + { + "name": "TR-Dataset", + "in": "header", + "description": "The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid.", + "required": true, + "schema": { + "type": "string", + "format": "uuid" + } + }, + { + "name": "message_id", + "in": "path", + "description": "The ID of the message to get.", + "required": true, + "schema": { + "type": "string", + "format": "uuid" + } + } + ], + "responses": { + "200": { + "description": "Message with the given ID", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Message" + } + } + } + }, + "400": { + "description": "Service error relating to getting the message", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponseBody" + } + } + } + } + }, + "security": [ + { + "ApiKey": [ + "admin" ] } ] @@ -4861,7 +4922,7 @@ "Message" ], "summary": "Get all messages for a given topic", - "description": "Get all messages for a given topic. If the topic is a RAG topic then the response will include Chunks first on each message. The structure will look like `[chunks]||mesage`. See docs.trieve.ai for more information.", + "description": "If the topic is a RAG topic then the response will include Chunks first on each message. The structure will look like `[chunks]||mesage`. See docs.trieve.ai for more information.", "operationId": "get_all_topic_messages", "parameters": [ { @@ -13599,7 +13660,7 @@ }, "group_size": { "type": "integer", - "format": "int32", + "format": "int64", "description": "Group_size is the number of chunks to fetch for each group. The default is 3. If a group has less than group_size chunks, all chunks will be returned. If this is set to a large number, we recommend setting slim_chunks to true to avoid returning the content and chunk_html of the chunks so as to lower the amount of time required for content download and serialization.", "nullable": true, "minimum": 0 @@ -13648,6 +13709,14 @@ "description": "Set slim_chunks to true to avoid returning the content and chunk_html of the chunks. This is useful for when you want to reduce amount of data over the wire for latency improvement (typicall 10-50ms). Default is false.", "nullable": true }, + "sort_options": { + "allOf": [ + { + "$ref": "#/components/schemas/SortOptions" + } + ], + "nullable": true + }, "typo_options": { "allOf": [ { @@ -15489,6 +15558,11 @@ "type": "string", "description": "Time stamp should be an ISO 8601 combined date and time without timezone. Time_stamp is used for time window filtering and recency-biasing search results. Will be passed down to the file's chunks.", "nullable": true + }, + "use_pdf2md_ocr": { + "type": "boolean", + "description": "Parameter to use pdf2md_ocr. If true, the file will be converted to markdown using gpt-4o.\nDefault is false.", + "nullable": true } }, "example": { @@ -15512,7 +15586,8 @@ "tag2" ], "target_splits_per_chunk": 20, - "time_stamp": "2021-01-01 00:00:00.000Z" + "time_stamp": "2021-01-01 00:00:00.000Z", + "use_pdf2md_ocr": false } }, "UploadFileResult": { diff --git a/clients/ts-sdk/package.json b/clients/ts-sdk/package.json index 29b2f0e3e4..588da91b76 100644 --- a/clients/ts-sdk/package.json +++ b/clients/ts-sdk/package.json @@ -6,7 +6,7 @@ "files": [ "dist" ], - "version": "0.0.30", + "version": "0.0.31", "license": "MIT", "scripts": { "lint": "eslint 'src/**/*.ts'", diff --git a/clients/ts-sdk/src/functions/message/index.ts b/clients/ts-sdk/src/functions/message/index.ts index 95eadb43ef..4fd3bb7911 100644 --- a/clients/ts-sdk/src/functions/message/index.ts +++ b/clients/ts-sdk/src/functions/message/index.ts @@ -7,6 +7,7 @@ import { CreateMessageReqPayload, EditMessageReqPayload, GetAllTopicMessagesData, + GetMessageByIdData, RegenerateMessageReqPayload, } from "../../fetch-client"; import { TrieveSDK } from "../../sdk"; @@ -358,3 +359,31 @@ export async function getAllMessagesForTopic( signal ); } + +/** + * Quickly get the full object for a given message. From the message, you can get the topic and all messages which exist on that topic. + * + * Example: + * ```js + *const data = await trieve.getMessageById({ + messageId: "3c90c3cc-1d76-27198-8888-8dd25736052a", +}); + * ``` + */ +export async function getMessageById( + /** @hidden */ + this: TrieveSDK, + data: Omit, + signal?: AbortSignal +) { + return await this.trieve.fetch( + "/api/message/{message_id}", + "get", + { + ...data, + datasetId: this.datasetId, + }, + signal + ); +} + diff --git a/clients/ts-sdk/src/types.gen.ts b/clients/ts-sdk/src/types.gen.ts index 7546322b8c..7a9797358e 100644 --- a/clients/ts-sdk/src/types.gen.ts +++ b/clients/ts-sdk/src/types.gen.ts @@ -2396,6 +2396,7 @@ export type SearchOverGroupsReqPayload = { * Set slim_chunks to true to avoid returning the content and chunk_html of the chunks. This is useful for when you want to reduce amount of data over the wire for latency improvement (typicall 10-50ms). Default is false. */ slim_chunks?: (boolean) | null; + sort_options?: ((SortOptions) | null); typo_options?: ((TypoOptions) | null); /** * If true, quoted and - prefixed words will be parsed from the queries and used as required and negated words respectively. Default is false. @@ -3050,6 +3051,11 @@ export type UploadFileReqPayload = { * Time stamp should be an ISO 8601 combined date and time without timezone. Time_stamp is used for time window filtering and recency-biasing search results. Will be passed down to the file's chunks. */ time_stamp?: (string) | null; + /** + * Parameter to use pdf2md_ocr. If true, the file will be converted to markdown using gpt-4o. + * Default is false. + */ + use_pdf2md_ocr?: (boolean) | null; }; export type UploadFileResult = { @@ -4105,6 +4111,19 @@ export type RegenerateMessagePatchData = { export type RegenerateMessagePatchResponse = (string); +export type GetMessageByIdData = { + /** + * The ID of the message to get. + */ + messageId: string; + /** + * The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid. + */ + trDataset: string; +}; + +export type GetMessageByIdResponse = (Message); + export type GetAllTopicMessagesData = { /** * The ID of the topic to get messages for. @@ -5533,6 +5552,21 @@ export type $OpenApiTs = { }; }; }; + '/api/message/{message_id}': { + get: { + req: GetMessageByIdData; + res: { + /** + * Message with the given ID + */ + 200: Message; + /** + * Service error relating to getting the message + */ + 400: ErrorResponseBody; + }; + }; + }; '/api/messages/{messages_topic_id}': { get: { req: GetAllTopicMessagesData; diff --git a/server/src/handlers/message_handler.rs b/server/src/handlers/message_handler.rs index 86f08182c2..2f0f0fa5a0 100644 --- a/server/src/handlers/message_handler.rs +++ b/server/src/handlers/message_handler.rs @@ -13,8 +13,9 @@ use crate::{ chunk_operator::{get_chunk_metadatas_from_point_ids, get_random_chunk_metadatas_query}, clickhouse_operator::EventQueue, message_operator::{ - create_topic_message_query, delete_message_query, get_message_by_sort_for_topic_query, - get_messages_for_topic_query, get_topic_messages, stream_response, + create_topic_message_query, delete_message_query, get_message_by_id_query, + get_message_by_sort_for_topic_query, get_messages_for_topic_query, + get_topic_messages_query, stream_response, }, organization_operator::get_message_org_count, parse_operator::convert_html_to_text, @@ -132,7 +133,7 @@ pub struct CreateMessageReqPayload { ("TR-Dataset" = uuid::Uuid, Header, description = "The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid."), ), security( - ("ApiKey" = ["readonly"]), + ("ApiKey" = ["admin"]), ) )] #[tracing::instrument(skip(pool, event_queue))] @@ -187,7 +188,7 @@ pub async fn create_message( ); // get the previous messages - let mut previous_messages = get_topic_messages( + let mut previous_messages = get_topic_messages_query( topic_id, dataset_org_plan_sub.dataset.id, &get_messages_pool, @@ -247,7 +248,7 @@ pub async fn create_message( /// Get all messages for a given topic /// -/// Get all messages for a given topic. If the topic is a RAG topic then the response will include Chunks first on each message. The structure will look like `[chunks]||mesage`. See docs.trieve.ai for more information. +/// If the topic is a RAG topic then the response will include Chunks first on each message. The structure will look like `[chunks]||mesage`. See docs.trieve.ai for more information. #[utoipa::path( get, path = "/messages/{messages_topic_id}", @@ -259,7 +260,7 @@ pub async fn create_message( ), params( ("TR-Dataset" = uuid::Uuid, Header, description = "The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid."), - ("messages_topic_id" = uuid, description = "The ID of the topic to get messages for."), + ("messages_topic_id" = uuid, Path, description = "The ID of the topic to get messages for."), ), security( ("ApiKey" = ["readonly"]), @@ -295,6 +296,43 @@ pub async fn get_all_topic_messages( Ok(HttpResponse::Ok().json(messages)) } +/// Get a message by its ID +/// +/// Quickly get the full object for a given message. From the message, you can get the topic and all messages which exist on that topic. +#[utoipa::path( + get, + path = "/message/{message_id}", + context_path = "/api", + tag = "Message", + responses( + (status = 200, description = "Message with the given ID", body = Message), + (status = 400, description = "Service error relating to getting the message", body = ErrorResponseBody), + ), + params( + ("TR-Dataset" = uuid::Uuid, Header, description = "The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid."), + ("message_id" = uuid, Path, description = "The ID of the message to get."), + ), + security( + ("ApiKey" = ["admin"]), + ) +)] +#[tracing::instrument(skip(pool))] +pub async fn get_message_by_id( + user: AdminOnly, + message_id: web::Path, + dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, + pool: web::Data, +) -> Result { + let message: models::Message = get_message_by_id_query( + message_id.into_inner(), + dataset_org_plan_sub.dataset.id, + &pool, + ) + .await?; + + Ok(HttpResponse::Ok().json(message)) +} + #[derive(Serialize, Debug, ToSchema)] pub struct RegenerateMessageReqPayload { /// The id of the topic to regenerate the last message for. @@ -416,7 +454,7 @@ impl From for CreateMessageReqPayload { ("TR-Dataset" = uuid::Uuid, Header, description = "The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid."), ), security( - ("ApiKey" = ["readonly"]), + ("ApiKey" = ["admin"]), ) )] #[tracing::instrument(skip(pool, event_queue, redis_pool))] @@ -490,7 +528,7 @@ pub async fn edit_message( ("TR-Dataset" = uuid::Uuid, Header, description = "The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid."), ), security( - ("ApiKey" = ["readonly"]), + ("ApiKey" = ["admin"]), ) )] #[tracing::instrument(skip(pool, event_queue, redis_pool))] @@ -513,7 +551,7 @@ pub async fn regenerate_message_patch( let dataset_id = dataset_org_plan_sub.dataset.id; let mut previous_messages = - get_topic_messages(topic_id, dataset_id, &get_messages_pool).await?; + get_topic_messages_query(topic_id, dataset_id, &get_messages_pool).await?; if previous_messages.len() < 2 { return Err( @@ -634,7 +672,7 @@ pub async fn regenerate_message_patch( ("TR-Dataset" = uuid::Uuid, Header, description = "The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid."), ), security( - ("ApiKey" = ["readonly"]), + ("ApiKey" = ["admin"]), ) )] #[deprecated] diff --git a/server/src/handlers/topic_handler.rs b/server/src/handlers/topic_handler.rs index 54245d743f..65e4ba83be 100644 --- a/server/src/handlers/topic_handler.rs +++ b/server/src/handlers/topic_handler.rs @@ -3,7 +3,7 @@ use crate::{ errors::ServiceError, handlers::auth_handler::AdminOnly, operators::{ - message_operator::{create_messages_query, get_topic_messages, get_topic_string}, + message_operator::{create_messages_query, get_topic_messages_query, get_topic_string}, topic_operator::{ create_topic_query, delete_topic_query, get_all_topics_for_owner_id_query, get_topic_query, update_topic_query, @@ -139,7 +139,7 @@ pub async fn clone_topic( create_topic_query(new_topic.clone(), &pool).await?; let mut old_messages = - get_topic_messages(original_topic.id, dataset_org_plan_sub.dataset.id, &pool).await?; + get_topic_messages_query(original_topic.id, dataset_org_plan_sub.dataset.id, &pool).await?; old_messages.iter_mut().for_each(|message| { message.topic_id = new_topic.id; diff --git a/server/src/lib.rs b/server/src/lib.rs index a176d294e1..b0692c0880 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -143,7 +143,7 @@ impl Modify for SecurityAddon { name = "BSL", url = "https://github.com/devflowinc/trieve/blob/main/LICENSE.txt", ), - version = "0.12.0", + version = "0.12.1", ), servers( (url = "https://api.trieve.ai", @@ -167,6 +167,7 @@ impl Modify for SecurityAddon { handlers::topic_handler::clone_topic, handlers::topic_handler::get_all_topics_for_owner_id, handlers::message_handler::create_message, + handlers::message_handler::get_message_by_id, handlers::message_handler::get_all_topic_messages, handlers::message_handler::edit_message, handlers::message_handler::regenerate_message, @@ -921,6 +922,10 @@ pub fn main() -> std::io::Result<()> { .to(handlers::message_handler::regenerate_message), ), ) + .service( + web::resource("/message/{message_id}") + .route(web::get().to(handlers::message_handler::get_message_by_id)) + ) .service( web::resource("/messages/{messages_topic_id}").route( web::get().to(handlers::message_handler::get_all_topic_messages), diff --git a/server/src/operators/message_operator.rs b/server/src/operators/message_operator.rs index be3c883b3d..788abed014 100644 --- a/server/src/operators/message_operator.rs +++ b/server/src/operators/message_operator.rs @@ -44,7 +44,7 @@ pub struct ChatCompletionDTO { } #[tracing::instrument(skip(pool))] -pub async fn get_topic_messages( +pub async fn get_topic_messages_query( messages_topic_id: uuid::Uuid, given_dataset_id: uuid::Uuid, pool: &web::Data, @@ -69,6 +69,32 @@ pub async fn get_topic_messages( Ok(topic_messages) } +#[tracing::instrument(skip(pool))] +pub async fn get_message_by_id_query( + message_id: uuid::Uuid, + given_dataset_id: uuid::Uuid, + pool: &web::Data, +) -> Result { + use crate::data::schema::messages::dsl as messages_columns; + + let mut conn = pool.get().await.map_err(|_e| { + ServiceError::InternalServerError("Failed to get postgres connection".to_string()) + })?; + + let message = messages_columns::messages + .filter(messages_columns::id.eq(message_id)) + .filter(messages_columns::dataset_id.eq(given_dataset_id)) + .filter(messages_columns::deleted.eq(false)) + .first::(&mut conn) + .await + .map_err(|db_error| { + log::error!("Error getting message by id {:?}", db_error); + ServiceError::BadRequest("Error getting message by id".to_string()) + })?; + + Ok(message) +} + #[tracing::instrument(skip(pool))] pub async fn create_messages_query( new_messages: Vec,