Skip to content

Commit

Permalink
feat: add RawCompletionTask to retrieve unoptimized completion
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse committed Dec 9, 2024
1 parent b850f78 commit 7d82ece
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
55 changes: 55 additions & 0 deletions src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,17 @@ impl<'a> TaskCompletion<'a> {
self.stopping.stop_sequences = stop_sequences;
self
}
pub fn with_raw_completion(self) -> TaskRawCompletion<'a> {
TaskRawCompletion(self)
}
}

/// Completes a prompt and returns the raw (non-optimized) completion of the model.
///
/// Closely related to [`TaskCompletion`], but returns the raw completion of the model.
/// You can build a [`TaskRawCompletion`] by calling [`with_raw_completion`] on a [`TaskCompletion`].
pub struct TaskRawCompletion<'a>(TaskCompletion<'a>);

/// Sampling controls how the tokens ("words") are selected for the completion.
pub struct Sampling<'a> {
/// A temperature encourages the model to produce less probable outputs ("be more creative").
Expand Down Expand Up @@ -145,6 +154,13 @@ struct BodyCompletion<'a> {
/// If true, the response will be streamed.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
/// Forces the raw completion of the model to be returned.
/// For some models, the completion that was generated by the model may be optimized and
/// returned in the completion field of the CompletionResponse.
/// The raw completion, if returned, will contain the un-optimized completion.
/// Setting tokens to true or log_probs to any value will also trigger the raw completion to be returned.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub raw_completion: bool,
}

impl<'a> BodyCompletion<'a> {
Expand All @@ -159,12 +175,17 @@ impl<'a> BodyCompletion<'a> {
top_p: task.sampling.top_p,
completion_bias_inclusion: task.sampling.start_with_one_of,
stream: false,
raw_completion: false,
}
}
pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
pub fn with_raw_completion(mut self) -> Self {
self.raw_completion = true;
self
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
Expand All @@ -187,6 +208,12 @@ impl ResponseCompletion {
}
}

#[derive(Deserialize)]
pub struct RawCompletionResponse {
pub model_version: String,
pub completions: Vec<RawCompletionOutput>,
}

/// Completion and metainformation returned by a completion task
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct CompletionOutput {
Expand Down Expand Up @@ -214,6 +241,34 @@ impl Task for TaskCompletion<'_> {
}
}

/// Completion and metainformation returned by a completion task
#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct RawCompletionOutput {
pub completion: String,
pub raw_completion: String,
pub finish_reason: String,
}

impl Task for TaskRawCompletion<'_> {
type Output = RawCompletionOutput;

type ResponseBody = RawCompletionResponse;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = BodyCompletion::new(model, &self.0).with_raw_completion();
client.post(format!("{base}/complete")).json(&body)
}

fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
response.completions.pop().unwrap()
}
}

/// Describes a chunk of a completion stream
#[derive(Deserialize, Debug)]
pub struct StreamChunk {
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub use self::{
chat::{ChatEvent, ChatStreamChunk},
chat::{ChatOutput, Message, TaskChat},
completion::{CompletionEvent, CompletionSummary, StreamChunk, StreamSummary},
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion},
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion, TaskRawCompletion},
detokenization::{DetokenizationOutput, TaskDetokenization},
explanation::{
Explanation, ExplanationOutput, Granularity, ImageScore, ItemExplanation,
Expand Down
25 changes: 25 additions & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,31 @@ async fn completion_with_luminous_base() {
assert!(!response.completion.is_empty())
}

#[tokio::test]
async fn raw_completion_includes_python_tag() {
// When
let task = TaskCompletion::from_text(
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|>
Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
)
.with_maximum_tokens(30)
.with_raw_completion();

let model = "llama-3.1-8b-instruct";
let client = Client::with_base_url(base_url(), api_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
.unwrap();
dbg!(&response.completion);
dbg!(&response.raw_completion);
assert!(response.completion.trim().starts_with("def"));
assert!(response.raw_completion.trim().starts_with("<|python_tag|>"));
}

#[tokio::test]
async fn request_authentication_has_priority() {
let bad_aa_api_token = "DUMMY";
Expand Down

0 comments on commit 7d82ece

Please sign in to comment.