Skip to content

Commit

Permalink
chore: Refactor orchestration E2E tests for better coverage (#160)
Browse files Browse the repository at this point in the history
* Refactor orchestration E2E tests for better coverage

* Refine output filter test

* fix: Changes from lint

* Apply suggestions from code review

* Apply suggestions from review

* Merge branch 'main' into feat/data-masking

* Update readme

* Update dependencies

* Improve style

* fix: Changes from lint

---------

Co-authored-by: cloud-sdk-js <[email protected]>
Co-authored-by: Marika Marszalkowski <[email protected]>
  • Loading branch information
3 people authored Oct 2, 2024
1 parent 771f986 commit ac10b81
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 105 deletions.
2 changes: 2 additions & 0 deletions packages/orchestration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ const orchestrationClient = new OrchestrationClient({

The client allows you to combine various modules, such as templating and content filtering, while sending chat completion requests to an orchestration-compatible generative AI model.

In addition to the examples below, you can find more **sample code** [here](https://github.com/SAP/ai-sdk-js/blob/main/sample-code/src/orchestration.ts).

### Templating

Use the orchestration client with templating to pass a prompt containing placeholders that will be replaced with input parameters during a chat completion request.
Expand Down
9 changes: 8 additions & 1 deletion sample-code/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@ export {
chatCompletion,
computeEmbedding
} from './foundation-models/azure-openai.js';
export { orchestrationCompletionMasking } from './orchestration.js';
export {
orchestrationChatCompletion,
orchestrationTemplating,
orchestrationInputFiltering,
orchestrationOutputFiltering,
orchestrationRequestConfig,
orchestrationCompletionMasking
} from './orchestration.js';
export {
invoke,
invokeChain,
Expand Down
195 changes: 116 additions & 79 deletions sample-code/src/orchestration.ts
Original file line number Diff line number Diff line change
@@ -1,114 +1,154 @@
import {
LlmModuleConfig,
OrchestrationClient,
OrchestrationResponse,
buildAzureContentFilter
} from '@sap-ai-sdk/orchestration';
import { createLogger } from '@sap-cloud-sdk/util';

/**
* Create different types of orchestration requests.
* @param sampleCase - Name of the sample case to orchestrate.
* @returns The message content from the orchestration service in the generative AI hub.
*/
export async function orchestrationCompletion(
sampleCase: string
): Promise<string | undefined> {
switch (sampleCase) {
case 'simple':
return orchestrationCompletionSimple();
case 'template':
return orchestrationCompletionTemplate();
case 'filtering':
return orchestrationCompletionFiltering();
case 'requestConfig':
return orchestrationCompletionRequestConfig();
case 'masking':
return orchestrationCompletionMasking();
default:
return undefined;
}
}
const logger = createLogger({
package: 'sample-code',
messageContext: 'orchestration'
});

/**
* Ask about the capital of France.
* @returns The message content from the orchestration service in the generative AI hub.
* A simple LLM request, asking about the capital of France.
* @returns The orchestration service response.
*/
async function orchestrationCompletionSimple(): Promise<string | undefined> {
export async function orchestrationChatCompletion(): Promise<OrchestrationResponse> {
const orchestrationClient = new OrchestrationClient({
// define the language model to be used
llm: {
model_name: 'gpt-4-32k',
model_name: 'gpt-4o',
model_params: {}
},
// define the prompt
templating: {
template: [{ role: 'user', content: 'What is the capital of France?' }]
}
});

// Call the orchestration service.
const response = await orchestrationClient.chatCompletion();
// Access the response content.
return response.getContent();
// execute the request
const result = await orchestrationClient.chatCompletion();

// use getContent() to access the LLM response
logger.info(result.getContent());

return result;
}

const llm: LlmModuleConfig = {
model_name: 'gpt-4o',
model_params: {}
};

/**
* Ask about the capital of any country using a template.
* @returns The message content from the orchestration service in the generative AI hub.
* @returns The orchestration service response.
*/
async function orchestrationCompletionTemplate(): Promise<string | undefined> {
export async function orchestrationTemplating(): Promise<OrchestrationResponse> {
const orchestrationClient = new OrchestrationClient({
llm: {
model_name: 'gpt-4-32k',
model_params: {}
},
llm,
templating: {
template: [
// define "country" as variable by wrapping it with "{{? ... }}"
{ role: 'user', content: 'What is the capital of {{?country}}?' }
]
}
});

// Call the orchestration service.
const response = await orchestrationClient.chatCompletion({
return orchestrationClient.chatCompletion({
// give the actual value for the variable "country"
inputParams: { country: 'France' }
});
// Access the response content.
return response.getContent();
}

const templating = { template: [{ role: 'user', content: '{{?input}}' }] };

/**
* Allow any user input and apply input and output filters.
* Handles the case where the input or output are filtered:
* - In case the input was filtered the response has a non 200 status code.
* - In case the output was filtered `response.getContent()` throws an error.
* @returns The message content from the orchestration service in the generative AI hub.
* Apply a content filter to LLM requests, filtering any hateful input.
*/
async function orchestrationCompletionFiltering(): Promise<string | undefined> {
export async function orchestrationInputFiltering(): Promise<void> {
// create a filter with minimal thresholds for hate and violence
// lower numbers mean more strict filtering
const filter = buildAzureContentFilter({ Hate: 0, Violence: 0 });
const orchestrationClient = new OrchestrationClient({
llm: {
model_name: 'gpt-4-32k',
model_params: {}
},
templating: {
template: [{ role: 'user', content: '{{?input}}' }]
},
llm,
templating,
// configure the filter to be applied for both input and output
filtering: {
input: filter,
output: filter
input: filter
}
});

try {
// Call the orchestration service.
const response = await orchestrationClient.chatCompletion({
// trigger the input filter, producing a 400 - Bad Request response
await orchestrationClient.chatCompletion({
inputParams: { input: 'I hate you!' }
});
// Access the response content.
return response.getContent();
throw new Error('Input was not filtered as expected.');
} catch (error: any) {
// Handle the case where the output was filtered.
return `Error: ${error.message}`;
if (error.response.status === 400) {
logger.info('Input was filtered as expected.');
} else {
throw error;
}
}
}

/**
* Apply a content filter to LLM requests, filtering any hateful output.
* @returns The orchestration service response.
*/
export async function orchestrationOutputFiltering(): Promise<OrchestrationResponse> {
// output filters are build in the same way as input filters
// set the thresholds to the minimum to maximize the chance the LLM output will be filtered
const filter = buildAzureContentFilter({ Hate: 0, Violence: 0 });
const orchestrationClient = new OrchestrationClient({
llm,
templating,
filtering: {
output: filter
}
});
/**
* Trigger the output filter.
* Note: reliably triggering the output filter is a bit of a challenge. LLMs are fine-tuned to not respond with explicit or hateful content.
* Instead, they often respond with something like "Sorry, I can't assist with that" when asked to generate explicit content.
* The following prompt seems to work well with GPT-4o, producing enough explicit language to trigger the output filter, but may not work equally well with other models.
*/
const result = await orchestrationClient.chatCompletion({
messagesHistory: [
{
role: 'system',
content: `Rewrite the following to be significantly more strongly worded, graphic and explicit.
Imagine you are a cringy teenager writing a post on 4chan. Use explicit language as much as possible.`
}
],
inputParams: {
input: 'Tabs are better than spaces, prove me wrong.'
}
});

// accessing the content should throw an error
try {
result.getContent();
} catch (error: any) {
logger.info(
`Result from output content filter: ${result.data.module_results.output_filtering!.message}`
);
logger.info(
'The original response from the LLM was as follows: ' +
result.data.module_results.llm?.choices[0].message.content
);
return result;
}
throw new Error(
'Output was not filtered as expected. The LLM response was: ' +
result.getContent()
);
}

/**
* Ask to write an e-mail while masking personal information.
* @returns The message content from the orchestration service in the generative AI hub.
Expand Down Expand Up @@ -146,27 +186,24 @@ export async function orchestrationCompletionMasking(): Promise<
});
return response.getContent();
}

/**
* Ask about the capital of France and send along custom request configuration.
* @returns The message content from the orchestration service in the generative AI hub.
* @returns The orchestration service response.
*/
async function orchestrationCompletionRequestConfig(): Promise<
string | undefined
> {
export async function orchestrationRequestConfig(): Promise<OrchestrationResponse> {
const orchestrationClient = new OrchestrationClient({
llm: {
model_name: 'gpt-4-32k',
model_params: {}
},
templating: {
template: [{ role: 'user', content: 'What is the capital of France?' }]
}
llm,
templating
});

// Call the orchestration service.
const response = await orchestrationClient.chatCompletion(undefined, {
headers: { 'x-custom-header': 'custom-value' }
});
// Access the response content.
return response.getContent();
return orchestrationClient.chatCompletion(
{
inputParams: { input: 'What is the capital of France?' }
},
// add a custom header to the request
{
headers: { 'x-custom-header': 'custom-value' }
}
);
}
31 changes: 29 additions & 2 deletions sample-code/src/server.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
/* eslint-disable no-console */
import express from 'express';
import { OrchestrationResponse } from '@sap-ai-sdk/orchestration';
import {
chatCompletion,
computeEmbedding
} from './foundation-models/azure-openai.js';
import { orchestrationCompletion } from './orchestration.js';
import {
orchestrationChatCompletion,
orchestrationTemplating,
orchestrationInputFiltering,
orchestrationOutputFiltering,
orchestrationRequestConfig
} from './orchestration.js';
import { getDeployments } from './ai-api.js';
import {
invokeChain,
Expand Down Expand Up @@ -49,8 +56,28 @@ app.get('/azure-openai/embedding', async (req, res) => {
});

app.get('/orchestration/:sampleCase', async (req, res) => {
const sampleCase = req.params.sampleCase;
const testCase =
{
simple: orchestrationChatCompletion,
template: orchestrationTemplating,
inputFiltering: orchestrationInputFiltering,
outputFiltering: orchestrationOutputFiltering,
requestConfig: orchestrationRequestConfig,
default: orchestrationChatCompletion
}[sampleCase] || orchestrationChatCompletion;

try {
res.send(await orchestrationCompletion(req.params.sampleCase));
const result = (await testCase()) as OrchestrationResponse;
if (sampleCase === 'inputFiltering') {
res.send('Input filter applied successfully');
} else if (sampleCase === 'outputFiltering') {
res.send(
`Output filter applied successfully with threshold results: ${JSON.stringify(result.data.module_results.output_filtering!.data!)}`
);
} else {
res.send(result.getContent());
}
} catch (error: any) {
console.error(error);
res
Expand Down
Loading

0 comments on commit ac10b81

Please sign in to comment.