Skip to content

Commit

Permalink
camel case inferencing options
Browse files Browse the repository at this point in the history
Signed-off-by: karthik2804 <[email protected]>
  • Loading branch information
karthik2804 committed Sep 4, 2023
1 parent b6bafbf commit bfb49d5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
13 changes: 6 additions & 7 deletions crates/spin-js-engine/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ fn postgres_query(context: &Context, _this: &Value, args: &[Value]) -> Result<Va
}

fn map_inferencing_model_name(name: &str) -> llm::InferencingModel {
match name{
match name {
"llama2-chat" => llm::InferencingModel::Llama2Chat,
"codellama-instruct" => llm::InferencingModel::CodellamaInstruct,
_ => llm::InferencingModel::Other(name),
Expand All @@ -1057,7 +1057,7 @@ fn llm_run_with_defaults(context: &Context, _this: &Value, args: &[Value]) -> Re
[model, prompt] => {
let model = deserialize_helper(model)?;
let prompt = deserialize_helper(prompt)?;
let llm_model= map_inferencing_model_name(model.as_str());
let llm_model = map_inferencing_model_name(model.as_str());
let inference_result = llm::infer(llm_model, &prompt);
match inference_result {
Ok(val) => {
Expand All @@ -1074,7 +1074,7 @@ fn llm_run_with_defaults(context: &Context, _this: &Value, args: &[Value]) -> Re
)?;
ret.set_property("usage", usage)?;
Ok(ret)
},
}
Err(err) => Err(anyhow!(err)),
}
}
Expand All @@ -1097,7 +1097,7 @@ fn llm_inference_with_options(context: &Context, _this: &Value, args: &[Value])
[model, prompt, options] => {
let model = deserialize_helper(model)?;
let prompt = deserialize_helper(prompt)?;
let llm_model= map_inferencing_model_name(model.as_str());
let llm_model = map_inferencing_model_name(model.as_str());
let options_deserializer = &mut Deserializer::from(options.clone());
let options = InferencingOption::deserialize(options_deserializer)?;
let llm_options = llm::InferencingParams {
Expand All @@ -1108,8 +1108,7 @@ fn llm_inference_with_options(context: &Context, _this: &Value, args: &[Value])
top_k: options.top_k,
top_p: options.top_p,
};
let inference_result =
llm::infer_with_options(llm_model, &prompt, llm_options);
let inference_result = llm::infer_with_options(llm_model, &prompt, llm_options);
match inference_result {
Ok(val) => {
let ret = context.object_value()?;
Expand All @@ -1134,7 +1133,7 @@ fn llm_inference_with_options(context: &Context, _this: &Value, args: &[Value])
}

fn map_embedding_model_name(name: &str) -> llm::EmbeddingModel {
match name{
match name {
"all-minilm-l6-v2" => llm::EmbeddingModel::AllMiniLmL6V2,
_ => llm::EmbeddingModel::Other(name),
}
Expand Down
Binary file added spin-sdk.tar.gz
Binary file not shown.
23 changes: 16 additions & 7 deletions spin-sdk/src/modules/spinSdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ interface RdbmsReturn {
}

interface InferencingOptions {
maxTokens?: number,
repeatPenalty?: number,
repeatPenaltyLastNTokenCount?: number,
temperature?: number,
topK?: number,
topP?: number
}

interface InternalInferencingOptions {
max_tokens?: number,
repeat_penalty?: number,
repeat_penalty_last_n_token_count?: number,
Expand Down Expand Up @@ -112,7 +121,7 @@ interface SpinSdk {
}
llm: {
infer: (model: InferencingModels | string, prompt: string) => InferenceResult
inferWithOptions: (model: InferencingModels | string, prompt: string, options: InferencingOptions) => InferenceResult
inferWithOptions: (model: InferencingModels | string, prompt: string, options: InternalInferencingOptions) => InferenceResult
generateEmbeddings: (model: EmbeddingModels | string, sentences: Array<string>) => EmbeddingResult
}
}
Expand Down Expand Up @@ -216,13 +225,13 @@ const Llm = {
if (!options) {
return __internal__.spin_sdk.llm.infer(model, prompt)
}
let inference_options: InferencingOptions = {
max_tokens: options.max_tokens || 100,
repeat_penalty: options.repeat_penalty || 1.1,
repeat_penalty_last_n_token_count: options.repeat_penalty_last_n_token_count || 64,
let inference_options: InternalInferencingOptions = {
max_tokens: options.maxTokens || 100,
repeat_penalty: options.repeatPenalty || 1.1,
repeat_penalty_last_n_token_count: options.repeatPenaltyLastNTokenCount || 64,
temperature: options.temperature || 0.8,
top_k: options.top_k || 40,
top_p: options.top_p || 0.9
top_k: options.topK || 40,
top_p: options.topP || 0.9
}
return __internal__.spin_sdk.llm.inferWithOptions(model, prompt, inference_options)
},
Expand Down

0 comments on commit bfb49d5

Please sign in to comment.