From 5f18f3854835099f04fc505cdd38ef1cffe24a8a Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Tue, 10 Dec 2024 10:37:31 +0100 Subject: [PATCH] feat!: clean up naming of methods to setup clients --- src/http.rs | 2 +- src/lib.rs | 16 +++++----------- tests/integration.rs | 38 +++++++++++++++++++------------------- tests/unit.rs | 12 ++++++------ 4 files changed, 31 insertions(+), 37 deletions(-) diff --git a/src/http.rs b/src/http.rs index cdd153a..7f9022f 100644 --- a/src/http.rs +++ b/src/http.rs @@ -92,7 +92,7 @@ pub struct HttpClient { impl HttpClient { /// In production you typically would want set this to . /// Yet you may want to use a different instance for testing. - pub fn with_base_url(host: String, api_token: Option) -> Result { + pub fn new(host: String, api_token: Option) -> Result { let http = ClientBuilder::new().build()?; Ok(Self { diff --git a/src/lib.rs b/src/lib.rs index 611a78c..21ed9b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,25 +71,19 @@ pub struct Client { impl Client { /// A new instance of an Aleph Alpha client helping you interact with the Aleph Alpha API. - /// For "normal" client applications you may likely rather use [`Self::with_base_url`]. /// + /// Setting the token to None allows specifying it on a per request basis. /// You may want to only use request based authentication and skip default authentication. This /// is useful if writing an application which invokes the client on behalf of many different /// users. Having neither request, nor default authentication is considered a bug and will cause /// a panic. pub fn new(host: impl Into, api_token: Option) -> Result { - let http_client = HttpClient::with_base_url(host.into(), api_token)?; + let http_client = HttpClient::new(host.into(), api_token)?; Ok(Self { http_client }) } - /// Use your on-premise inference with your API token for all requests. - /// - /// In production you typically would want set this to . - /// Yet you may want to use a different instance for testing. - pub fn with_base_url( - host: impl Into, - api_token: impl Into, - ) -> Result { + /// A client instance that always uses the same token for all requests. + pub fn with_auth(host: impl Into, api_token: impl Into) -> Result { Self::new(host, Some(api_token.into())) } @@ -97,7 +91,7 @@ impl Client { let _ = dotenv(); let api_token = env::var("PHARIA_AI_TOKEN").unwrap(); let inference_url = env::var("INFERENCE_URL").unwrap(); - Self::with_base_url(inference_url, api_token) + Self::with_auth(inference_url, api_token) } /// Execute a task with the aleph alpha API and fetch its result. diff --git a/tests/integration.rs b/tests/integration.rs index 80d5b62..3349242 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -36,7 +36,7 @@ async fn chat_with_pharia_1_7b_base() { let task = TaskChat::with_message(message); let model = "pharia-1-llm-7b-control"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client.chat(&task, model, &How::default()).await.unwrap(); // Then @@ -49,7 +49,7 @@ async fn completion_with_luminous_base() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -67,7 +67,7 @@ async fn request_authentication_has_priority() { let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), bad_pharia_ai_token).unwrap(); + let client = Client::with_auth(inference_url(), bad_pharia_ai_token).unwrap(); let response = client .output_of( &task.with_model(model), @@ -140,7 +140,7 @@ async fn semanitc_search_with_luminous_base() { temperature, traditionally in a wood-fired oven.", ); let query = Prompt::from_text("What is Pizza?"); - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let robot_embedding_task = TaskSemanticEmbedding { @@ -203,7 +203,7 @@ async fn complete_structured_prompt() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -232,7 +232,7 @@ async fn maximum_tokens_none_request() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -253,7 +253,7 @@ async fn explain_request() { target: " How is it going?", granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Sentence), }; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -283,7 +283,7 @@ async fn explain_request_with_auto_granularity() { target: " How is it going?", granularity: Granularity::default(), }; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -315,7 +315,7 @@ async fn explain_request_with_image_modality() { target: " a cat.", granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Paragraph), }; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let response = client @@ -365,7 +365,7 @@ async fn describe_image_starting_from_a_path() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -394,7 +394,7 @@ async fn describe_image_starting_from_a_dyn_image() { sampling: Sampling::MOST_LIKELY, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -420,7 +420,7 @@ async fn only_answer_with_specific_animal() { }, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -447,7 +447,7 @@ async fn answer_should_continue() { }, }; let model = "luminous-base"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -474,7 +474,7 @@ async fn batch_semanitc_embed_with_luminous_base() { temperature, traditionally in a wood-fired oven.", ); - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let embedding_task = TaskBatchSemanticEmbedding { @@ -499,7 +499,7 @@ async fn tokenization_with_luminous_base() { // Given let input = "Hello, World!"; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let task1 = TaskTokenization::new(input, false, true); @@ -536,7 +536,7 @@ async fn detokenization_with_luminous_base() { // Given let input = vec![49222, 15, 5390, 4]; - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let task = TaskDetokenization { token_ids: &input }; @@ -553,7 +553,7 @@ async fn detokenization_with_luminous_base() { #[tokio::test] async fn fetch_tokenizer_for_pharia_1_llm_7b() { // Given - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); // When let tokenizer = client @@ -568,7 +568,7 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() { #[tokio::test] async fn stream_completion() { // Given a streaming completion task - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let task = TaskCompletion::from_text("").with_maximum_tokens(7); // When the events are streamed and collected @@ -601,7 +601,7 @@ async fn stream_completion() { #[tokio::test] async fn stream_chat_with_pharia_1_llm_7b() { // Given a streaming completion task - let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap(); + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); let message = Message::user("Hello,"); let task = TaskChat::with_messages(vec![message]).with_maximum_tokens(7); diff --git a/tests/unit.rs b/tests/unit.rs index e29c6a7..c687803 100644 --- a/tests/unit.rs +++ b/tests/unit.rs @@ -34,7 +34,7 @@ async fn completion_with_luminous_base() { // When let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); let response = client .output_of(&task.with_model(model), &How::default()) .await @@ -74,7 +74,7 @@ async fn detect_rate_limiting() { // When let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); let error = client .output_of(&task.with_model(model), &How::default()) .await @@ -118,7 +118,7 @@ async fn detect_queue_full() { // When let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); let error = client .output_of(&task.with_model(model), &How::default()) .await @@ -155,7 +155,7 @@ async fn detect_service_unavailable() { // When let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); let error = client .output_of(&task.with_model(model), &How::default()) .await @@ -177,7 +177,7 @@ async fn be_nice() { // When let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); // Drop result, answer is meaningless anyway let _ = client .output_of( @@ -206,7 +206,7 @@ async fn client_timeout() { .respond_with(ResponseTemplate::new(StatusCode::OK).set_delay(response_time)) .mount(&mock_server) .await; - let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap(); // When let result = client