From b4d3597d8980cba2d3f665459b2aedc7a7b1372b Mon Sep 17 00:00:00 2001 From: xiaoshuyui <528490652@qq.com> Date: Thu, 2 May 2024 09:24:21 +0800 Subject: [PATCH 1/3] feat : seq chain --- lib/src/rust/api/llm_api.dart | 5 + lib/src/rust/frb_generated.dart | 56 +++-- rust/Cargo.lock | 13 +- rust/Cargo.toml | 3 +- rust/src/api/llm_api.rs | 5 + rust/src/frb_generated.rs | 63 ++++-- rust/src/llm/mod.rs | 110 ++++------ rust/src/llm/models.rs | 40 ++++ rust/src/llm/sequential_chain_builder.rs | 24 +++ rust/src/llm/tests.rs | 259 +++++++++++++++++++++++ 10 files changed, 473 insertions(+), 105 deletions(-) create mode 100644 rust/src/llm/models.rs create mode 100644 rust/src/llm/sequential_chain_builder.rs create mode 100644 rust/src/llm/tests.rs diff --git a/lib/src/rust/api/llm_api.dart b/lib/src/rust/api/llm_api.dart index e4b9dea..9df133a 100644 --- a/lib/src/rust/api/llm_api.dart +++ b/lib/src/rust/api/llm_api.dart @@ -24,3 +24,8 @@ Future chat( dynamic hint}) => RustLib.instance.api.chat( uuid: uuid, history: history, stream: stream, query: query, hint: hint); + +Future sequentialChainChat( + {required String jsonStr, required String query, dynamic hint}) => + RustLib.instance.api + .sequentialChainChat(jsonStr: jsonStr, query: query, hint: hint); diff --git a/lib/src/rust/frb_generated.dart b/lib/src/rust/frb_generated.dart index 590d233..e280a09 100644 --- a/lib/src/rust/frb_generated.dart +++ b/lib/src/rust/frb_generated.dart @@ -85,6 +85,9 @@ abstract class RustLibApi extends BaseApi { Stream llmMessageStream({dynamic hint}); + Future sequentialChainChat( + {required String jsonStr, required String query, dynamic hint}); + Future> getProcessPortMappers({dynamic hint}); String greet({required String name, dynamic hint}); @@ -239,13 +242,40 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { ); @override - Future> getProcessPortMappers({dynamic hint}) { + Future sequentialChainChat( + {required String jsonStr, required String query, dynamic hint}) { return handler.executeNormal(NormalTask( callFfi: (port_) { final serializer = SseSerializer(generalizedFrbRustBinding); + sse_encode_String(jsonStr, serializer); + sse_encode_String(query, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 5, port: port_); }, + codec: SseCodec( + decodeSuccessData: sse_decode_unit, + decodeErrorData: null, + ), + constMeta: kSequentialChainChatConstMeta, + argValues: [jsonStr, query], + apiImpl: this, + hint: hint, + )); + } + + TaskConstMeta get kSequentialChainChatConstMeta => const TaskConstMeta( + debugName: "sequential_chain_chat", + argNames: ["jsonStr", "query"], + ); + + @override + Future> getProcessPortMappers({dynamic hint}) { + return handler.executeNormal(NormalTask( + callFfi: (port_) { + final serializer = SseSerializer(generalizedFrbRustBinding); + pdeCallFfi(generalizedFrbRustBinding, serializer, + funcId: 6, port: port_); + }, codec: SseCodec( decodeSuccessData: sse_decode_list_process_port_mapper, decodeErrorData: null, @@ -268,7 +298,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: () { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_String(name, serializer); - return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 6)!; + return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 7)!; }, codec: SseCodec( decodeSuccessData: sse_decode_String, @@ -292,7 +322,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: (port_) { final serializer = SseSerializer(generalizedFrbRustBinding); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 7, port: port_); + funcId: 8, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -319,7 +349,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_i_64(id, serializer); sse_encode_String(name, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 12, port: port_); + funcId: 13, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -343,7 +373,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: (port_) { final serializer = SseSerializer(generalizedFrbRustBinding); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 9, port: port_); + funcId: 10, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_list_software, @@ -368,7 +398,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_list_record_i_64_string(items, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 8, port: port_); + funcId: 9, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -393,7 +423,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_i_64(id, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 13, port: port_); + funcId: 14, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -418,7 +448,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: () { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_StreamSink_list_prim_i_64_strict_Sse(s, serializer); - return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 10)!; + return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 11)!; }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -447,7 +477,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_StreamSink_record_list_prim_i_64_strict_string_Sse( s, serializer); - return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 11)!; + return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 12)!; }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -473,7 +503,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: (port_) { final serializer = SseSerializer(generalizedFrbRustBinding); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 14, port: port_); + funcId: 15, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -500,7 +530,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { sse_encode_list_Auto_Owned_RustOpaque_flutter_rust_bridgefor_generatedrust_asyncRwLockrust_simple_notify_libPinWindowItem( data, serializer); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 15, port: port_); + funcId: 16, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -524,7 +554,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: (port_) { final serializer = SseSerializer(generalizedFrbRustBinding); pdeCallFfi(generalizedFrbRustBinding, serializer, - funcId: 17, port: port_); + funcId: 18, port: port_); }, codec: SseCodec( decodeSuccessData: sse_decode_unit, @@ -549,7 +579,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { callFfi: () { final serializer = SseSerializer(generalizedFrbRustBinding); sse_encode_StreamSink_monitor_info_Sse(s, serializer); - return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 16)!; + return pdeCallFfi(generalizedFrbRustBinding, serializer, funcId: 17)!; }, codec: SseCodec( decodeSuccessData: sse_decode_unit, diff --git a/rust/Cargo.lock b/rust/Cargo.lock index c014e2a..efd23c2 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1960,9 +1960,9 @@ checksum = "3a68a4904193147e0a8dec3314640e6db742afd5f6e634f428a6af230d9b3591" [[package]] name = "either" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" [[package]] name = "encode_unicode" @@ -3861,9 +3861,9 @@ dependencies = [ [[package]] name = "langchain-rust" -version = "4.0.3" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db5d3146854c3b9df2913ee66631ebaffea94ece4b6e3cb2c1fa473b1e89a3cb" +checksum = "44372931fbc7d664f3338bcabd28916e20aba52c2918b0aec5a65168741ee42a" dependencies = [ "async-openai", "async-recursion", @@ -5755,6 +5755,7 @@ dependencies = [ "langchain-rust", "once_cell", "rust_simple_notify_lib", + "serde", "serde_json", "sysinfo", "systemicons", @@ -6743,9 +6744,9 @@ checksum = "f18aa187839b2bdb1ad2fa35ead8c4c2976b64e4363c386d45ac0f7ee85c9233" [[package]] name = "text-splitter" -version = "0.8.1" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0133ee1a78dc257c52bc118583e57a0ce7646d54cfd28d6c6d039c333c197445" +checksum = "691e4c33fe08c9637366b4f6ba6217157703f784f1c0a670aa76ac2f8a15733f" dependencies = [ "ahash", "auto_enums", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 35a9499..8def54e 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -12,9 +12,10 @@ cron-job = "=0.1.4" dotenv = "0.15.0" flutter_rust_bridge = "=2.0.0-dev.31" futures = "0.3.30" -langchain-rust = "4.0.3" +langchain-rust = "4.1.0" once_cell = "1.19.0" rust_simple_notify_lib = { path = "../rust_simple_notify_lib" } +serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.116" sysinfo = "0.30.8" tokio = { version = "1.37.0", features = ["full"] } diff --git a/rust/src/api/llm_api.rs b/rust/src/api/llm_api.rs index aa0aaad..c703e98 100644 --- a/rust/src/api/llm_api.rs +++ b/rust/src/api/llm_api.rs @@ -27,3 +27,8 @@ pub fn chat(_uuid: Option, _history: Option>, stream: bo let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async { crate::llm::chat(_uuid, _history, stream, query).await }); } + +pub fn sequential_chain_chat(json_str: String, query: String) { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { crate::llm::sequential_chain_chat(json_str, query).await }); +} diff --git a/rust/src/frb_generated.rs b/rust/src/frb_generated.rs index fa686d5..2dec015 100644 --- a/rust/src/frb_generated.rs +++ b/rust/src/frb_generated.rs @@ -166,6 +166,42 @@ fn wire_llm_message_stream_impl( }, ) } +fn wire_sequential_chain_chat_impl( + port_: flutter_rust_bridge::for_generated::MessagePort, + ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr, + rust_vec_len_: i32, + data_len_: i32, +) { + FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::( + flutter_rust_bridge::for_generated::TaskInfo { + debug_name: "sequential_chain_chat", + port: Some(port_), + mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal, + }, + move || { + let message = unsafe { + flutter_rust_bridge::for_generated::Dart2RustMessageSse::from_wire( + ptr_, + rust_vec_len_, + data_len_, + ) + }; + let mut deserializer = + flutter_rust_bridge::for_generated::SseDeserializer::new(message); + let api_json_str = ::sse_decode(&mut deserializer); + let api_query = ::sse_decode(&mut deserializer); + deserializer.end(); + move |context| { + transform_result_sse((move || { + Result::<_, ()>::Ok(crate::api::llm_api::sequential_chain_chat( + api_json_str, + api_query, + )) + })()) + } + }, + ) +} fn wire_get_process_port_mappers_impl( port_: flutter_rust_bridge::for_generated::MessagePort, ptr_: flutter_rust_bridge::for_generated::PlatformGeneralizedUint8ListPtr, @@ -1143,15 +1179,16 @@ fn pde_ffi_dispatcher_primary_impl( // Codec=Pde (Serialization + dispatch), see doc to use other codecs match func_id { 4 => wire_chat_impl(port, ptr, rust_vec_len, data_len), - 5 => wire_get_process_port_mappers_impl(port, ptr, rust_vec_len, data_len), - 7 => wire_init_app_impl(port, ptr, rust_vec_len, data_len), - 12 => wire_add_to_watching_list_impl(port, ptr, rust_vec_len, data_len), - 9 => wire_get_installed_softwares_impl(port, ptr, rust_vec_len, data_len), - 8 => wire_init_monitor_impl(port, ptr, rust_vec_len, data_len), - 13 => wire_remove_from_watching_list_impl(port, ptr, rust_vec_len, data_len), - 14 => wire_create_event_loop_impl(port, ptr, rust_vec_len, data_len), - 15 => wire_show_todos_impl(port, ptr, rust_vec_len, data_len), - 17 => wire_start_system_monitor_impl(port, ptr, rust_vec_len, data_len), + 5 => wire_sequential_chain_chat_impl(port, ptr, rust_vec_len, data_len), + 6 => wire_get_process_port_mappers_impl(port, ptr, rust_vec_len, data_len), + 8 => wire_init_app_impl(port, ptr, rust_vec_len, data_len), + 13 => wire_add_to_watching_list_impl(port, ptr, rust_vec_len, data_len), + 10 => wire_get_installed_softwares_impl(port, ptr, rust_vec_len, data_len), + 9 => wire_init_monitor_impl(port, ptr, rust_vec_len, data_len), + 14 => wire_remove_from_watching_list_impl(port, ptr, rust_vec_len, data_len), + 15 => wire_create_event_loop_impl(port, ptr, rust_vec_len, data_len), + 16 => wire_show_todos_impl(port, ptr, rust_vec_len, data_len), + 18 => wire_start_system_monitor_impl(port, ptr, rust_vec_len, data_len), _ => unreachable!(), } } @@ -1167,12 +1204,12 @@ fn pde_ffi_dispatcher_sync_impl( 2 => wire_get_llm_config_impl(ptr, rust_vec_len, data_len), 1 => wire_init_llm_impl(ptr, rust_vec_len, data_len), 3 => wire_llm_message_stream_impl(ptr, rust_vec_len, data_len), - 6 => wire_greet_impl(ptr, rust_vec_len, data_len), - 10 => wire_software_watching_message_stream_impl(ptr, rust_vec_len, data_len), - 11 => { + 7 => wire_greet_impl(ptr, rust_vec_len, data_len), + 11 => wire_software_watching_message_stream_impl(ptr, rust_vec_len, data_len), + 12 => { wire_software_watching_with_foreground_message_stream_impl(ptr, rust_vec_len, data_len) } - 16 => wire_system_monitor_message_stream_impl(ptr, rust_vec_len, data_len), + 17 => wire_system_monitor_message_stream_impl(ptr, rust_vec_len, data_len), _ => unreachable!(), } } diff --git a/rust/src/llm/mod.rs b/rust/src/llm/mod.rs index 1571aea..322a127 100644 --- a/rust/src/llm/mod.rs +++ b/rust/src/llm/mod.rs @@ -1,6 +1,12 @@ +pub mod models; +pub mod sequential_chain_builder; +mod tests; + use dotenv::dotenv; use futures::StreamExt; +use langchain_rust::chain::Chain; use langchain_rust::language_models::llm::LLM; +use langchain_rust::prompt_args; use langchain_rust::{ llm::{OpenAI, OpenAIConfig}, schemas::Message, @@ -10,78 +16,8 @@ use std::sync::RwLock; use crate::frb_generated::StreamSink; -#[allow(unused_imports)] -mod tests { - use dotenv::dotenv; - use futures::StreamExt; - use langchain_rust::{ - chain::{Chain, LLMChainBuilder}, - fmt_message, fmt_template, - language_models::llm::LLM, - llm::{OpenAI, OpenAIConfig}, - message_formatter, - prompt::HumanMessagePromptTemplate, - prompt_args, - schemas::Message, - template_fstring, - }; - - #[test] - fn uuid_test() { - let u = uuid::Uuid::new_v4().to_string(); - println!("uuid : {:?}", u); - } - - #[tokio::test] - async fn test() { - dotenv().ok(); - - let base = std::env::var("LLM_BASE").unwrap(); - println!("base {:?}", base); - let name = std::env::var("LLM_MODEL_NAME").unwrap(); - println!("name {:?}", name); - let sk = std::env::var("LLM_SK").unwrap_or("".to_owned()); - println!("sk {:?}", sk); - - let open_ai = OpenAI::default() - .with_config(OpenAIConfig::new().with_api_base(base).with_api_key(sk)) - .with_model(name); - - let response = open_ai.invoke("how can langsmith help with testing?").await; - println!("{:?}", response); - } - - #[tokio::test] - async fn test_stream() { - dotenv().ok(); - - let base = std::env::var("LLM_BASE").unwrap(); - println!("base {:?}", base); - let name = std::env::var("LLM_MODEL_NAME").unwrap(); - println!("name {:?}", name); - let sk = std::env::var("LLM_SK").unwrap_or("".to_owned()); - println!("sk {:?}", sk); - - let open_ai = OpenAI::default() - .with_config(OpenAIConfig::new().with_api_base(base).with_api_key(sk)) - .with_model(name); - - let mut stream = open_ai - .stream(&[ - Message::new_human_message("你是一个私有化AI助理。"), - Message::new_human_message("请问如何使用rust实现链表。"), - ]) - .await - .unwrap(); - - while let Some(result) = stream.next().await { - match result { - Ok(value) => value.to_stdout().unwrap(), - Err(e) => panic!("Error invoking LLMChain: {:?}", e), - } - } - } -} +use self::models::ChainIOes; +use self::sequential_chain_builder::CustomSequentialChain; pub static ENV_PARAMS: Lazy>> = Lazy::new(|| RwLock::new(None)); @@ -281,3 +217,33 @@ pub async fn chat( } } } + +pub async fn sequential_chain_chat(json_str: String, query: String) -> anyhow::Result<()> { + let open_ai; + { + open_ai = OPENAI.read().unwrap(); + } + let items: ChainIOes = serde_json::from_str(&json_str)?; + let first_input = items.items.first().unwrap().input_key.clone(); + let seq = CustomSequentialChain { + chains: items.items, + }; + + let seq_chain = seq.build(open_ai.clone()); + match seq_chain { + Some(_seq) => { + let output = _seq + .execute(prompt_args! { + first_input => query + }) + .await + .unwrap(); + println!("output: {:?}", output); + } + None => { + println!("none"); + } + } + + anyhow::Ok(()) +} diff --git a/rust/src/llm/models.rs b/rust/src/llm/models.rs new file mode 100644 index 0000000..19f9514 --- /dev/null +++ b/rust/src/llm/models.rs @@ -0,0 +1,40 @@ +use langchain_rust::{ + chain::{LLMChain, LLMChainBuilder}, + llm::{Config, OpenAI}, + prompt::HumanMessagePromptTemplate, + template_jinja2, +}; +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct ChainIO { + pub input_key: String, + pub output_key: String, + pub prompt: String, +} + +#[derive(Deserialize)] +pub struct ChainIOes { + pub items: Vec, +} + +impl ChainIO { + fn to_prompt(&self) -> HumanMessagePromptTemplate { + let f = self.prompt.replace("placeholder", &self.input_key); + + HumanMessagePromptTemplate::new(template_jinja2!(f, self.input_key)) + } + + pub fn to_chain( + &self, + llm: OpenAI, + ) -> anyhow::Result { + anyhow::Ok( + LLMChainBuilder::new() + .prompt(self.to_prompt()) + .llm(llm.clone()) + .output_key(self.output_key.clone()) + .build()?, + ) + } +} diff --git a/rust/src/llm/sequential_chain_builder.rs b/rust/src/llm/sequential_chain_builder.rs new file mode 100644 index 0000000..da53cd2 --- /dev/null +++ b/rust/src/llm/sequential_chain_builder.rs @@ -0,0 +1,24 @@ +use langchain_rust::{chain::{SequentialChain, SequentialChainBuilder}, llm::{Config, OpenAI}}; + +use super::models::ChainIO; + +pub struct CustomSequentialChain { + pub chains: Vec, +} + +impl CustomSequentialChain { + pub fn build(&self, llm: OpenAI) -> Option { + let mut seq_chain = SequentialChainBuilder::new(); + + for i in &self.chains{ + let _c = i.to_chain(llm.clone()); + if let Ok(_c) = _c { + seq_chain = seq_chain.add_chain(_c); + }else{ + return None; + } + } + + Some(seq_chain.build()) + } +} diff --git a/rust/src/llm/tests.rs b/rust/src/llm/tests.rs new file mode 100644 index 0000000..77baa60 --- /dev/null +++ b/rust/src/llm/tests.rs @@ -0,0 +1,259 @@ +#[allow(unused_imports)] +mod tests { + use std::io::Write; + + use dotenv::dotenv; + use futures::StreamExt; + use langchain_rust::{ + chain::{builder::ConversationalChainBuilder, Chain, LLMChainBuilder}, + fmt_message, fmt_template, + language_models::llm::LLM, + llm::{OpenAI, OpenAIConfig}, + memory::SimpleMemory, + message_formatter, + prompt::HumanMessagePromptTemplate, + prompt_args, + schemas::Message, + sequential_chain, template_fstring, template_jinja2, + }; + + use crate::llm::{models::ChainIOes, sequential_chain_builder::CustomSequentialChain}; + + #[test] + fn uuid_test() { + let u = uuid::Uuid::new_v4().to_string(); + println!("uuid : {:?}", u); + } + + #[tokio::test] + async fn test() { + dotenv().ok(); + + let base = std::env::var("LLM_BASE").unwrap(); + println!("base {:?}", base); + let name = std::env::var("LLM_MODEL_NAME").unwrap(); + println!("name {:?}", name); + let sk = std::env::var("LLM_SK").unwrap_or("".to_owned()); + println!("sk {:?}", sk); + + let open_ai = OpenAI::default() + .with_config(OpenAIConfig::new().with_api_base(base).with_api_key(sk)) + .with_model(name); + + let response = open_ai.invoke("how can langsmith help with testing?").await; + println!("{:?}", response); + } + + #[tokio::test] + async fn test_stream() { + dotenv().ok(); + + let base = std::env::var("LLM_BASE").unwrap(); + println!("base {:?}", base); + let name = std::env::var("LLM_MODEL_NAME").unwrap(); + println!("name {:?}", name); + let sk = std::env::var("LLM_SK").unwrap_or("".to_owned()); + println!("sk {:?}", sk); + + let open_ai = OpenAI::default() + .with_config(OpenAIConfig::new().with_api_base(base).with_api_key(sk)) + .with_model(name); + + let mut stream = open_ai + .stream(&[ + Message::new_human_message("你是一个私有化AI助理。"), + Message::new_human_message("请问如何使用rust实现链表。"), + ]) + .await + .unwrap(); + + while let Some(result) = stream.next().await { + match result { + Ok(value) => value.to_stdout().unwrap(), + Err(e) => panic!("Error invoking LLMChain: {:?}", e), + } + } + } + + #[tokio::test] + async fn test_chain() { + dotenv().ok(); + + let base = std::env::var("LLM_BASE").unwrap(); + println!("base {:?}", base); + let name = std::env::var("LLM_MODEL_NAME").unwrap(); + println!("name {:?}", name); + let sk = std::env::var("LLM_SK").unwrap_or("".to_owned()); + println!("sk {:?}", sk); + + let open_ai = OpenAI::default() + .with_config(OpenAIConfig::new().with_api_base(base).with_api_key(sk)) + .with_model(name); + + let memory = SimpleMemory::new(); + let chain = ConversationalChainBuilder::new() + .llm(open_ai) + //IF YOU WANT TO ADD A CUSTOM PROMPT YOU CAN UN COMMENT THIS: + // .prompt(message_formatter![ + // fmt_message!(Message::new_system_message("You are a helpful assistant")), + // fmt_template!(HumanMessagePromptTemplate::new( + // template_fstring!(" + // The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. + // + // Current conversation: + // {history} + // Human: {input} + // AI: + // ", + // "input","history"))) + // + // ]) + .memory(memory.into()) + .build() + .expect("Error building ConversationalChain"); + + let input_variables = prompt_args! { + "input" => "我来自江苏常州。", + }; + + let mut stream = chain.stream(input_variables).await.unwrap(); + while let Some(result) = stream.next().await { + match result { + Ok(data) => { + //If you junt want to print to stdout, you can use data.to_stdout().unwrap(); + print!("{}", data.content); + std::io::stdout().flush().unwrap(); + } + Err(e) => { + println!("Error: {:?}", e); + } + } + } + + let input_variables = prompt_args! { + "input" => "常州有什么特产或者好玩的地方?", + }; + match chain.invoke(input_variables).await { + Ok(result) => { + println!("\n"); + println!("Result: {:?}", result); + } + Err(e) => panic!("Error invoking LLMChain: {:?}", e), + } + } + + #[tokio::test] + async fn sequential_chain_test() { + dotenv().ok(); + + let base = std::env::var("LLM_BASE").unwrap(); + println!("base {:?}", base); + let name = std::env::var("LLM_MODEL_NAME").unwrap(); + println!("name {:?}", name); + let sk = std::env::var("LLM_SK").unwrap_or("".to_owned()); + println!("sk {:?}", sk); + + let llm = OpenAI::default() + .with_config(OpenAIConfig::new().with_api_base(base).with_api_key(sk)) + .with_model(name); + let prompt = HumanMessagePromptTemplate::new(template_jinja2!( + "给我一个卖东西的商店起个有创意的名字: {{producto}}", + "producto" + )); + + let get_name_chain = LLMChainBuilder::new() + .prompt(prompt) + .llm(llm.clone()) + .output_key("name") + .build() + .unwrap(); + + let prompt = HumanMessagePromptTemplate::new(template_jinja2!( + "给我一个下一个名字的口号: {{name}}", + "name" + )); + let get_slogan_chain = LLMChainBuilder::new() + .prompt(prompt) + .llm(llm.clone()) + .output_key("slogan") + .build() + .unwrap(); + + let sequential_chain = sequential_chain!(get_name_chain, get_slogan_chain); + + print!("Please enter a product: "); + std::io::stdout().flush().unwrap(); // Display prompt to terminal + + let mut product = String::new(); + std::io::stdin().read_line(&mut product).unwrap(); // Get product from terminal input + + let product = product.trim(); + let output = sequential_chain + .execute(prompt_args! { + "producto" => product + }) + .await + .unwrap(); + + println!("Name: {}", output["name"]); + println!("Slogan: {}", output["slogan"]); + } + + #[tokio::test] + async fn json_test() -> anyhow::Result<()> { + let json_data = r#" + { + "items":[ + { + "input_key":"producto", + "output_key":"name", + "prompt":"给我一个卖东西的商店起个有创意的名字: {{placeholder}}" + }, + { + "input_key":"name", + "output_key":"slogan", + "prompt":"给我一个下一个名字的口号: {{placeholder}}" + } + ] + } + "#; + + dotenv().ok(); + + let base = std::env::var("LLM_BASE").unwrap(); + println!("base {:?}", base); + let name = std::env::var("LLM_MODEL_NAME").unwrap(); + println!("name {:?}", name); + let sk = std::env::var("LLM_SK").unwrap_or("".to_owned()); + println!("sk {:?}", sk); + + let llm = OpenAI::default() + .with_config(OpenAIConfig::new().with_api_base(base).with_api_key(sk)) + .with_model(name); + + let items: ChainIOes = serde_json::from_str(json_data)?; + let seq = CustomSequentialChain { + chains: items.items, + }; + + let seq_chain = seq.build(llm); + + match seq_chain { + Some(_seq) => { + let output = _seq + .execute(prompt_args! { + "producto" => "shoe" + }) + .await + .unwrap(); + println!("Name: {}", output["name"]); + println!("Slogan: {}", output["slogan"]); + } + None => { + println!("none"); + } + } + + anyhow::Ok(()) + } +} From e4e76428dbb83e4664e5c22463ce901b81d67f6b Mon Sep 17 00:00:00 2001 From: xiaoshuyui <528490652@qq.com> Date: Thu, 2 May 2024 17:45:15 +0800 Subject: [PATCH 2/3] feat: seq chain UI --- lib/llm/langchain/components/buttons.dart | 1 + .../components/modify_chain_dialog.dart | 176 ++++++++++++ .../langchain/components/seq_chain_flow.dart | 263 ++++++++++++++++++ lib/llm/langchain/components/tools_item.dart | 1 + .../langchain/components/tools_screen.dart | 128 +++++---- lib/llm/langchain/extension.dart | 29 ++ lib/llm/langchain/langchain_chat_screen.dart | 4 +- lib/llm/langchain/models/chains.dart | 44 +++ .../langchain/notifiers/chain_notifier.dart | 50 ++++ .../langchain/notifiers/tool_notifier.dart | 14 +- pubspec.lock | 24 ++ pubspec.yaml | 3 + rust/src/api/llm_api.rs | 2 +- 13 files changed, 673 insertions(+), 66 deletions(-) create mode 100644 lib/llm/langchain/components/modify_chain_dialog.dart create mode 100644 lib/llm/langchain/components/seq_chain_flow.dart create mode 100644 lib/llm/langchain/extension.dart create mode 100644 lib/llm/langchain/models/chains.dart create mode 100644 lib/llm/langchain/notifiers/chain_notifier.dart diff --git a/lib/llm/langchain/components/buttons.dart b/lib/llm/langchain/components/buttons.dart index 89674f0..f6d9e88 100644 --- a/lib/llm/langchain/components/buttons.dart +++ b/lib/llm/langchain/components/buttons.dart @@ -28,6 +28,7 @@ class Buttons extends ConsumerWidget { GestureDetector( onTap: () { ref.read(toolProvider.notifier).changeState(null); + ref.read(toolProvider.notifier).jumpTo(0); }, child: _wrapper(const Text( "返回", diff --git a/lib/llm/langchain/components/modify_chain_dialog.dart b/lib/llm/langchain/components/modify_chain_dialog.dart new file mode 100644 index 0000000..8ec81e9 --- /dev/null +++ b/lib/llm/langchain/components/modify_chain_dialog.dart @@ -0,0 +1,176 @@ +import 'package:all_in_one/llm/langchain/models/chains.dart'; +import 'package:all_in_one/styles/app_style.dart'; +import 'package:flutter/material.dart'; +import 'package:flutter_flow_chart/flutter_flow_chart.dart'; + +class ModifyChainDialog extends StatefulWidget { + const ModifyChainDialog({super.key, required this.item}); + final FlowElement item; + + @override + State createState() => _ModifyChainDialogState(); +} + +class _ModifyChainDialogState extends State { + late final TextEditingController itemTextController = TextEditingController() + ..text = widget.item.text; + final itemTextFocusNode = FocusNode(); + final TextEditingController inputKeyController = TextEditingController(); + final inputKeyFocusNode = FocusNode(); + final TextEditingController outputKeyController = TextEditingController(); + final outputKeyFocusNode = FocusNode(); + final TextEditingController promptController = TextEditingController() + ..text = "{{placeholder}}"; + final promptFocusNode = FocusNode(); + + final _formKey = GlobalKey(); + + @override + Widget build(BuildContext context) { + return Material( + borderRadius: BorderRadius.circular(4), + elevation: 10, + child: Container( + padding: const EdgeInsets.all(20), + width: 400, + height: 320, + child: Form( + key: _formKey, + child: Column( + children: [ + _wrapper( + "name", + SizedBox( + height: 30, + child: TextFormField( + validator: (value) { + if (value == null || value == "") { + return ""; + } + return null; + }, + controller: itemTextController, + style: + const TextStyle(color: Colors.black, fontSize: 12), + decoration: AppStyle.inputDecoration, + autofocus: true, + onFieldSubmitted: (value) { + inputKeyFocusNode.requestFocus(); + }, + ), + )), + const SizedBox( + height: 20, + ), + _wrapper( + "input", + SizedBox( + height: 30, + child: TextFormField( + validator: (value) { + if (value == null || value == "") { + return ""; + } + return null; + }, + controller: inputKeyController, + style: const TextStyle( + color: Colors.black, fontSize: 12), + decoration: AppStyle.inputDecoration, + autofocus: false, + onFieldSubmitted: (value) { + outputKeyFocusNode.requestFocus(); + }, + ))), + const SizedBox( + height: 20, + ), + _wrapper( + "output", + SizedBox( + height: 30, + child: TextFormField( + validator: (value) { + if (value == null || value == "") { + return ""; + } + return null; + }, + controller: outputKeyController, + style: const TextStyle( + color: Colors.black, fontSize: 12), + decoration: AppStyle.inputDecoration, + autofocus: false, + onFieldSubmitted: (value) { + promptFocusNode.requestFocus(); + }, + ))), + const SizedBox( + height: 20, + ), + _wrapper( + "prompt", + TextFormField( + validator: (value) { + if (value == null || value == "") { + return ""; + } + if (value.contains("{{placeholder}}")) { + return null; + } + + return ""; + }, + maxLines: 3, + controller: promptController, + style: const TextStyle(color: Colors.black, fontSize: 12), + decoration: AppStyle.inputDecoration, + autofocus: false, + onFieldSubmitted: (value) { + _formKey.currentState!.validate(); + }, + )), + const SizedBox( + height: 20, + ), + Row( + mainAxisAlignment: MainAxisAlignment.end, + children: [ + ElevatedButton( + onPressed: () { + if (_formKey.currentState!.validate()) { + Navigator.of(context).pop(( + itemTextController.text, + ChainItem( + inputKey: inputKeyController.text, + outputKey: outputKeyController.text, + prompt: promptController.text) + )); + } + }, + child: const Text("确定")) + ], + ) + ], + )), + ), + ); + } + + _wrapper(String title, Widget child) { + return Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + SizedBox( + width: 100, + child: Text(title), + ), + Expanded( + child: Align( + alignment: Alignment.centerLeft, + child: child, + )) + ], + ); + } +} diff --git a/lib/llm/langchain/components/seq_chain_flow.dart b/lib/llm/langchain/components/seq_chain_flow.dart new file mode 100644 index 0000000..1f0bffc --- /dev/null +++ b/lib/llm/langchain/components/seq_chain_flow.dart @@ -0,0 +1,263 @@ +import 'dart:convert'; + +import 'package:all_in_one/llm/langchain/extension.dart'; +import 'package:all_in_one/llm/langchain/models/chains.dart'; +import 'package:all_in_one/llm/langchain/notifiers/chain_notifier.dart'; +import 'package:all_in_one/llm/langchain/notifiers/tool_notifier.dart'; +import 'package:all_in_one/src/rust/api/llm_api.dart'; +import 'package:flutter/material.dart'; +import 'package:flutter_expandable_fab/flutter_expandable_fab.dart'; +import 'package:flutter_flow_chart/flutter_flow_chart.dart'; +import 'package:flutter_riverpod/flutter_riverpod.dart'; +import 'package:star_menu/star_menu.dart'; + +import 'modify_chain_dialog.dart'; + +class FlowScreen extends ConsumerStatefulWidget { + const FlowScreen({super.key}); + + @override + ConsumerState createState() => _FlowScreenState(); +} + +class _FlowScreenState extends ConsumerState { + Dashboard dashboard = Dashboard(); + + @override + Widget build(BuildContext context) { + WidgetsBinding.instance.addPostFrameCallback((timeStamp) { + if (dashboard.elements.isEmpty) { + dashboard.addElement(FlowElement( + position: const Offset(100, 100), + size: const Size(100, 50), + text: '开始', + kind: ElementKind.oval, + handlers: [])); + } + }); + + return Scaffold( + body: Padding( + padding: const EdgeInsets.all(5.0), + + // child: Container(color: Colors.amber), + + child: Stack( + children: [ + Container( + constraints: const BoxConstraints.expand(), + child: FlowChart( + dashboard: dashboard, + onDashboardTapped: ((context, position) { + debugPrint('Dashboard tapped $position'); + _displayDashboardMenu(context, position); + }), + onDashboardSecondaryTapped: (context, position) { + debugPrint('Dashboard right clicked $position'); + _displayDashboardMenu(context, position); + }, + onElementPressed: (context, position, element) { + debugPrint('Element with "${element.text}" text pressed'); + _displayElementMenu(context, position, element); + }, + ), + ), + ], + ), + ), + floatingActionButtonLocation: ExpandableFab.location, + // floatingActionButton: FloatingActionButton( + // onPressed: () => dashboard.format(), + // child: const Icon(Icons.center_focus_strong)), + floatingActionButton: ExpandableFab( + distance: 50, + type: ExpandableFabType.side, + children: [ + FloatingActionButton.small( + tooltip: "format", + heroTag: "center", + onPressed: () { + dashboard.format(); + }, + child: const Icon(Icons.center_focus_strong), + ), + FloatingActionButton.small( + tooltip: "save", + heroTag: null, + child: const Icon(Icons.save), + onPressed: () { + final r = ref.read(chainProvider.notifier).validate(); + if (r) { + Chains chains = + Chains(items: ref.read(chainProvider.notifier).items); + final jsonStr = jsonEncode(chains.toJson()); + sequentialChainChat(jsonStr: jsonStr, query: "shoe"); + } + }, + ), + FloatingActionButton.small( + tooltip: "clear all", + heroTag: null, + child: const Icon(Icons.clear), + onPressed: () { + dashboard.removeAllElements(notify: true); + }, + ), + FloatingActionButton.small( + tooltip: "back", + heroTag: null, + child: const Icon(Icons.arrow_back), + onPressed: () { + ref.read(toolProvider.notifier).jumpTo(0); + }, + ), + ], + ), + ); + } + + /// Display a drop down menu when tapping on an element + _displayElementMenu( + BuildContext context, + Offset position, + FlowElement element, + ) { + StarMenuOverlay.displayStarMenu( + context, + StarMenu( + params: StarMenuParameters( + shape: MenuShape.linear, + openDurationMs: 60, + linearShapeParams: const LinearShapeParams( + angle: 270, + alignment: LinearAlignment.left, + space: 10, + ), + onHoverScale: 1.1, + centerOffset: position - const Offset(50, 50), + backgroundParams: const BackgroundParams( + backgroundColor: Colors.transparent, + ), + boundaryBackground: BoundaryBackground( + padding: const EdgeInsets.all(16), + decoration: BoxDecoration( + borderRadius: BorderRadius.circular(8), + color: Theme.of(context).cardColor, + boxShadow: kElevationToShadow[6], + ), + ), + ), + onItemTapped: (index, controller) { + if (!(index == 5 || index == 2)) { + controller.closeMenu!(); + } + }, + items: [ + Text( + element.text, + style: const TextStyle(fontWeight: FontWeight.w900), + ), + InkWell( + onTap: () async { + final (String, ChainItem)? item = await showGeneralDialog( + barrierColor: Colors.white.withOpacity(0.1), + barrierDismissible: true, + barrierLabel: "modify", + context: context, + pageBuilder: (c, _, __) { + return Center( + child: ModifyChainDialog( + item: element, + ), + ); + }); + + if (item != null) { + element.setText(item.$1); + ref.read(chainProvider.notifier).updateItem(element, item.$2); + } + }, + child: const Text('Modify'), + ), + InkWell( + onTap: () { + final prev = element.findPrevious(dashboard.elements); + assert(prev != null, "prev cannot be null"); + + final next = dashboard.elements + .where((e) => e.id == element.next.firstOrNull?.destElementId) + .firstOrNull; + + dashboard.removeElement(element, notify: next == null); + + if (next != null && prev != null) { + dashboard.addNextById(prev, next.id, ArrowParams()); + } + + ref.read(chainProvider.notifier).removeItem(element.id); + }, + child: const Text('Delete'), + ), + ], + parentContext: context, + ), + ); + } + + /// Display a linear menu for the dashboard + /// with menu entries built with [menuEntries] + _displayDashboardMenu(BuildContext context, Offset position) { + StarMenuOverlay.displayStarMenu( + context, + StarMenu( + params: StarMenuParameters( + shape: MenuShape.linear, + openDurationMs: 60, + linearShapeParams: const LinearShapeParams( + angle: 270, + alignment: LinearAlignment.left, + space: 10, + ), + // calculate the offset from the dashboard center + centerOffset: position - + Offset( + dashboard.dashboardSize.width / 2, + dashboard.dashboardSize.height / 2, + ), + ), + onItemTapped: (index, controller) => controller.closeMenu!(), + parentContext: context, + items: [ + ActionChip( + label: const Text('添加item'), + onPressed: () { + final flowElement = FlowElement( + position: position + const Offset(100, 25), + size: const Size(200, 50), + text: 'chain item', + kind: ElementKind.rectangle, + handlers: []); + dashboard.addElement(flowElement, notify: false); + // print(xor.id); + dashboard.addNextById( + dashboard.elements[dashboard.elements.length - 2], + flowElement.id, + ArrowParams(), + notify: true); + + ref + .read(chainProvider.notifier) + .addItem(flowElement, ChainItem()); + }), + ActionChip( + label: const Text('清空'), + onPressed: () { + dashboard.removeAllElements(); + + setState(() {}); + }), + ], + ), + ); + } +} diff --git a/lib/llm/langchain/components/tools_item.dart b/lib/llm/langchain/components/tools_item.dart index 0751d85..3f24ac0 100644 --- a/lib/llm/langchain/components/tools_item.dart +++ b/lib/llm/langchain/components/tools_item.dart @@ -20,6 +20,7 @@ class _ToolsItemState extends ConsumerState { ref .read(toolProvider.notifier) .changeState(widget.toolModel.toMessage()); + ref.read(toolProvider.notifier).jumpTo(1); }, child: FittedBox( child: Stack( diff --git a/lib/llm/langchain/components/tools_screen.dart b/lib/llm/langchain/components/tools_screen.dart index 1f64274..9c5f0c9 100644 --- a/lib/llm/langchain/components/tools_screen.dart +++ b/lib/llm/langchain/components/tools_screen.dart @@ -1,20 +1,23 @@ import 'dart:convert'; import 'package:all_in_one/llm/langchain/models/tool_model.dart'; +import 'package:all_in_one/llm/langchain/notifiers/tool_notifier.dart'; import 'package:all_in_one/styles/app_style.dart'; import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; +import 'package:flutter_riverpod/flutter_riverpod.dart'; +import 'package:icons_plus/icons_plus.dart'; import 'tools_item.dart'; -class ToolsScreen extends StatefulWidget { +class ToolsScreen extends ConsumerStatefulWidget { const ToolsScreen({super.key}); @override - State createState() => _ToolsScreenState(); + ConsumerState createState() => _ToolsScreenState(); } -class _ToolsScreenState extends State { +class _ToolsScreenState extends ConsumerState { @override void initState() { super.initState(); @@ -31,65 +34,72 @@ class _ToolsScreenState extends State { @override Widget build(BuildContext context) { - return Padding( - padding: const EdgeInsets.all(10), - child: FutureBuilder( - future: future, - builder: (c, s) { - if (s.connectionState == ConnectionState.done) { - final Map jsonObj = json.decode(textContent); - return Column( - children: [ - Stack( - children: [ - Container( - width: double.infinity, - constraints: const BoxConstraints(maxHeight: 400), - child: Image.asset( - "assets/llm/banner.jpg", - fit: BoxFit.cover, + return Scaffold( + floatingActionButton: FloatingActionButton( + child: const Icon(Bootstrap.chat), + onPressed: () { + ref.read(toolProvider.notifier).jumpTo(2); + }), + body: Padding( + padding: const EdgeInsets.all(10), + child: FutureBuilder( + future: future, + builder: (c, s) { + if (s.connectionState == ConnectionState.done) { + final Map jsonObj = json.decode(textContent); + return Column( + children: [ + Stack( + children: [ + Container( + width: double.infinity, + constraints: const BoxConstraints(maxHeight: 400), + child: Image.asset( + "assets/llm/banner.jpg", + fit: BoxFit.cover, + ), ), - ), - Positioned( - bottom: 40, - right: 20, - child: Transform.rotate( - angle: -3.14 / 10, - child: const Text( - "Let AI help you", - style: TextStyle( - fontFamily: "xing", - fontSize: 40, - color: AppStyle.orange), - ), - )) - ], - ), - const SizedBox( - height: 10, - ), - Expanded( - child: Align( - alignment: Alignment.topLeft, - child: SingleChildScrollView( - child: Wrap( - runSpacing: 15, - spacing: 15, - children: jsonObj.entries - .map((e) => ToolsItem( - toolModel: ToolModel.fromJson(e.value), - )) - .toList(), - ), + Positioned( + bottom: 40, + right: 20, + child: Transform.rotate( + angle: -3.14 / 10, + child: const Text( + "Let AI help you", + style: TextStyle( + fontFamily: "xing", + fontSize: 40, + color: AppStyle.orange), + ), + )) + ], + ), + const SizedBox( + height: 10, ), - )) - ], + Expanded( + child: Align( + alignment: Alignment.topLeft, + child: SingleChildScrollView( + child: Wrap( + runSpacing: 15, + spacing: 15, + children: jsonObj.entries + .map((e) => ToolsItem( + toolModel: ToolModel.fromJson(e.value), + )) + .toList(), + ), + ), + )) + ], + ); + } + return const Center( + child: CircularProgressIndicator(), ); - } - return const Center( - child: CircularProgressIndicator(), - ); - }), + }), + ), ); } } diff --git a/lib/llm/langchain/extension.dart b/lib/llm/langchain/extension.dart new file mode 100644 index 0000000..f18c0bf --- /dev/null +++ b/lib/llm/langchain/extension.dart @@ -0,0 +1,29 @@ +import 'package:flutter/material.dart'; +import 'package:flutter_flow_chart/flutter_flow_chart.dart'; + +extension FlowElementExtension on FlowElement { + FlowElement? findPrevious(List elements) { + for (final i in elements) { + if (i.next.map((e) => e.destElementId).contains(id)) { + return i; + } + } + + return null; + } +} + +extension FormatExtension on Dashboard { + format() { + if (elements.length < 2) { + return; + } + + for (int i = 1; i < elements.length; i++) { + elements[i].position = + Offset(elements[0].position.dx + i * 200, elements[0].position.dy); + } + + recenter(); + } +} diff --git a/lib/llm/langchain/langchain_chat_screen.dart b/lib/llm/langchain/langchain_chat_screen.dart index c372274..0faadef 100644 --- a/lib/llm/langchain/langchain_chat_screen.dart +++ b/lib/llm/langchain/langchain_chat_screen.dart @@ -7,6 +7,8 @@ import 'package:all_in_one/llm/langchain/notifiers/tool_notifier.dart'; import 'package:flutter/material.dart'; import 'package:flutter_riverpod/flutter_riverpod.dart'; +import 'components/seq_chain_flow.dart'; + class LangchainChatScreen extends ConsumerStatefulWidget { const LangchainChatScreen({super.key}); @@ -22,7 +24,7 @@ class _LangchainChatScreenState extends ConsumerState { body: PageView( physics: const NeverScrollableScrollPhysics(), controller: ref.read(toolProvider.notifier).controller, - children: const [ToolsScreen(), _UI()], + children: const [ToolsScreen(), _UI(), FlowScreen()], ), ); } diff --git a/lib/llm/langchain/models/chains.dart b/lib/llm/langchain/models/chains.dart new file mode 100644 index 0000000..e1ce590 --- /dev/null +++ b/lib/llm/langchain/models/chains.dart @@ -0,0 +1,44 @@ +class Chains { + List? items; + + Chains({this.items}); + + Chains.fromJson(Map json) { + if (json['items'] != null) { + items = []; + json['items'].forEach((v) { + items!.add(ChainItem.fromJson(v)); + }); + } + } + + Map toJson() { + final Map data = {}; + if (items != null) { + data['items'] = items!.map((v) => v.toJson()).toList(); + } + return data; + } +} + +class ChainItem { + String? inputKey; + String? outputKey; + String? prompt; + + ChainItem({this.inputKey, this.outputKey, this.prompt}); + + ChainItem.fromJson(Map json) { + inputKey = json['input_key']; + outputKey = json['output_key']; + prompt = json['prompt']; + } + + Map toJson() { + final Map data = {}; + data['input_key'] = inputKey; + data['output_key'] = outputKey; + data['prompt'] = prompt; + return data; + } +} diff --git a/lib/llm/langchain/notifiers/chain_notifier.dart b/lib/llm/langchain/notifiers/chain_notifier.dart new file mode 100644 index 0000000..aaa0f14 --- /dev/null +++ b/lib/llm/langchain/notifiers/chain_notifier.dart @@ -0,0 +1,50 @@ +import 'package:all_in_one/llm/langchain/models/chains.dart'; +import 'package:flutter_flow_chart/flutter_flow_chart.dart'; +import 'package:flutter_riverpod/flutter_riverpod.dart'; + +class ChainNotifier extends Notifier> { + @override + Map build() { + return {}; + } + + addItem(FlowElement element, ChainItem chainItem) { + state[element] = chainItem; + } + + removeItem(String id) { + state.removeWhere((key, value) => key.id == id); + } + + updateItem(FlowElement element, ChainItem chainItem) { + state[element] = chainItem; + } + + bool validate() { + final values = state.values; + if (values.isEmpty) { + return false; + } + if (values.elementAt(0).inputKey == null || + values.elementAt(0).outputKey == null || + values.elementAt(0).prompt == null) { + return false; + } + for (int i = 0; i < values.length - 1; i++) { + if (values.elementAt(i).outputKey == null || + values.elementAt(i + 1).inputKey != values.elementAt(i).outputKey || + values.elementAt(i).prompt == null || + values.elementAt(i + 1).prompt == null) { + return false; + } + } + return true; + } + + List get items => state.values.toList(); +} + +final chainProvider = + NotifierProvider>( + () => ChainNotifier(), +); diff --git a/lib/llm/langchain/notifiers/tool_notifier.dart b/lib/llm/langchain/notifiers/tool_notifier.dart index b3076cf..bd21b4d 100644 --- a/lib/llm/langchain/notifiers/tool_notifier.dart +++ b/lib/llm/langchain/notifiers/tool_notifier.dart @@ -13,11 +13,15 @@ class ToolNotifier extends Notifier { changeState(LLMMessage? message) { state = message; - if (message == null) { - controller.jumpToPage(0); - } else { - controller.jumpToPage(1); - } + // if (message == null) { + // controller.jumpToPage(0); + // } else { + // controller.jumpToPage(1); + // } + } + + jumpTo(int index) { + controller.jumpToPage(index); } } diff --git a/pubspec.lock b/pubspec.lock index eb945be..59c3a5d 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -428,6 +428,22 @@ packages: description: flutter source: sdk version: "0.0.0" + flutter_expandable_fab: + dependency: "direct main" + description: + name: flutter_expandable_fab + sha256: "2aa5735bebcdbc49f43bcb32a29f9f03a9b7029212b8cd9837ae332ab2edf647" + url: "https://pub.flutter-io.cn" + source: hosted + version: "2.0.0" + flutter_flow_chart: + dependency: "direct main" + description: + name: flutter_flow_chart + sha256: "862e7535ddf9a8ac2e9de9bc07a79ba06540baa6a4beddf878f56744f036a747" + url: "https://pub.flutter-io.cn" + source: hosted + version: "2.2.1" flutter_highlight: dependency: transitive description: @@ -1036,6 +1052,14 @@ packages: url: "https://pub.flutter-io.cn" source: hosted version: "1.11.1" + star_menu: + dependency: "direct main" + description: + name: star_menu + sha256: b3147a753f2db3f4830a4d6f63bf56c96cde8bd0ff54d3464934b7726e4efa3a + url: "https://pub.flutter-io.cn" + source: hosted + version: "3.1.9" state_notifier: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index 57dff67..3bd7eee 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -26,6 +26,8 @@ dependencies: flutter_context_menu: git: url: https://github.com/guchengxi1994/flutter_context_menu + flutter_expandable_fab: ^2.0.0 + flutter_flow_chart: ^2.2.1 flutter_layout_grid: ^2.0.6 flutter_math_fork: ^0.7.2 flutter_riverpod: ^2.5.1 @@ -44,6 +46,7 @@ dependencies: riverpod: ^2.5.1 rust_lib_all_in_one: path: rust_builder + star_menu: ^3.1.4 syncfusion_flutter_calendar: ^25.1.38 # needs syncfusion license time_duration_picker: git: diff --git a/rust/src/api/llm_api.rs b/rust/src/api/llm_api.rs index c703e98..54b21a4 100644 --- a/rust/src/api/llm_api.rs +++ b/rust/src/api/llm_api.rs @@ -30,5 +30,5 @@ pub fn chat(_uuid: Option, _history: Option>, stream: bo pub fn sequential_chain_chat(json_str: String, query: String) { let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async { crate::llm::sequential_chain_chat(json_str, query).await }); + let _ = rt.block_on(async { crate::llm::sequential_chain_chat(json_str, query).await }); } From e57a8ef464a256e8bb7fea9f232e58cd579ebbbf Mon Sep 17 00:00:00 2001 From: xiaoshuyui <528490652@qq.com> Date: Fri, 3 May 2024 19:36:04 +0800 Subject: [PATCH 3/3] test linux cmakelist --- assets/llm.json | 4 ++-- linux/CMakeLists.txt | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/assets/llm.json b/assets/llm.json index 562098a..549a093 100644 --- a/assets/llm.json +++ b/assets/llm.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0a4d67a3a300aa2cfda52e61619405ab600e7cf09df9efdbde776da5ae0b33ad -size 636 +oid sha256:cab4366b205321f5d7b3a1507c8ec4432dee99155855030ff2a16667ff683b35 +size 647 diff --git a/linux/CMakeLists.txt b/linux/CMakeLists.txt index 5c9875c..f8c22c0 100644 --- a/linux/CMakeLists.txt +++ b/linux/CMakeLists.txt @@ -143,3 +143,30 @@ if(NOT CMAKE_BUILD_TYPE MATCHES "Debug") install(FILES "${AOT_LIBRARY}" DESTINATION "${INSTALL_BUNDLE_LIB_DIR}" COMPONENT Runtime) endif() + +get_filename_component(P1 ${CMAKE_CURRENT_BINARY_DIR} DIRECTORY) +message(STATUS "P1 ==> " ${P1}) +string(REPLACE "/build/linux/x64/" "" ROOT_DIR ${P1}) +message(STATUS "ROOT_DIR ==> " ${ROOT_DIR}) + +set(ENV_PATH ${ROOT_DIR}/.env) + +if(NOT EXISTS ${ENV_PATH}) + set(MULTILINE_TEXT + "LLM_BASE =\" your api here \" +LLM_MODEL_NAME =\" your model name here \" +LLM_SK =\" your api key here \" +CHAT_CHAT_BASE = \" your chatchat api here \" + " + ) + file(WRITE ${ENV_PATH} ${MULTILINE_TEXT}) + message("File did not exist, so it was created.") +else() + message("File already exists.") +endif() + +add_custom_target(copy-runtime-files ALL + COMMAND ${CMAKE_COMMAND} -E copy ${ENV_PATH} ${P1}/Debug/bundle/.env + COMMAND ${CMAKE_COMMAND} -E copy ${ENV_PATH} ${P1}/Release/bundle/.env + COMMAND ${CMAKE_COMMAND} -E copy ${ENV_PATH} ${P1}/Profile/bundle/.env +) \ No newline at end of file