Skip to content

Commit

Permalink
fix(ai-proxy): remove nil checks on model and tuning parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
tysoekong committed Jun 18, 2024
1 parent 1814a34 commit 3b9a95b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 37 deletions.
25 changes: 10 additions & 15 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ local transformers_to = {
return nil, nil, err
end

messages.temperature = (model.options and model.options.temperature) or request_table.temperature or nil
messages.max_tokens = (model.options and model.options.max_tokens) or request_table.max_tokens or nil
messages.temperature = (model.options and model.options.temperature) or request_table.temperature
messages.max_tokens = (model.options and model.options.max_tokens) or request_table.max_tokens
messages.model = model.name or request_table.model
messages.stream = request_table.stream or false -- explicitly set this if nil

Expand All @@ -110,8 +110,8 @@ local transformers_to = {
return nil, nil, err
end

prompt.temperature = (model.options and model.options.temperature) or request_table.temperature or nil
prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or request_table.max_tokens or nil
prompt.temperature = (model.options and model.options.temperature) or request_table.temperature
prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or request_table.max_tokens
prompt.model = model.name or request_table.model
prompt.stream = request_table.stream or false -- explicitly set this if nil

Expand Down Expand Up @@ -151,11 +151,9 @@ local function start_to_event(event_data, model_info)

local metadata = {
prompt_tokens = meta.usage
and meta.usage.input_tokens
or nil,
and meta.usage.input_tokens,
completion_tokens = meta.usage
and meta.usage.output_tokens
or nil,
and meta.usage.output_tokens,
model = meta.model,
stop_reason = meta.stop_reason,
stop_sequence = meta.stop_sequence,
Expand Down Expand Up @@ -208,14 +206,11 @@ local function handle_stream_event(event_t, model_info, route_type)
and event_data.usage then
return nil, nil, {
prompt_tokens = nil,
completion_tokens = event_data.usage.output_tokens
or nil,
completion_tokens = event_data.usage.output_tokens,
stop_reason = event_data.delta
and event_data.delta.stop_reason
or nil,
and event_data.delta.stop_reason,
stop_sequence = event_data.delta
and event_data.delta.stop_sequence
or nil,
and event_data.delta.stop_sequence,
}
else
return nil, "message_delta is missing the metadata block", nil
Expand Down Expand Up @@ -266,7 +261,7 @@ local transformers_from = {
prompt_tokens = usage.input_tokens,
completion_tokens = usage.output_tokens,
total_tokens = usage.input_tokens and usage.output_tokens and
usage.input_tokens + usage.output_tokens or nil,
usage.input_tokens + usage.output_tokens,
}

else
Expand Down
33 changes: 13 additions & 20 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,15 @@ local transformers_from = {
local stats = {
completion_tokens = response_table.meta
and response_table.meta.billed_units
and response_table.meta.billed_units.output_tokens
or nil,
and response_table.meta.billed_units.output_tokens,

prompt_tokens = response_table.meta
and response_table.meta.billed_units
and response_table.meta.billed_units.input_tokens
or nil,
and response_table.meta.billed_units.input_tokens,

total_tokens = response_table.meta
and response_table.meta.billed_units
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens)
or nil,
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens),
}
messages.usage = stats

Expand All @@ -252,26 +249,23 @@ local transformers_from = {
local stats = {
completion_tokens = response_table.meta
and response_table.meta.billed_units
and response_table.meta.billed_units.output_tokens
or nil,
and response_table.meta.billed_units.output_tokens,

prompt_tokens = response_table.meta
and response_table.meta.billed_units
and response_table.meta.billed_units.input_tokens
or nil,
and response_table.meta.billed_units.input_tokens,

total_tokens = response_table.meta
and response_table.meta.billed_units
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens)
or nil,
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens),
}
messages.usage = stats

else -- probably a fault
return nil, "'text' or 'generations' missing from cohere response body"

end

return cjson.encode(messages)
end,

Expand Down Expand Up @@ -299,11 +293,10 @@ local transformers_from = {
prompt.id = response_table.id

local stats = {
completion_tokens = response_table.meta and response_table.meta.billed_units.output_tokens or nil,
prompt_tokens = response_table.meta and response_table.meta.billed_units.input_tokens or nil,
completion_tokens = response_table.meta and response_table.meta.billed_units.output_tokens,
prompt_tokens = response_table.meta and response_table.meta.billed_units.input_tokens,
total_tokens = response_table.meta
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens)
or nil,
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens),
}
prompt.usage = stats

Expand All @@ -323,9 +316,9 @@ local transformers_from = {
prompt.id = response_table.generation_id

local stats = {
completion_tokens = response_table.token_count and response_table.token_count.response_tokens or nil,
prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens or nil,
total_tokens = response_table.token_count and response_table.token_count.total_tokens or nil,
completion_tokens = response_table.token_count and response_table.token_count.response_tokens,
prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens,
total_tokens = response_table.token_count and response_table.token_count.total_tokens,
}
prompt.usage = stats

Expand Down
4 changes: 2 additions & 2 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,15 @@ function _M:access(conf)

-- copy from the user request if present
if (not multipart) and (not conf_m.model.name) and (request_table.model) then
if request_table.model ~= cjson.null then
if type(request_table.model) == "string" then
conf_m.model.name = request_table.model
end
elseif multipart then
conf_m.model.name = "NOT_SPECIFIED"
end

-- check that the user isn't trying to override the plugin conf model in the request body
if request_table and request_table.model and type(request_table.model) == "string" then
if request_table and request_table.model and type(request_table.model) == "string" and request_table.model ~= "" then
if request_table.model ~= conf_m.model.name then
return bad_request("cannot use own model - must be: " .. conf_m.model.name)
end
Expand Down

0 comments on commit 3b9a95b

Please sign in to comment.