Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raw completion #42

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub struct TaskCompletion<'a> {
pub stopping: Stopping<'a>,
/// Sampling controls how the tokens ("words") are selected for the completion.
pub sampling: Sampling<'a>,
/// Whether to include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
pub special_tokens: bool,
}

impl<'a> TaskCompletion<'a> {
Expand All @@ -20,6 +22,7 @@ impl<'a> TaskCompletion<'a> {
prompt: Prompt::from_text(text),
stopping: Stopping::NO_TOKEN_LIMIT,
sampling: Sampling::MOST_LIKELY,
special_tokens: false,
}
}

Expand All @@ -32,6 +35,12 @@ impl<'a> TaskCompletion<'a> {
self.stopping.stop_sequences = stop_sequences;
self
}

/// Include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
pub fn with_special_tokens(mut self) -> Self {
self.special_tokens = true;
self
}
}

/// Sampling controls how the tokens ("words") are selected for the completion.
Expand Down Expand Up @@ -144,6 +153,13 @@ struct BodyCompletion<'a> {
/// If true, the response will be streamed.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
/// Forces the raw completion of the model to be returned.
/// For some models, the completion that was generated by the model may be optimized and
/// returned in the completion field of the CompletionResponse.
/// The raw completion, if returned, will contain the un-optimized completion.
/// Setting tokens to true or log_probs to any value will also trigger the raw completion to be returned.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub raw_completion: bool,
}

impl<'a> BodyCompletion<'a> {
Expand All @@ -158,6 +174,7 @@ impl<'a> BodyCompletion<'a> {
top_p: task.sampling.top_p,
completion_bias_inclusion: task.sampling.complete_with_one_of,
stream: false,
raw_completion: task.special_tokens,
}
}
pub fn with_streaming(mut self) -> Self {
Expand All @@ -168,22 +185,15 @@ impl<'a> BodyCompletion<'a> {

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ResponseCompletion {
pub model_version: String,
pub completions: Vec<CompletionOutput>,
model_version: String,
completions: Vec<DeserializedCompletion>,
}

impl ResponseCompletion {
/// The best completion in the answer.
pub fn best(&self) -> &CompletionOutput {
self.completions
.first()
.expect("Response is assumed to always have at least one completion")
}

/// Text of the best completion.
pub fn best_text(&self) -> &str {
&self.best().completion
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
struct DeserializedCompletion {
completion: String,
finish_reason: String,
raw_completion: Option<String>,
}

/// Completion and metainformation returned by a completion task
Expand All @@ -209,7 +219,16 @@ impl Task for TaskCompletion<'_> {
}

fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
response.completions.pop().unwrap()
let deserialized = response.completions.pop().unwrap();
let completion = if self.special_tokens {
deserialized.raw_completion.unwrap()
} else {
deserialized.completion
};
CompletionOutput {
completion,
finish_reason: deserialized.finish_reason,
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ impl Client {
/// prompt: prompt.clone(),
/// stopping: Stopping::from_maximum_tokens(10),
/// sampling: Sampling::MOST_LIKELY,
/// special_tokens: false,
/// };
/// let response = client.completion(&task, model, &How::default()).await?;
///
Expand Down
1 change: 1 addition & 0 deletions src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl<'a> Modality<'a> {
/// ]),
/// stopping: Stopping::from_maximum_tokens(10),
/// sampling: Sampling::MOST_LIKELY,
/// special_tokens: false,
/// };
/// // Execute
/// let model = "luminous-base";
Expand Down
28 changes: 28 additions & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,28 @@ async fn completion_with_luminous_base() {
assert!(!response.completion.is_empty())
}

#[tokio::test]
async fn raw_completion_includes_python_tag() {
// When
let task = TaskCompletion::from_text(
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|>

Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
)
.with_maximum_tokens(30)
.with_special_tokens();

let model = "llama-3.1-8b-instruct";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
.unwrap();
assert!(response.completion.trim().starts_with("<|python_tag|>"));
}

#[tokio::test]
async fn request_authentication_has_priority() {
let bad_pharia_ai_token = "DUMMY";
Expand Down Expand Up @@ -201,6 +223,7 @@ async fn complete_structured_prompt() {
stop_sequences: &stop_sequences[..],
},
sampling: Sampling::MOST_LIKELY,
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand Down Expand Up @@ -230,6 +253,7 @@ async fn maximum_tokens_none_request() {
prompt: Prompt::from_text(prompt),
stopping,
sampling: Sampling::MOST_LIKELY,
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand Down Expand Up @@ -363,6 +387,7 @@ async fn describe_image_starting_from_a_path() {
]),
stopping: Stopping::from_maximum_tokens(10),
sampling: Sampling::MOST_LIKELY,
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand Down Expand Up @@ -392,6 +417,7 @@ async fn describe_image_starting_from_a_dyn_image() {
]),
stopping: Stopping::from_maximum_tokens(10),
sampling: Sampling::MOST_LIKELY,
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand All @@ -418,6 +444,7 @@ async fn only_answer_with_specific_animal() {
complete_with_one_of: &[" dog"],
..Default::default()
},
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand All @@ -444,6 +471,7 @@ async fn answer_should_continue() {
complete_with_one_of: &[" Says.", " Art.", " Weekend."],
..Default::default()
},
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand Down
Loading