Skip to content

Commit

Permalink
feat: custom completion prompt (#906)
Browse files Browse the repository at this point in the history
* feat: custom completion prompt

* chore: custom prompt
  • Loading branch information
appflowy authored Oct 20, 2024
1 parent 57c4481 commit 2f715c3
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 33 deletions.
35 changes: 24 additions & 11 deletions libs/appflowy-ai-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::dto::{
AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, CreateTextChatContext,
Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData,
CustomPrompt, Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData,
RepeatedLocalAIPackage, RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse,
TranslateRowData, TranslateRowResponse,
};
Expand Down Expand Up @@ -49,19 +49,29 @@ impl AppFlowyAIClient {
Ok(())
}

pub async fn completion_text(
pub async fn completion_text<T: Into<Option<CompletionType>>>(
&self,
text: &str,
completion_type: CompletionType,
completion_type: T,
custom_prompt: Option<CustomPrompt>,
model: AIModel,
) -> Result<CompleteTextResponse, AIError> {
let completion_type = completion_type.into();

if completion_type.is_some() && custom_prompt.is_some() {
return Err(AIError::InvalidRequest(
"Cannot specify both completion_type and custom_prompt".to_string(),
));
}

if text.is_empty() {
return Err(AIError::InvalidRequest("Empty text".to_string()));
}

let params = json!({
"text": text,
"type": completion_type as u8,
"type": completion_type.map(|t| t as u8),
"custom_prompt": custom_prompt,
});

let url = format!("{}/completion", self.url);
Expand All @@ -76,19 +86,22 @@ impl AppFlowyAIClient {
.into_data()
}

pub async fn stream_completion_text(
pub async fn stream_completion_text<T: Into<Option<CompletionType>>>(
&self,
text: &str,
completion_type: CompletionType,
completion_type: T,
custom_prompt: Option<CustomPrompt>,
model: AIModel,
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
let completion_type = completion_type.into();
if text.is_empty() {
return Err(AIError::InvalidRequest("Empty text".to_string()));
}

let params = json!({
"text": text,
"type": completion_type as u8,
"type": completion_type.map(|t| t as u8),
"custom_prompt": custom_prompt,
});

let url = format!("{}/completion/stream", self.url);
Expand Down Expand Up @@ -201,7 +214,7 @@ impl AppFlowyAIClient {
chat_id: &str,
content: &str,
model: &AIModel,
metadata: Option<serde_json::Value>,
metadata: Option<Value>,
) -> Result<ChatAnswer, AIError> {
let json = ChatQuestion {
chat_id: chat_id.to_string(),
Expand All @@ -226,7 +239,7 @@ impl AppFlowyAIClient {
&self,
chat_id: &str,
content: &str,
metadata: Option<serde_json::Value>,
metadata: Option<Value>,
model: &AIModel,
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
let json = ChatQuestion {
Expand All @@ -251,7 +264,7 @@ impl AppFlowyAIClient {
&self,
chat_id: &str,
content: &str,
metadata: Option<serde_json::Value>,
metadata: Option<Value>,
model: &AIModel,
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
let json = ChatQuestion {
Expand Down Expand Up @@ -337,7 +350,7 @@ pub struct AIResponse<T> {

impl<T> AIResponse<T>
where
T: serde::de::DeserializeOwned + 'static,
T: DeserializeOwned + 'static,
{
pub async fn from_response(resp: reqwest::Response) -> Result<Self, anyhow::Error> {
let status_code = resp.status();
Expand Down
13 changes: 10 additions & 3 deletions libs/appflowy-ai-client/src/dto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ impl Display for EmbeddingsModel {
pub enum AIModel {
#[default]
DefaultModel = 0,
GPT35 = 1,
GPT4oMini = 1,
GPT4o = 2,
Claude3Sonnet = 3,
Claude3Opus = 4,
Expand All @@ -215,7 +215,7 @@ impl AIModel {
pub fn to_str(&self) -> &str {
match self {
AIModel::DefaultModel => "default-model",
AIModel::GPT35 => "gpt-3.5-turbo",
AIModel::GPT4oMini => "gpt-4o-mini",
AIModel::GPT4o => "gpt-4o",
AIModel::Claude3Sonnet => "claude-3-sonnet",
AIModel::Claude3Opus => "claude-3-opus",
Expand All @@ -228,7 +228,8 @@ impl FromStr for AIModel {

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gpt-3.5-turbo" => Ok(AIModel::GPT35),
"gpt-3.5-turbo" => Ok(AIModel::GPT4oMini),
"gpt-4o-mini" => Ok(AIModel::GPT4oMini),
"gpt-4o" => Ok(AIModel::GPT4o),
"claude-3-sonnet" => Ok(AIModel::Claude3Sonnet),
"claude-3-opus" => Ok(AIModel::Claude3Opus),
Expand Down Expand Up @@ -364,3 +365,9 @@ impl Display for CreateTextChatContext {
))
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CustomPrompt {
pub system: String,
pub user: Option<String>,
}
7 changes: 5 additions & 2 deletions libs/appflowy-ai-client/tests/chat_test/completion_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ async fn continue_writing_test() {
.completion_text(
"I feel hungry",
CompletionType::ContinueWriting,
None,
AIModel::Claude3Sonnet,
)
.await
Expand All @@ -23,7 +24,8 @@ async fn improve_writing_test() {
.completion_text(
"I fell tired because i sleep not very well last night",
CompletionType::ImproveWriting,
AIModel::GPT35,
None,
AIModel::GPT4oMini,
)
.await
.unwrap();
Expand All @@ -39,7 +41,8 @@ async fn make_text_shorter_text() {
.stream_completion_text(
"I have an immense passion and deep-seated affection for Rust, a modern, multi-paradigm, high-performance programming language that I find incredibly satisfying to use due to its focus on safety, speed, and concurrency",
CompletionType::MakeShorter,
AIModel::GPT35
None,
AIModel::GPT4oMini
)
.await
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion libs/appflowy-ai-client/tests/chat_test/context_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async fn create_chat_context_test() {
};
client.create_chat_text_context(context).await.unwrap();
let resp = client
.send_question(&chat_id, "Where I live?", &AIModel::GPT35, None)
.send_question(&chat_id, "Where I live?", &AIModel::GPT4oMini, None)
.await
.unwrap();
// response will be something like:
Expand Down
8 changes: 4 additions & 4 deletions libs/appflowy-ai-client/tests/chat_test/qa_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ async fn qa_test() {
client.health_check().await.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();
let resp = client
.send_question(&chat_id, "I feel hungry", &AIModel::GPT35, None)
.send_question(&chat_id, "I feel hungry", &AIModel::GPT4o, None)
.await
.unwrap();
assert!(!resp.content.is_empty());

let questions = client
.get_related_question(&chat_id, &1, &AIModel::GPT35)
.get_related_question(&chat_id, &1, &AIModel::GPT4oMini)
.await
.unwrap()
.items;
Expand All @@ -29,7 +29,7 @@ async fn stop_stream_test() {
client.health_check().await.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();
let mut stream = client
.stream_question(&chat_id, "I feel hungry", None, &AIModel::GPT35)
.stream_question(&chat_id, "I feel hungry", None, &AIModel::GPT4oMini)
.await
.unwrap();

Expand All @@ -51,7 +51,7 @@ async fn stream_test() {
client.health_check().await.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();
let stream = client
.stream_question_v2(&chat_id, "I feel hungry", None, &AIModel::GPT35)
.stream_question_v2(&chat_id, "I feel hungry", None, &AIModel::GPT4oMini)
.await
.unwrap();
let json_stream = JsonStream::<serde_json::Value>::new(stream);
Expand Down
2 changes: 1 addition & 1 deletion libs/appflowy-ai-client/tests/row_test/summarize_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ async fn summarize_row_test() {
let json = json!({"name": "Jack", "age": 25, "city": "New York"});

let result = client
.summarize_row(json.as_object().unwrap(), AIModel::GPT35)
.summarize_row(json.as_object().unwrap(), AIModel::GPT4oMini)
.await
.unwrap();
result.text.contains("Jack");
Expand Down
5 changes: 4 additions & 1 deletion libs/appflowy-ai-client/tests/row_test/translate_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ async fn translate_row_test() {
include_header: false,
};

let result = client.translate_row(data, AIModel::GPT35).await.unwrap();
let result = client
.translate_row(data, AIModel::GPT4oMini)
.await
.unwrap();
assert_eq!(result.items.len(), 2);
}
2 changes: 1 addition & 1 deletion libs/client-api/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl Client {
);
}

let ai_model = Arc::new(RwLock::new(AIModel::GPT35));
let ai_model = Arc::new(RwLock::new(AIModel::GPT4oMini));

Self {
base_url: base_url.to_string(),
Expand Down
13 changes: 12 additions & 1 deletion libs/shared-entity/src/dto/ai_dto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,18 @@ pub struct SummarizeRowResponse {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CompleteTextParams {
pub text: String,
pub completion_type: CompletionType,
pub completion_type: Option<CompletionType>,
pub custom_prompt: Option<CustomPrompt>,
}

impl CompleteTextParams {
pub fn new_with_completion_type(text: String, completion_type: CompletionType) -> Self {
Self {
text,
completion_type: Some(completion_type),
custom_prompt: None,
}
}
}

#[derive(Debug)]
Expand Down
9 changes: 7 additions & 2 deletions src/api/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn complete_text_handler(
let params = payload.into_inner();
let resp = state
.ai_client
.completion_text(&params.text, params.completion_type, ai_model)
.completion_text(&params.text, params.completion_type, None, ai_model)
.await
.map_err(|err| AppError::Internal(err.into()))?;
Ok(AppResponse::Ok().with_data(resp).into())
Expand All @@ -51,7 +51,12 @@ async fn stream_complete_text_handler(
let params = payload.into_inner();
match state
.ai_client
.stream_completion_text(&params.text, params.completion_type, ai_model)
.stream_completion_text(
&params.text,
params.completion_type,
params.custom_prompt,
ai_model,
)
.await
{
Ok(stream) => Ok(
Expand Down
2 changes: 1 addition & 1 deletion src/api/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,5 @@ pub(crate) fn ai_model_from_header(req: &HttpRequest) -> AIModel {
let header = header.to_str().ok()?;
AIModel::from_str(header).ok()
})
.unwrap_or(AIModel::GPT35)
.unwrap_or(AIModel::GPT4oMini)
}
10 changes: 5 additions & 5 deletions tests/ai_test/complete_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ async fn improve_writing_test() {
return;
}
let test_client = TestClient::new_user().await;
test_client.api_client.set_ai_model(AIModel::GPT4o);
test_client.api_client.set_ai_model(AIModel::GPT4oMini);

let workspace_id = test_client.workspace_id().await;
let params = CompleteTextParams {
text: "I feel hungry".to_string(),
completion_type: CompletionType::ImproveWriting,
};
let params = CompleteTextParams::new_with_completion_type(
"I feel hungry".to_string(),
CompletionType::ImproveWriting,
);

let resp = test_client
.api_client
Expand Down

0 comments on commit 2f715c3

Please sign in to comment.