From 79df6d341d22e4f67ad8be2e8fafc196c8590309 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Mon, 9 Dec 2024 23:22:08 +0100 Subject: [PATCH] feat: add RawCompletionTask to retrieve unoptimized completion --- src/completion.rs | 55 ++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 +- tests/integration.rs | 25 ++++++++++++++++++++ 3 files changed, 81 insertions(+), 1 deletion(-) diff --git a/src/completion.rs b/src/completion.rs index b1aedb5..c3b4566 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -32,8 +32,17 @@ impl<'a> TaskCompletion<'a> { self.stopping.stop_sequences = stop_sequences; self } + pub fn with_raw_completion(self) -> TaskRawCompletion<'a> { + TaskRawCompletion(self) + } } +/// Completes a prompt and returns the raw (non-optimized) completion of the model. +/// +/// Closely related to [`TaskCompletion`], but returns the raw completion of the model. +/// You can build a [`TaskRawCompletion`] by calling [`with_raw_completion`] on a [`TaskCompletion`]. +pub struct TaskRawCompletion<'a>(TaskCompletion<'a>); + /// Sampling controls how the tokens ("words") are selected for the completion. pub struct Sampling<'a> { /// A temperature encourages the model to produce less probable outputs ("be more creative"). @@ -145,6 +154,13 @@ struct BodyCompletion<'a> { /// If true, the response will be streamed. #[serde(skip_serializing_if = "std::ops::Not::not")] pub stream: bool, + /// Forces the raw completion of the model to be returned. + /// For some models, the completion that was generated by the model may be optimized and + /// returned in the completion field of the CompletionResponse. + /// The raw completion, if returned, will contain the un-optimized completion. + /// Setting tokens to true or log_probs to any value will also trigger the raw completion to be returned. + #[serde(skip_serializing_if = "std::ops::Not::not")] + pub raw_completion: bool, } impl<'a> BodyCompletion<'a> { @@ -159,12 +175,17 @@ impl<'a> BodyCompletion<'a> { top_p: task.sampling.top_p, completion_bias_inclusion: task.sampling.start_with_one_of, stream: false, + raw_completion: false, } } pub fn with_streaming(mut self) -> Self { self.stream = true; self } + pub fn with_raw_completion(mut self) -> Self { + self.raw_completion = true; + self + } } #[derive(Deserialize, Debug, PartialEq, Eq)] @@ -187,6 +208,12 @@ impl ResponseCompletion { } } +#[derive(Deserialize)] +pub struct RawCompletionResponse { + pub model_version: String, + pub completions: Vec, +} + /// Completion and metainformation returned by a completion task #[derive(Deserialize, Debug, PartialEq, Eq)] pub struct CompletionOutput { @@ -214,6 +241,34 @@ impl Task for TaskCompletion<'_> { } } +/// Completion and metainformation returned by a completion task +#[derive(Deserialize, Debug, PartialEq, Eq)] +pub struct RawCompletionOutput { + pub completion: String, + pub raw_completion: String, + pub finish_reason: String, +} + +impl Task for TaskRawCompletion<'_> { + type Output = RawCompletionOutput; + + type ResponseBody = RawCompletionResponse; + + fn build_request( + &self, + client: &reqwest::Client, + base: &str, + model: &str, + ) -> reqwest::RequestBuilder { + let body = BodyCompletion::new(model, &self.0).with_raw_completion(); + client.post(format!("{base}/complete")).json(&body) + } + + fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output { + response.completions.pop().unwrap() + } +} + /// Describes a chunk of a completion stream #[derive(Deserialize, Debug)] pub struct StreamChunk { diff --git a/src/lib.rs b/src/lib.rs index 21ed9b5..8beb9e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,7 +45,7 @@ pub use self::{ chat::{ChatEvent, ChatStreamChunk}, chat::{ChatOutput, Message, TaskChat}, completion::{CompletionEvent, CompletionSummary, StreamChunk, StreamSummary}, - completion::{CompletionOutput, Sampling, Stopping, TaskCompletion}, + completion::{CompletionOutput, Sampling, Stopping, TaskCompletion, TaskRawCompletion}, detokenization::{DetokenizationOutput, TaskDetokenization}, explanation::{ Explanation, ExplanationOutput, Granularity, ImageScore, ItemExplanation, diff --git a/tests/integration.rs b/tests/integration.rs index 3349242..7774c13 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -61,6 +61,31 @@ async fn completion_with_luminous_base() { assert!(!response.completion.is_empty()) } +#[tokio::test] +async fn raw_completion_includes_python_tag() { + // When + let task = TaskCompletion::from_text( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|> + +Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>", + ) + .with_maximum_tokens(30) + .with_raw_completion(); + + let model = "llama-3.1-8b-instruct"; + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); + let response = client + .output_of(&task.with_model(model), &How::default()) + .await + .unwrap(); + dbg!(&response.completion); + dbg!(&response.raw_completion); + assert!(response.completion.trim().starts_with("def")); + assert!(response.raw_completion.trim().starts_with("<|python_tag|>")); +} + #[tokio::test] async fn request_authentication_has_priority() { let bad_pharia_ai_token = "DUMMY";