Skip to content

Commit

Permalink
OpenAI connector: send default model for "other" openAI provider (ela…
Browse files Browse the repository at this point in the history
…stic#204934)

## Summary

Part of elastic#204116

When model is not present in the payload, use the default model as
specified in the connector configuration.

We were already doing that for OpenAI-OpenAI, but not for
"Other"-OpenAI.

### Some section because I downloaded ollama just for that issue

<img width="950" alt="Screenshot 2024-12-19 at 13 53 48"
src="https://github.com/user-attachments/assets/4a6e4b35-a0c5-46e5-9372-677e99d070f8"
/>

<img width="769" alt="Screenshot 2024-12-19 at 13 54 54"
src="https://github.com/user-attachments/assets/a0a5a12a-ea1e-42b7-8fa1-6531bef5ae6c"
/>
  • Loading branch information
pgayvallet authored Dec 23, 2024
1 parent 96264d2 commit d4bc9be
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,42 @@ describe('Other (OpenAI Compatible Service) Utils', () => {
const sanitizedBodyString = getRequestWithStreamOption(bodyString, false);
expect(sanitizedBodyString).toEqual(bodyString);
});

it('sets model parameter if specified and not present in the body', () => {
const body = {
messages: [
{
role: 'user',
content: 'This is a test',
},
],
};

const sanitizedBodyString = getRequestWithStreamOption(JSON.stringify(body), true, 'llama-3');
expect(JSON.parse(sanitizedBodyString)).toEqual({
messages: [{ content: 'This is a test', role: 'user' }],
model: 'llama-3',
stream: true,
});
});

it('does not overrides model parameter if present in the body', () => {
const body = {
model: 'mistral',
messages: [
{
role: 'user',
content: 'This is a test',
},
],
};

const sanitizedBodyString = getRequestWithStreamOption(JSON.stringify(body), true, 'llama-3');
expect(JSON.parse(sanitizedBodyString)).toEqual({
messages: [{ content: 'This is a test', role: 'user' }],
model: 'mistral',
stream: true,
});
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@ export const sanitizeRequest = (body: string): string => {
* The stream parameter is accepted in the ChatCompletion
* API and the Completion API only
*/
export const getRequestWithStreamOption = (body: string, stream: boolean): string => {
export const getRequestWithStreamOption = (
body: string,
stream: boolean,
defaultModel?: string
): string => {
try {
const jsonBody = JSON.parse(body);
if (jsonBody) {
jsonBody.stream = stream;
}

if (defaultModel && !jsonBody.model) {
jsonBody.model = defaultModel;
}
return JSON.stringify(jsonBody);
} catch (err) {
// swallow the error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,19 @@ describe('Utils', () => {
});

it('calls other_openai_utils getRequestWithStreamOption when provider is Other OpenAi', () => {
getRequestWithStreamOption(OpenAiProviderType.Other, OPENAI_CHAT_URL, bodyString, true);
getRequestWithStreamOption(
OpenAiProviderType.Other,
OPENAI_CHAT_URL,
bodyString,
true,
'default-model'
);

expect(mockOtherOpenAiGetRequestWithStreamOption).toHaveBeenCalledWith(bodyString, true);
expect(mockOtherOpenAiGetRequestWithStreamOption).toHaveBeenCalledWith(
bodyString,
true,
'default-model'
);
expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
expect(mockAzureAiGetRequestWithStreamOption).not.toHaveBeenCalled();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ export function getRequestWithStreamOption(
case OpenAiProviderType.AzureAi:
return azureAiGetRequestWithStreamOption(url, body, stream);
case OpenAiProviderType.Other:
return otherOpenAiGetRequestWithStreamOption(body, stream);
return otherOpenAiGetRequestWithStreamOption(body, stream, defaultModel);
default:
return body;
}
Expand Down

0 comments on commit d4bc9be

Please sign in to comment.