From a51e17d49c6794dfd2918b9688f6fe1e18ae76df Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 9 Dec 2024 18:50:49 +0100 Subject: [PATCH 01/15] refactor: client takes impl into string on construction --- src/lib.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 6d85f8f..9ff31ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,21 +77,24 @@ impl Client { /// is useful if writing an application which invokes the client on behalf of many different /// users. Having neither request, nor default authentication is considered a bug and will cause /// a panic. - pub fn new(host: String, api_token: Option) -> Result { - let http_client = HttpClient::with_base_url(host, api_token)?; + pub fn new(host: impl Into, api_token: Option) -> Result { + let http_client = HttpClient::with_base_url(host.into(), api_token)?; Ok(Self { http_client }) } /// Use the Aleph Alpha SaaS offering with your API token for all requests. pub fn with_authentication(api_token: impl Into) -> Result { - Self::with_base_url("https://api.aleph-alpha.com".to_owned(), api_token) + Self::with_base_url("https://api.aleph-alpha.com", api_token) } /// Use your on-premise inference with your API token for all requests. /// /// In production you typically would want set this to . Yet /// you may want to use a different instances for testing. - pub fn with_base_url(host: String, api_token: impl Into) -> Result { + pub fn with_base_url( + host: impl Into, + api_token: impl Into, + ) -> Result { Self::new(host, Some(api_token.into())) } From 1ad5b836e9422e7bec2cf8b23de36eff1a5cad2a Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 9 Dec 2024 18:51:28 +0100 Subject: [PATCH 02/15] test: run against base_url from env --- tests/integration.rs | 51 ++++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/tests/integration.rs b/tests/integration.rs index 7be5332..1c5759b 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -19,6 +19,15 @@ fn api_token() -> &'static str { }) } +fn base_url() -> &'static str { + static AA_BASE_URL: OnceLock = OnceLock::new(); + AA_BASE_URL.get_or_init(|| { + drop(dotenv()); + std::env::var("AA_BASE_URL") + .expect("AA_BASE_URL environment variable must be specified to run tests.") + }) +} + #[tokio::test] async fn chat_with_pharia_1_7b_base() { // When @@ -26,7 +35,7 @@ async fn chat_with_pharia_1_7b_base() { let task = TaskChat::with_message(message); let model = "pharia-1-llm-7b-control"; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); let response = client.chat(&task, model, &How::default()).await.unwrap(); // Then @@ -39,7 +48,7 @@ async fn completion_with_luminous_base() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -57,7 +66,7 @@ async fn request_authentication_has_priority() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_authentication(bad_aa_api_token).unwrap(); + let client = Client::with_base_url(base_url(), bad_aa_api_token).unwrap(); let response = client .output_of( &task.with_model(model), @@ -82,7 +91,7 @@ async fn authentication_only_per_request() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); // When - let client = Client::new("https://api.aleph-alpha.com".to_owned(), None).unwrap(); + let client = Client::new(base_url().to_owned(), None).unwrap(); let response = client .output_of( &task.with_model(model), @@ -106,7 +115,7 @@ async fn must_panic_if_authentication_is_missing() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); // When - let client = Client::new("https://api.aleph-alpha.com".to_owned(), None).unwrap(); + let client = Client::new(base_url().to_owned(), None).unwrap(); client .output_of(&task.with_model(model), &How::default()) .await @@ -130,7 +139,7 @@ async fn semanitc_search_with_luminous_base() { temperature, traditionally in a wood-fired oven.", ); let query = Prompt::from_text("What is Pizza?"); - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); // When let robot_embedding_task = TaskSemanticEmbedding { @@ -193,7 +202,7 @@ async fn complete_structured_prompt() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -222,7 +231,7 @@ async fn maximum_tokens_none_request() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -243,7 +252,7 @@ async fn explain_request() { target: " How is it going?", granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Sentence), }; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); // When let response = client @@ -273,7 +282,7 @@ async fn explain_request_with_auto_granularity() { target: " How is it going?", granularity: Granularity::default(), }; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); // When let response = client @@ -305,7 +314,7 @@ async fn explain_request_with_image_modality() { target: " a cat.", granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Paragraph), }; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); // When let response = client @@ -355,7 +364,7 @@ async fn describe_image_starting_from_a_path() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -384,7 +393,7 @@ async fn describe_image_starting_from_a_dyn_image() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -410,7 +419,7 @@ async fn only_answer_with_specific_animal() { }, }; let model = "luminous-base"; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -437,7 +446,7 @@ async fn answer_should_continue() { }, }; let model = "luminous-base"; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -464,7 +473,7 @@ async fn batch_semanitc_embed_with_luminous_base() { temperature, traditionally in a wood-fired oven.", ); - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); // When let embedding_task = TaskBatchSemanticEmbedding { @@ -489,7 +498,7 @@ async fn tokenization_with_luminous_base() { // Given let input = "Hello, World!"; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); // When let task1 = TaskTokenization::new(input, false, true); @@ -526,7 +535,7 @@ async fn detokenization_with_luminous_base() { // Given let input = vec![49222, 15, 5390, 4]; - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); // When let task = TaskDetokenization { token_ids: &input }; @@ -543,7 +552,7 @@ async fn detokenization_with_luminous_base() { #[tokio::test] async fn fetch_tokenizer_for_pharia_1_llm_7b() { // Given - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); // When let tokenizer = client @@ -558,7 +567,7 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() { #[tokio::test] async fn stream_completion() { // Given a streaming completion task - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); let task = TaskCompletion::from_text("").with_maximum_tokens(7); // When the events are streamed and collected @@ -591,7 +600,7 @@ async fn stream_completion() { #[tokio::test] async fn stream_chat_with_pharia_1_llm_7b() { // Given a streaming completion task - let client = Client::with_authentication(api_token()).unwrap(); + let client = Client::with_base_url(base_url(), api_token()).unwrap(); let message = Message::user("Hello,"); let task = TaskChat::with_messages(vec![message]).with_maximum_tokens(7); From d4fc618cda4a50ceb6695a4a6ac1fbc4838f727f Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 9 Dec 2024 18:52:19 +0100 Subject: [PATCH 03/15] test: update name of pharia-1-llm model to lowercase --- tests/integration.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration.rs b/tests/integration.rs index 1c5759b..ad1fbbc 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -556,7 +556,7 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() { // When let tokenizer = client - .tokenizer_by_model("Pharia-1-LLM-7B-control", None) + .tokenizer_by_model("pharia-1-llm-7b-control", None) .await .unwrap(); From ef70cba890717a3f7ad13100f2caaa6b701adaed Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 9 Dec 2024 18:59:22 +0100 Subject: [PATCH 04/15] feat!: replace from_authentication by new from_env method as saas api offering is gone, we can not assume a default base url anymore --- .env.example | 2 ++ Cargo.toml | 2 +- README.md | 2 +- src/http.rs | 2 +- src/lib.rs | 38 ++++++++++++++++++++------------------ src/prompt.rs | 5 +---- 6 files changed, 26 insertions(+), 25 deletions(-) create mode 100644 .env.example diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..05263fc --- /dev/null +++ b/.env.example @@ -0,0 +1,2 @@ +AA_API_TOKEN= +AA_BASE_URL= \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index f6ff605..da40e93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ categories = ["api-bindings"] [dependencies] async-stream = "0.3.6" base64 = "0.22.0" +dotenv = "0.15.0" futures-util = "0.3.31" image = "0.25.1" itertools = "0.13.0" @@ -26,6 +27,5 @@ tokenizers = { version = "0.21.0", default-features = false, features = [ ] } [dev-dependencies] -dotenv = "0.15.0" tokio = { version = "1.37.0", features = ["rt", "macros"] } wiremock = "0.6.0" diff --git a/README.md b/README.md index ff87da7..370f13b 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ use aleph_alpha_client::{Client, TaskCompletion, How, Task}; #[tokio::main] fn main() { // Authenticate against API. Fetches token. - let client = Client::with_authentication("AA_API_TOKEN").unwrap(); + let client = Client::from_env().unwrap(); // Name of the model we we want to use. Large models give usually better answer, but are also // more costly. diff --git a/src/http.rs b/src/http.rs index bccbb76..9f7698d 100644 --- a/src/http.rs +++ b/src/http.rs @@ -139,7 +139,7 @@ impl HttpClient { /// /// async fn print_completion() -> Result<(), Error> { /// // Authenticate against API. Fetches token. - /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// let client = Client::from_env()?; /// /// // Name of the model we we want to use. Large models give usually better answer, but are /// // also slower and more costly. diff --git a/src/lib.rs b/src/lib.rs index 9ff31ae..e9d28de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ //! #[tokio::main(flavor = "current_thread")] //! async fn main() { //! // Authenticate against API. Fetches token. -//! let client = Client::with_authentication("AA_API_TOKEN").unwrap(); +//! let client = Client::from_env().unwrap(); //! //! // Name of the model we we want to use. Large models give usually better answer, but are also //! // more costly. @@ -33,11 +33,12 @@ mod prompt; mod semantic_embedding; mod stream; mod tokenization; -use std::{pin::Pin, time::Duration}; - +use dotenv::dotenv; use futures_util::Stream; use http::HttpClient; use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput}; +use std::env; +use std::{pin::Pin, time::Duration}; use tokenizers::Tokenizer; pub use self::{ @@ -70,8 +71,7 @@ pub struct Client { impl Client { /// A new instance of an Aleph Alpha client helping you interact with the Aleph Alpha API. - /// For "normal" client applications you may likely rather use [`Self::with_authentication`] or - /// [`Self::with_base_url`]. + /// For "normal" client applications you may likely rather use [`Self::with_base_url`]. /// /// You may want to only use request based authentication and skip default authentication. This /// is useful if writing an application which invokes the client on behalf of many different @@ -82,11 +82,6 @@ impl Client { Ok(Self { http_client }) } - /// Use the Aleph Alpha SaaS offering with your API token for all requests. - pub fn with_authentication(api_token: impl Into) -> Result { - Self::with_base_url("https://api.aleph-alpha.com", api_token) - } - /// Use your on-premise inference with your API token for all requests. /// /// In production you typically would want set this to . Yet @@ -98,6 +93,13 @@ impl Client { Self::new(host, Some(api_token.into())) } + pub fn from_env() -> Result { + let _ = dotenv(); + let api_token = env::var("AA_API_TOKEN").unwrap(); + let base_url = env::var("AA_BASE_URL").unwrap(); + Self::with_base_url(base_url, api_token) + } + /// Execute a task with the aleph alpha API and fetch its result. /// /// ```no_run @@ -105,7 +107,7 @@ impl Client { /// /// async fn print_completion() -> Result<(), Error> { /// // Authenticate against API. Fetches token. - /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// let client = Client::from_env()?; /// /// // Name of the model we we want to use. Large models give usually better answer, but are /// // also slower and more costly. @@ -169,7 +171,7 @@ impl Client { /// /// async fn print_completion() -> Result<(), Error> { /// // Authenticate against API. Fetches token. - /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// let client = Client::from_env()?; /// /// // Name of the model we we want to use. Large models give usually better answer, but are /// // also slower and more costly. @@ -207,7 +209,7 @@ impl Client { /// /// async fn print_stream_completion() -> Result<(), Error> { /// // Authenticate against API. Fetches token. - /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// let client = Client::from_env()?; /// /// // Name of the model we we want to use. Large models give usually better answer, but are /// // also slower and more costly. @@ -244,7 +246,7 @@ impl Client { /// /// async fn print_chat() -> Result<(), Error> { /// // Authenticate against API. Fetches token. - /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// let client = Client::from_env()?; /// /// // Name of a model that supports chat. /// let model = "pharia-1-llm-7b-control"; @@ -279,7 +281,7 @@ impl Client { /// /// async fn print_stream_chat() -> Result<(), Error> { /// // Authenticate against API. Fetches token. - /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// let client = Client::from_env()?; /// /// // Name of a model that supports chat. /// let model = "pharia-1-llm-7b-control"; @@ -315,7 +317,7 @@ impl Client { /// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error, Granularity, TaskExplanation, Stopping, Prompt, Sampling}; /// /// async fn print_explanation() -> Result<(), Error> { - /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// let client = Client::from_env()?; /// /// // Name of the model we we want to use. Large models give usually better answer, but are /// // also slower and more costly. @@ -359,7 +361,7 @@ impl Client { /// use aleph_alpha_client::{Client, Error, How, TaskTokenization}; /// /// async fn tokenize() -> Result<(), Error> { - /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// let client = Client::from_env()?; /// /// // Name of the model for which we want to tokenize text. /// let model = "luminous-base"; @@ -395,7 +397,7 @@ impl Client { /// use aleph_alpha_client::{Client, Error, How, TaskDetokenization}; /// /// async fn detokenize() -> Result<(), Error> { - /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// let client = Client::from_env()?; /// /// // Specify the name of the model whose tokenizer was used to generate the input token ids. /// let model = "luminous-base"; diff --git a/src/prompt.rs b/src/prompt.rs index c34e17b..55dd125 100644 --- a/src/prompt.rs +++ b/src/prompt.rs @@ -87,10 +87,7 @@ impl<'a> Modality<'a> { /// #[tokio::main(flavor = "current_thread")] /// async fn main() { /// // Create client - /// let _ = dotenv(); - /// let aa_api_token = std::env::var("AA_API_TOKEN") - /// .expect("AA_API_TOKEN environment variable must be specified to run demo."); - /// let client = Client::with_authentication(aa_api_token).unwrap(); + /// let client = Client::from_env().unwrap(); /// // Define task /// let task = TaskCompletion { /// prompt: Prompt::from_vec(vec![ From 9d93b733b9a9987388cc446494c190d77d0b1c85 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 9 Dec 2024 19:01:09 +0100 Subject: [PATCH 05/15] style: remove unused reference to fix clippy warning --- src/chat.rs | 2 +- src/completion.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/chat.rs b/src/chat.rs index 35db048..98152c6 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -207,7 +207,7 @@ impl<'a> StreamTask for TaskChat<'a> { base: &str, model: &str, ) -> reqwest::RequestBuilder { - let body = ChatBody::new(model, &self).with_streaming(); + let body = ChatBody::new(model, self).with_streaming(); client.post(format!("{base}/chat/completions")).json(&body) } diff --git a/src/completion.rs b/src/completion.rs index 70301ba..bdbe2f3 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -270,7 +270,7 @@ impl StreamTask for TaskCompletion<'_> { base: &str, model: &str, ) -> reqwest::RequestBuilder { - let body = BodyCompletion::new(model, &self).with_streaming(); + let body = BodyCompletion::new(model, self).with_streaming(); client.post(format!("{base}/complete")).json(&body) } From 59ce735f8078e0db5d8d2b89739f23d3c5e47b0a Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 9 Dec 2024 19:02:21 +0100 Subject: [PATCH 06/15] ci: read base url from github secrets --- .github/workflows/test.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6cb6696..6afa17e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,9 +8,7 @@ on: jobs: build: - runs-on: ubuntu-latest - steps: - uses: actions/checkout@v2 - name: Build @@ -18,4 +16,5 @@ jobs: - name: Run tests env: AA_API_TOKEN: ${{ secrets.AA_API_TOKEN }} + AA_BASE_URL: ${{ secrets.AA_BASE_URL }} run: cargo test From be7dc9fa7a99bd0807bea19253ce7158f2d81166 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 9 Dec 2024 19:09:56 +0100 Subject: [PATCH 07/15] docs: remove last mentions of api.aleph-alpha.com --- src/http.rs | 4 ++-- src/lib.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/http.rs b/src/http.rs index 9f7698d..df0f367 100644 --- a/src/http.rs +++ b/src/http.rs @@ -90,8 +90,8 @@ pub struct HttpClient { } impl HttpClient { - /// In production you typically would want set this to . Yet you - /// may want to use a different instances for testing. + /// In production you typically would want set this to . + /// Yet you may want to use a different instance for testing. pub fn with_base_url(host: String, api_token: Option) -> Result { let http = ClientBuilder::new().build()?; diff --git a/src/lib.rs b/src/lib.rs index e9d28de..139c9ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -84,8 +84,8 @@ impl Client { /// Use your on-premise inference with your API token for all requests. /// - /// In production you typically would want set this to . Yet - /// you may want to use a different instances for testing. + /// In production you typically would want set this to . + /// Yet you may want to use a different instance for testing. pub fn with_base_url( host: impl Into, api_token: impl Into, From 74e4442ada16567ea2bb1199d3e40b51fed69643 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 9 Dec 2024 22:45:33 +0100 Subject: [PATCH 08/15] style: remove lifetime specifiers which can be elided --- src/chat.rs | 4 ++-- src/completion.rs | 2 +- src/detokenization.rs | 2 +- src/http.rs | 2 +- src/stream.rs | 2 +- src/tokenization.rs | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/chat.rs b/src/chat.rs index 98152c6..ff413c1 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -149,7 +149,7 @@ impl<'a> ChatBody<'a> { } } -impl<'a> Task for TaskChat<'a> { +impl Task for TaskChat<'_> { type Output = ChatOutput; type ResponseBody = ResponseChat; @@ -196,7 +196,7 @@ pub struct ChatEvent { pub choices: Vec, } -impl<'a> StreamTask for TaskChat<'a> { +impl StreamTask for TaskChat<'_> { type Output = ChatStreamChunk; type ResponseBody = ChatEvent; diff --git a/src/completion.rs b/src/completion.rs index bdbe2f3..b1aedb5 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -94,7 +94,7 @@ pub struct Stopping<'a> { pub stop_sequences: &'a [&'a str], } -impl<'a> Stopping<'a> { +impl Stopping<'_> { /// Only stop once the model reaches its technical limit, usually the context window. pub const NO_TOKEN_LIMIT: Self = Stopping { maximum_tokens: None, diff --git a/src/detokenization.rs b/src/detokenization.rs index 644fc10..2a728e8 100644 --- a/src/detokenization.rs +++ b/src/detokenization.rs @@ -34,7 +34,7 @@ impl From for DetokenizationOutput { } } -impl<'a> Task for TaskDetokenization<'a> { +impl Task for TaskDetokenization<'_> { type Output = DetokenizationOutput; type ResponseBody = ResponseDetokenization; diff --git a/src/http.rs b/src/http.rs index df0f367..cdd153a 100644 --- a/src/http.rs +++ b/src/http.rs @@ -65,7 +65,7 @@ pub struct MethodJob<'a, T> { pub task: &'a T, } -impl<'a, T> Job for MethodJob<'a, T> +impl Job for MethodJob<'_, T> where T: Task, { diff --git a/src/stream.rs b/src/stream.rs index 0fe5e52..de9b360 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -50,7 +50,7 @@ pub trait StreamTask { } } -impl<'a, T> StreamJob for MethodJob<'a, T> +impl StreamJob for MethodJob<'_, T> where T: StreamTask, { diff --git a/src/tokenization.rs b/src/tokenization.rs index e552622..4e5bb4d 100644 --- a/src/tokenization.rs +++ b/src/tokenization.rs @@ -14,7 +14,7 @@ pub struct TaskTokenization<'a> { } impl<'a> From<&'a str> for TaskTokenization<'a> { - fn from(prompt: &'a str) -> TaskTokenization { + fn from(prompt: &str) -> TaskTokenization { TaskTokenization { prompt, tokens: true, From b850f78a0b497effa79488c34bd2fc9b2c6bb8e5 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 9 Dec 2024 22:53:57 +0100 Subject: [PATCH 09/15] ci: run clippy and fmt --- .github/workflows/test.yml | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6afa17e..00010a6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -7,14 +7,26 @@ on: branches: [ main ] jobs: - build: + lints: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Build - run: cargo build - - name: Run tests - env: - AA_API_TOKEN: ${{ secrets.AA_API_TOKEN }} - AA_BASE_URL: ${{ secrets.AA_BASE_URL }} - run: cargo test + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy + - uses: Swatinem/rust-cache@v2 + - run: cargo fmt -- --check + - run: cargo clippy -- -D warnings + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + - name: Build + run: cargo build + - name: Run tests + env: + AA_API_TOKEN: ${{ secrets.AA_API_TOKEN }} + AA_BASE_URL: ${{ secrets.AA_BASE_URL }} + run: cargo test From 44aa48d4f2973efdd8b5a7c4c2a33ed22da22804 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Tue, 10 Dec 2024 10:05:28 +0100 Subject: [PATCH 10/15] ci: specify aa base url in code --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 00010a6..44e02f6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,5 +28,5 @@ jobs: - name: Run tests env: AA_API_TOKEN: ${{ secrets.AA_API_TOKEN }} - AA_BASE_URL: ${{ secrets.AA_BASE_URL }} + AA_BASE_URL: https://inference-api.product.pharia.com run: cargo test From 33f545eede25b52b53860d20aac5e6fbbcd4a85c Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Tue, 10 Dec 2024 10:08:00 +0100 Subject: [PATCH 11/15] feat!: token env is called PHARIA_AI_TOKEN --- .env.example | 2 +- .github/workflows/test.yml | 2 +- .gitignore | 2 +- src/lib.rs | 2 +- tests/integration.rs | 54 +++++++++++++++++++------------------- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/.env.example b/.env.example index 05263fc..4259563 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,2 @@ -AA_API_TOKEN= +PHARIA_AI_TOKEN= AA_BASE_URL= \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 44e02f6..31b1d45 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,6 +27,6 @@ jobs: run: cargo build - name: Run tests env: - AA_API_TOKEN: ${{ secrets.AA_API_TOKEN }} + PHARIA_AI_TOKEN: ${{ secrets.PHARIA_AI_TOKEN }} AA_BASE_URL: https://inference-api.product.pharia.com run: cargo test diff --git a/.gitignore b/.gitignore index a38be1f..7d72550 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -# Avoid commiting AA_API_TOKEN +# Avoid commiting PHARIA_AI_TOKEN .env /target diff --git a/src/lib.rs b/src/lib.rs index 139c9ff..161d87c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -95,7 +95,7 @@ impl Client { pub fn from_env() -> Result { let _ = dotenv(); - let api_token = env::var("AA_API_TOKEN").unwrap(); + let api_token = env::var("PHARIA_AI_TOKEN").unwrap(); let base_url = env::var("AA_BASE_URL").unwrap(); Self::with_base_url(base_url, api_token) } diff --git a/tests/integration.rs b/tests/integration.rs index ad1fbbc..024f9e7 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -10,12 +10,12 @@ use dotenv::dotenv; use futures_util::StreamExt; use image::ImageFormat; -fn api_token() -> &'static str { - static AA_API_TOKEN: OnceLock = OnceLock::new(); - AA_API_TOKEN.get_or_init(|| { +fn pharia_ai_token() -> &'static str { + static PHARIA_AI_TOKEN: OnceLock = OnceLock::new(); + PHARIA_AI_TOKEN.get_or_init(|| { drop(dotenv()); - std::env::var("AA_API_TOKEN") - .expect("AA_API_TOKEN environment variable must be specified to run tests.") + std::env::var("PHARIA_AI_TOKEN") + .expect("PHARIA_AI_TOKEN environment variable must be specified to run tests.") }) } @@ -35,7 +35,7 @@ async fn chat_with_pharia_1_7b_base() { let task = TaskChat::with_message(message); let model = "pharia-1-llm-7b-control"; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); let response = client.chat(&task, model, &How::default()).await.unwrap(); // Then @@ -48,7 +48,7 @@ async fn completion_with_luminous_base() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -62,16 +62,16 @@ async fn completion_with_luminous_base() { #[tokio::test] async fn request_authentication_has_priority() { - let bad_aa_api_token = "DUMMY"; + let bad_pharia_ai_token = "DUMMY"; let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(base_url(), bad_aa_api_token).unwrap(); + let client = Client::with_base_url(base_url(), bad_pharia_ai_token).unwrap(); let response = client .output_of( &task.with_model(model), &How { - api_token: Some(api_token().to_owned()), + api_token: Some(pharia_ai_token().to_owned()), ..Default::default() }, ) @@ -96,7 +96,7 @@ async fn authentication_only_per_request() { .output_of( &task.with_model(model), &How { - api_token: Some(api_token().to_owned()), + api_token: Some(pharia_ai_token().to_owned()), ..Default::default() }, ) @@ -139,7 +139,7 @@ async fn semanitc_search_with_luminous_base() { temperature, traditionally in a wood-fired oven.", ); let query = Prompt::from_text("What is Pizza?"); - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); // When let robot_embedding_task = TaskSemanticEmbedding { @@ -202,7 +202,7 @@ async fn complete_structured_prompt() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -231,7 +231,7 @@ async fn maximum_tokens_none_request() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -252,7 +252,7 @@ async fn explain_request() { target: " How is it going?", granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Sentence), }; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -282,7 +282,7 @@ async fn explain_request_with_auto_granularity() { target: " How is it going?", granularity: Granularity::default(), }; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -314,7 +314,7 @@ async fn explain_request_with_image_modality() { target: " a cat.", granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Paragraph), }; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -364,7 +364,7 @@ async fn describe_image_starting_from_a_path() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -393,7 +393,7 @@ async fn describe_image_starting_from_a_dyn_image() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -419,7 +419,7 @@ async fn only_answer_with_specific_animal() { }, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -446,7 +446,7 @@ async fn answer_should_continue() { }, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -473,7 +473,7 @@ async fn batch_semanitc_embed_with_luminous_base() { temperature, traditionally in a wood-fired oven.", ); - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); // When let embedding_task = TaskBatchSemanticEmbedding { @@ -498,7 +498,7 @@ async fn tokenization_with_luminous_base() { // Given let input = "Hello, World!"; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); // When let task1 = TaskTokenization::new(input, false, true); @@ -535,7 +535,7 @@ async fn detokenization_with_luminous_base() { // Given let input = vec![49222, 15, 5390, 4]; - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); // When let task = TaskDetokenization { token_ids: &input }; @@ -552,7 +552,7 @@ async fn detokenization_with_luminous_base() { #[tokio::test] async fn fetch_tokenizer_for_pharia_1_llm_7b() { // Given - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); // When let tokenizer = client @@ -567,7 +567,7 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() { #[tokio::test] async fn stream_completion() { // Given a streaming completion task - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); let task = TaskCompletion::from_text("").with_maximum_tokens(7); // When the events are streamed and collected @@ -600,7 +600,7 @@ async fn stream_completion() { #[tokio::test] async fn stream_chat_with_pharia_1_llm_7b() { // Given a streaming completion task - let client = Client::with_base_url(base_url(), api_token()).unwrap(); + let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); let message = Message::user("Hello,"); let task = TaskChat::with_messages(vec![message]).with_maximum_tokens(7); From 836b2db4275c11da73c9c223f5167d89c391f1d0 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Tue, 10 Dec 2024 10:12:55 +0100 Subject: [PATCH 12/15] test: replace once lock with lazy lock --- tests/integration.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/integration.rs b/tests/integration.rs index 024f9e7..9c016cb 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,4 +1,4 @@ -use std::{fs::File, io::BufReader, sync::OnceLock}; +use std::{fs::File, io::BufReader}; use aleph_alpha_client::{ cosine_similarity, Client, CompletionEvent, Granularity, How, ImageScore, ItemExplanation, @@ -9,23 +9,24 @@ use aleph_alpha_client::{ use dotenv::dotenv; use futures_util::StreamExt; use image::ImageFormat; +use std::sync::LazyLock; fn pharia_ai_token() -> &'static str { - static PHARIA_AI_TOKEN: OnceLock = OnceLock::new(); - PHARIA_AI_TOKEN.get_or_init(|| { + static PHARIA_AI_TOKEN: LazyLock = LazyLock::new(|| { drop(dotenv()); std::env::var("PHARIA_AI_TOKEN") .expect("PHARIA_AI_TOKEN environment variable must be specified to run tests.") - }) + }); + &PHARIA_AI_TOKEN } fn base_url() -> &'static str { - static AA_BASE_URL: OnceLock = OnceLock::new(); - AA_BASE_URL.get_or_init(|| { + static AA_BASE_URL: LazyLock = LazyLock::new(|| { drop(dotenv()); std::env::var("AA_BASE_URL") .expect("AA_BASE_URL environment variable must be specified to run tests.") - }) + }); + &AA_BASE_URL } #[tokio::test] From 1945625e7c16e2123af4d2e54328ec3522f60ca8 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Tue, 10 Dec 2024 10:14:10 +0100 Subject: [PATCH 13/15] build(deps): replace unmaintained dotenv by dotenvy --- Cargo.toml | 2 +- src/lib.rs | 2 +- src/prompt.rs | 2 +- tests/integration.rs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index da40e93..18edf08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ categories = ["api-bindings"] [dependencies] async-stream = "0.3.6" base64 = "0.22.0" -dotenv = "0.15.0" +dotenvy = "0.15.7" futures-util = "0.3.31" image = "0.25.1" itertools = "0.13.0" diff --git a/src/lib.rs b/src/lib.rs index 161d87c..8282698 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,7 +33,7 @@ mod prompt; mod semantic_embedding; mod stream; mod tokenization; -use dotenv::dotenv; +use dotenvy::dotenv; use futures_util::Stream; use http::HttpClient; use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput}; diff --git a/src/prompt.rs b/src/prompt.rs index 55dd125..21c1188 100644 --- a/src/prompt.rs +++ b/src/prompt.rs @@ -81,7 +81,7 @@ impl<'a> Modality<'a> { /// /// ```no_run /// use aleph_alpha_client::{Client, How, Modality, Prompt, Sampling, Stopping, TaskCompletion, Task}; - /// use dotenv::dotenv; + /// use dotenvy::dotenv; /// use std::path::PathBuf; /// /// #[tokio::main(flavor = "current_thread")] diff --git a/tests/integration.rs b/tests/integration.rs index 9c016cb..4d1d930 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -6,7 +6,7 @@ use aleph_alpha_client::{ TaskBatchSemanticEmbedding, TaskChat, TaskCompletion, TaskDetokenization, TaskExplanation, TaskSemanticEmbedding, TaskTokenization, TextScore, }; -use dotenv::dotenv; +use dotenvy::dotenv; use futures_util::StreamExt; use image::ImageFormat; use std::sync::LazyLock; From c9c48b20277a89935a703c85feb51b503a9bfdf3 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Tue, 10 Dec 2024 10:18:39 +0100 Subject: [PATCH 14/15] feat: rename base_url env variable to inference_url --- .env.example | 2 +- .github/workflows/test.yml | 2 +- src/lib.rs | 4 +-- tests/integration.rs | 52 +++++++++++++++++++------------------- 4 files changed, 30 insertions(+), 30 deletions(-) diff --git a/.env.example b/.env.example index 4259563..4b5a2d3 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,2 @@ PHARIA_AI_TOKEN= -AA_BASE_URL= \ No newline at end of file +INFERENCE_URL= \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 31b1d45..ffbce3c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,5 +28,5 @@ jobs: - name: Run tests env: PHARIA_AI_TOKEN: ${{ secrets.PHARIA_AI_TOKEN }} - AA_BASE_URL: https://inference-api.product.pharia.com + INFERENCE_URL: https://inference-api.product.pharia.com run: cargo test diff --git a/src/lib.rs b/src/lib.rs index 8282698..611a78c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,8 +96,8 @@ impl Client { pub fn from_env() -> Result { let _ = dotenv(); let api_token = env::var("PHARIA_AI_TOKEN").unwrap(); - let base_url = env::var("AA_BASE_URL").unwrap(); - Self::with_base_url(base_url, api_token) + let inference_url = env::var("INFERENCE_URL").unwrap(); + Self::with_base_url(inference_url, api_token) } /// Execute a task with the aleph alpha API and fetch its result. diff --git a/tests/integration.rs b/tests/integration.rs index 4d1d930..80d5b62 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -20,13 +20,13 @@ fn pharia_ai_token() -> &'static str { &PHARIA_AI_TOKEN } -fn base_url() -> &'static str { - static AA_BASE_URL: LazyLock = LazyLock::new(|| { +fn inference_url() -> &'static str { + static INFERENCE_URL: LazyLock = LazyLock::new(|| { drop(dotenv()); - std::env::var("AA_BASE_URL") - .expect("AA_BASE_URL environment variable must be specified to run tests.") + std::env::var("INFERENCE_URL") + .expect("INFERENCE_URL environment variable must be specified to run tests.") }); - &AA_BASE_URL + &INFERENCE_URL } #[tokio::test] @@ -36,7 +36,7 @@ async fn chat_with_pharia_1_7b_base() { let task = TaskChat::with_message(message); let model = "pharia-1-llm-7b-control"; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); let response = client.chat(&task, model, &How::default()).await.unwrap(); // Then @@ -49,7 +49,7 @@ async fn completion_with_luminous_base() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -67,7 +67,7 @@ async fn request_authentication_has_priority() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(base_url(), bad_pharia_ai_token).unwrap(); + let client = Client::with_base_url(inference_url(), bad_pharia_ai_token).unwrap(); let response = client .output_of( &task.with_model(model), @@ -92,7 +92,7 @@ async fn authentication_only_per_request() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); // When - let client = Client::new(base_url().to_owned(), None).unwrap(); + let client = Client::new(inference_url().to_owned(), None).unwrap(); let response = client .output_of( &task.with_model(model), @@ -116,7 +116,7 @@ async fn must_panic_if_authentication_is_missing() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); // When - let client = Client::new(base_url().to_owned(), None).unwrap(); + let client = Client::new(inference_url().to_owned(), None).unwrap(); client .output_of(&task.with_model(model), &How::default()) .await @@ -140,7 +140,7 @@ async fn semanitc_search_with_luminous_base() { temperature, traditionally in a wood-fired oven.", ); let query = Prompt::from_text("What is Pizza?"); - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); // When let robot_embedding_task = TaskSemanticEmbedding { @@ -203,7 +203,7 @@ async fn complete_structured_prompt() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -232,7 +232,7 @@ async fn maximum_tokens_none_request() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -253,7 +253,7 @@ async fn explain_request() { target: " How is it going?", granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Sentence), }; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -283,7 +283,7 @@ async fn explain_request_with_auto_granularity() { target: " How is it going?", granularity: Granularity::default(), }; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -315,7 +315,7 @@ async fn explain_request_with_image_modality() { target: " a cat.", granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Paragraph), }; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -365,7 +365,7 @@ async fn describe_image_starting_from_a_path() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -394,7 +394,7 @@ async fn describe_image_starting_from_a_dyn_image() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -420,7 +420,7 @@ async fn only_answer_with_specific_animal() { }, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -447,7 +447,7 @@ async fn answer_should_continue() { }, }; let model = "luminous-base"; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -474,7 +474,7 @@ async fn batch_semanitc_embed_with_luminous_base() { temperature, traditionally in a wood-fired oven.", ); - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); // When let embedding_task = TaskBatchSemanticEmbedding { @@ -499,7 +499,7 @@ async fn tokenization_with_luminous_base() { // Given let input = "Hello, World!"; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); // When let task1 = TaskTokenization::new(input, false, true); @@ -536,7 +536,7 @@ async fn detokenization_with_luminous_base() { // Given let input = vec![49222, 15, 5390, 4]; - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); // When let task = TaskDetokenization { token_ids: &input }; @@ -553,7 +553,7 @@ async fn detokenization_with_luminous_base() { #[tokio::test] async fn fetch_tokenizer_for_pharia_1_llm_7b() { // Given - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); // When let tokenizer = client @@ -568,7 +568,7 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() { #[tokio::test] async fn stream_completion() { // Given a streaming completion task - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); let task = TaskCompletion::from_text("").with_maximum_tokens(7); // When the events are streamed and collected @@ -601,7 +601,7 @@ async fn stream_completion() { #[tokio::test] async fn stream_chat_with_pharia_1_llm_7b() { // Given a streaming completion task - let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap(); + let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); let message = Message::user("Hello,"); let task = TaskChat::with_messages(vec![message]).with_maximum_tokens(7); From 380fbe056b0c4052738102e4e2100361fcfae820 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Tue, 10 Dec 2024 10:37:31 +0100 Subject: [PATCH 15/15] feat!: clean up naming of methods to setup clients --- src/http.rs | 2 +- src/lib.rs | 16 +++++----------- tests/integration.rs | 38 +++++++++++++++++++------------------- tests/unit.rs | 12 ++++++------ 4 files changed, 31 insertions(+), 37 deletions(-) diff --git a/src/http.rs b/src/http.rs index cdd153a..7f9022f 100644 --- a/src/http.rs +++ b/src/http.rs @@ -92,7 +92,7 @@ pub struct HttpClient { impl HttpClient { /// In production you typically would want set this to . /// Yet you may want to use a different instance for testing. - pub fn with_base_url(host: String, api_token: Option) -> Result { + pub fn new(host: String, api_token: Option) -> Result { let http = ClientBuilder::new().build()?; Ok(Self { diff --git a/src/lib.rs b/src/lib.rs index 611a78c..21ed9b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,25 +71,19 @@ pub struct Client { impl Client { /// A new instance of an Aleph Alpha client helping you interact with the Aleph Alpha API. - /// For "normal" client applications you may likely rather use [`Self::with_base_url`]. /// + /// Setting the token to None allows specifying it on a per request basis. /// You may want to only use request based authentication and skip default authentication. This /// is useful if writing an application which invokes the client on behalf of many different /// users. Having neither request, nor default authentication is considered a bug and will cause /// a panic. pub fn new(host: impl Into, api_token: Option) -> Result { - let http_client = HttpClient::with_base_url(host.into(), api_token)?; + let http_client = HttpClient::new(host.into(), api_token)?; Ok(Self { http_client }) } - /// Use your on-premise inference with your API token for all requests. - /// - /// In production you typically would want set this to . - /// Yet you may want to use a different instance for testing. - pub fn with_base_url( - host: impl Into, - api_token: impl Into, - ) -> Result { + /// A client instance that always uses the same token for all requests. + pub fn with_auth(host: impl Into, api_token: impl Into) -> Result { Self::new(host, Some(api_token.into())) } @@ -97,7 +91,7 @@ impl Client { let _ = dotenv(); let api_token = env::var("PHARIA_AI_TOKEN").unwrap(); let inference_url = env::var("INFERENCE_URL").unwrap(); - Self::with_base_url(inference_url, api_token) + Self::with_auth(inference_url, api_token) } /// Execute a task with the aleph alpha API and fetch its result. diff --git a/tests/integration.rs b/tests/integration.rs index 80d5b62..3349242 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -36,7 +36,7 @@ async fn chat_with_pharia_1_7b_base() { let task = TaskChat::with_message(message); let model = "pharia-1-llm-7b-control"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client.chat(&task, model, &How::default()).await.unwrap(); // Then @@ -49,7 +49,7 @@ async fn completion_with_luminous_base() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -67,7 +67,7 @@ async fn request_authentication_has_priority() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), bad_pharia_ai_token).unwrap(); + let client = Client::with_auth(inference_url(), bad_pharia_ai_token).unwrap(); let response = client .output_of( &task.with_model(model), @@ -140,7 +140,7 @@ async fn semanitc_search_with_luminous_base() { temperature, traditionally in a wood-fired oven.", ); let query = Prompt::from_text("What is Pizza?"); - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let robot_embedding_task = TaskSemanticEmbedding { @@ -203,7 +203,7 @@ async fn complete_structured_prompt() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -232,7 +232,7 @@ async fn maximum_tokens_none_request() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -253,7 +253,7 @@ async fn explain_request() { target: " How is it going?", granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Sentence), }; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -283,7 +283,7 @@ async fn explain_request_with_auto_granularity() { target: " How is it going?", granularity: Granularity::default(), }; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -315,7 +315,7 @@ async fn explain_request_with_image_modality() { target: " a cat.", granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Paragraph), }; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -365,7 +365,7 @@ async fn describe_image_starting_from_a_path() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -394,7 +394,7 @@ async fn describe_image_starting_from_a_dyn_image() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -420,7 +420,7 @@ async fn only_answer_with_specific_animal() { }, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -447,7 +447,7 @@ async fn answer_should_continue() { }, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -474,7 +474,7 @@ async fn batch_semanitc_embed_with_luminous_base() { temperature, traditionally in a wood-fired oven.", ); - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let embedding_task = TaskBatchSemanticEmbedding { @@ -499,7 +499,7 @@ async fn tokenization_with_luminous_base() { // Given let input = "Hello, World!"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let task1 = TaskTokenization::new(input, false, true); @@ -536,7 +536,7 @@ async fn detokenization_with_luminous_base() { // Given let input = vec![49222, 15, 5390, 4]; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let task = TaskDetokenization { token_ids: &input }; @@ -553,7 +553,7 @@ async fn detokenization_with_luminous_base() { #[tokio::test] async fn fetch_tokenizer_for_pharia_1_llm_7b() { // Given - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let tokenizer = client @@ -568,7 +568,7 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() { #[tokio::test] async fn stream_completion() { // Given a streaming completion task - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let task = TaskCompletion::from_text("").with_maximum_tokens(7); // When the events are streamed and collected @@ -601,7 +601,7 @@ async fn stream_completion() { #[tokio::test] async fn stream_chat_with_pharia_1_llm_7b() { // Given a streaming completion task - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let message = Message::user("Hello,"); let task = TaskChat::with_messages(vec![message]).with_maximum_tokens(7); diff --git a/tests/unit.rs b/tests/unit.rs index e29c6a7..c687803 100644 --- a/tests/unit.rs +++ b/tests/unit.rs @@ -34,7 +34,7 @@ async fn completion_with_luminous_base() { // When let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -74,7 +74,7 @@ async fn detect_rate_limiting() { // When let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); let error = client .output_of(&task.with_model(model), &How::default()) .await @@ -118,7 +118,7 @@ async fn detect_queue_full() { // When let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); let error = client .output_of(&task.with_model(model), &How::default()) .await @@ -155,7 +155,7 @@ async fn detect_service_unavailable() { // When let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); let error = client .output_of(&task.with_model(model), &How::default()) .await @@ -177,7 +177,7 @@ async fn be_nice() { // When let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); // Drop result, answer is meaningless anyway let _ = client .output_of( @@ -206,7 +206,7 @@ async fn client_timeout() { .respond_with(ResponseTemplate::new(StatusCode::OK).set_delay(response_time)) .mount(&mock_server) .await; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); // When let result = client