Skip to content

Commit

Permalink
feat!: add option to ask for special tokens in completion response
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse authored Dec 10, 2024
1 parent edd5590 commit 1dbcb77
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 15 deletions.
49 changes: 34 additions & 15 deletions src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub struct TaskCompletion<'a> {
pub stopping: Stopping<'a>,
/// Sampling controls how the tokens ("words") are selected for the completion.
pub sampling: Sampling<'a>,
/// Whether to include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
pub special_tokens: bool,
}

impl<'a> TaskCompletion<'a> {
Expand All @@ -20,6 +22,7 @@ impl<'a> TaskCompletion<'a> {
prompt: Prompt::from_text(text),
stopping: Stopping::NO_TOKEN_LIMIT,
sampling: Sampling::MOST_LIKELY,
special_tokens: false,
}
}

Expand All @@ -32,6 +35,12 @@ impl<'a> TaskCompletion<'a> {
self.stopping.stop_sequences = stop_sequences;
self
}

/// Include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
pub fn with_special_tokens(mut self) -> Self {
self.special_tokens = true;
self
}
}

/// Sampling controls how the tokens ("words") are selected for the completion.
Expand Down Expand Up @@ -144,6 +153,13 @@ struct BodyCompletion<'a> {
/// If true, the response will be streamed.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
/// Forces the raw completion of the model to be returned.
/// For some models, the completion that was generated by the model may be optimized and
/// returned in the completion field of the CompletionResponse.
/// The raw completion, if returned, will contain the un-optimized completion.
/// Setting tokens to true or log_probs to any value will also trigger the raw completion to be returned.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub raw_completion: bool,
}

impl<'a> BodyCompletion<'a> {
Expand All @@ -158,6 +174,7 @@ impl<'a> BodyCompletion<'a> {
top_p: task.sampling.top_p,
completion_bias_inclusion: task.sampling.complete_with_one_of,
stream: false,
raw_completion: task.special_tokens,
}
}
pub fn with_streaming(mut self) -> Self {
Expand All @@ -168,22 +185,15 @@ impl<'a> BodyCompletion<'a> {

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ResponseCompletion {
pub model_version: String,
pub completions: Vec<CompletionOutput>,
model_version: String,
completions: Vec<DeserializedCompletion>,
}

impl ResponseCompletion {
/// The best completion in the answer.
pub fn best(&self) -> &CompletionOutput {
self.completions
.first()
.expect("Response is assumed to always have at least one completion")
}

/// Text of the best completion.
pub fn best_text(&self) -> &str {
&self.best().completion
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
struct DeserializedCompletion {
completion: String,
finish_reason: String,
raw_completion: Option<String>,
}

/// Completion and metainformation returned by a completion task
Expand All @@ -209,7 +219,16 @@ impl Task for TaskCompletion<'_> {
}

fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
response.completions.pop().unwrap()
let deserialized = response.completions.pop().unwrap();
let completion = if self.special_tokens {
deserialized.raw_completion.unwrap()
} else {
deserialized.completion
};
CompletionOutput {
completion,
finish_reason: deserialized.finish_reason,
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ impl Client {
/// prompt: prompt.clone(),
/// stopping: Stopping::from_maximum_tokens(10),
/// sampling: Sampling::MOST_LIKELY,
/// special_tokens: false,
/// };
/// let response = client.completion(&task, model, &How::default()).await?;
///
Expand Down
1 change: 1 addition & 0 deletions src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl<'a> Modality<'a> {
/// ]),
/// stopping: Stopping::from_maximum_tokens(10),
/// sampling: Sampling::MOST_LIKELY,
/// special_tokens: false,
/// };
/// // Execute
/// let model = "luminous-base";
Expand Down
28 changes: 28 additions & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,28 @@ async fn completion_with_luminous_base() {
assert!(!response.completion.is_empty())
}

#[tokio::test]
async fn raw_completion_includes_python_tag() {
// When
let task = TaskCompletion::from_text(
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|>
Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
)
.with_maximum_tokens(30)
.with_special_tokens();

let model = "llama-3.1-8b-instruct";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
.unwrap();
assert!(response.completion.trim().starts_with("<|python_tag|>"));
}

#[tokio::test]
async fn request_authentication_has_priority() {
let bad_pharia_ai_token = "DUMMY";
Expand Down Expand Up @@ -201,6 +223,7 @@ async fn complete_structured_prompt() {
stop_sequences: &stop_sequences[..],
},
sampling: Sampling::MOST_LIKELY,
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand Down Expand Up @@ -230,6 +253,7 @@ async fn maximum_tokens_none_request() {
prompt: Prompt::from_text(prompt),
stopping,
sampling: Sampling::MOST_LIKELY,
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand Down Expand Up @@ -363,6 +387,7 @@ async fn describe_image_starting_from_a_path() {
]),
stopping: Stopping::from_maximum_tokens(10),
sampling: Sampling::MOST_LIKELY,
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand Down Expand Up @@ -392,6 +417,7 @@ async fn describe_image_starting_from_a_dyn_image() {
]),
stopping: Stopping::from_maximum_tokens(10),
sampling: Sampling::MOST_LIKELY,
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand All @@ -418,6 +444,7 @@ async fn only_answer_with_specific_animal() {
complete_with_one_of: &[" dog"],
..Default::default()
},
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand All @@ -444,6 +471,7 @@ async fn answer_should_continue() {
complete_with_one_of: &[" Says.", " Art.", " Weekend."],
..Default::default()
},
special_tokens: false,
};
let model = "luminous-base";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
Expand Down

0 comments on commit 1dbcb77

Please sign in to comment.