Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for new cohere command r models #118

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,22 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
if (requestBody.top_p !== undefined) {
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P] = requestBody.top_p;
}
} else if (modelId.includes('cohere.command-r')) {
if (requestBody.max_tokens !== undefined) {
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS] = requestBody.max_tokens;
}
if (requestBody.temperature !== undefined) {
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TEMPERATURE] = requestBody.temperature;
}
if (requestBody.p !== undefined) {
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P] = requestBody.p;
}
if (requestBody.message !== undefined) {
// NOTE: We approximate the token count since this value is not directly available in the body
// According to Bedrock docs they use (total_chars / 6) to approximate token count for pricing.
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html
spanAttributes[AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS] = Math.ceil(requestBody.message.length / 6);
}
Comment on lines +258 to +263
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the docs, this data should be available in the response body. However when logging the response body in the implementation it seems the data is not actually there. As a result, I decided to stay with this token approximation approach.

Screenshot 2024-11-06 at 9 38 29 AM

} else if (modelId.includes('cohere.command')) {
if (requestBody.max_tokens !== undefined) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we should go ahead and remove support for the old Cohere Command model.

According to the docs, EOL should not be until 2025 but we are already getting 404s to this model.

spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS] = requestBody.max_tokens;
Expand All @@ -255,6 +271,9 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
if (requestBody.p !== undefined) {
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P] = requestBody.p;
}
if (requestBody.prompt !== undefined) {
spanAttributes[AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS] = Math.ceil(requestBody.prompt.length / 6);
}
} else if (modelId.includes('ai21.jamba')) {
if (requestBody.max_tokens !== undefined) {
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS] = requestBody.max_tokens;
Expand Down Expand Up @@ -329,13 +348,18 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
if (responseBody.stop_reason !== undefined) {
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_RESPONSE_FINISH_REASONS, [responseBody.stop_reason]);
}
} else if (currentModelId.includes('cohere.command')) {
if (responseBody.prompt !== undefined) {
} else if (currentModelId.includes('cohere.command-r')) {
console.log('Response Body:', responseBody);
if (responseBody.text !== undefined) {
// NOTE: We approximate the token count since this value is not directly available in the body
// According to Bedrock docs they use (total_chars / 6) to approximate token count for pricing.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prompt is only available in the JavaScript implementation because of special data model defined in an upstream Otel package. This makes it possible to approximate the input token usage from the response body.

However, this is not possible in the Java implementation as there is no special data model wrapping the inputs into the response body. As a result, I decided to move this approximation logic strictly to the request body to keep the implementation logic consistent between languages.

Screenshot 2024-11-06 at 9 42 25 AM

// https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS, Math.ceil(responseBody.prompt.length / 6));
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_USAGE_OUTPUT_TOKENS, Math.ceil(responseBody.text.length / 6));
}
if (responseBody.finish_reason !== undefined) {
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_RESPONSE_FINISH_REASONS, [responseBody.finish_reason]);
}
} else if (currentModelId.includes('cohere.command')) {
if (responseBody.generations?.[0]?.text !== undefined) {
span.setAttribute(
AwsSpanProcessingUtil.GEN_AI_USAGE_OUTPUT_TOKENS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,60 @@ describe('BedrockRuntime', () => {
expect(invokeModelSpan.kind).toBe(SpanKind.CLIENT);
});

it('Add Cohere Command R model attributes to span', async () => {
const modelId: string = 'cohere.command-r-v1:0"';
const prompt: string = "Describe the purpose of a 'hello world' program in one line";
const nativeRequest: any = {
message: prompt,
max_tokens: 512,
temperature: 0.5,
p: 0.65,
};
const mockRequestBody: string = JSON.stringify(nativeRequest);
const mockResponseBody: any = {
finish_reason: 'COMPLETE',
text: 'test-generation-text',
prompt: prompt,
request: {
commandInput: {
modelId: modelId,
},
},
};

nock(`https://bedrock-runtime.${region}.amazonaws.com`)
.post(`/model/${encodeURIComponent(modelId)}/invoke`)
.reply(200, mockResponseBody);

await bedrock
.invokeModel({
modelId: modelId,
body: mockRequestBody,
})
.catch((err: any) => {
console.log('error', err);
});

const testSpans: ReadableSpan[] = getTestSpans();
const invokeModelSpans: ReadableSpan[] = testSpans.filter((s: ReadableSpan) => {
return s.name === 'BedrockRuntime.InvokeModel';
});
expect(invokeModelSpans.length).toBe(1);
const invokeModelSpan = invokeModelSpans[0];
expect(invokeModelSpan.attributes[AWS_ATTRIBUTE_KEYS.AWS_BEDROCK_AGENT_ID]).toBeUndefined();
expect(invokeModelSpan.attributes[AWS_ATTRIBUTE_KEYS.AWS_BEDROCK_KNOWLEDGE_BASE_ID]).toBeUndefined();
expect(invokeModelSpan.attributes[AWS_ATTRIBUTE_KEYS.AWS_BEDROCK_DATA_SOURCE_ID]).toBeUndefined();
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_SYSTEM]).toBe('aws_bedrock');
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MODEL]).toBe(modelId);
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS]).toBe(512);
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TEMPERATURE]).toBe(0.5);
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P]).toBe(0.65);
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS]).toBe(10);
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_USAGE_OUTPUT_TOKENS]).toBe(4);
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_RESPONSE_FINISH_REASONS]).toEqual(['COMPLETE']);
expect(invokeModelSpan.kind).toBe(SpanKind.CLIENT);
});

it('Add Meta Llama model attributes to span', async () => {
const modelId: string = 'meta.llama2-13b-chat-v1';
const prompt: string = 'Describe the purpose of an interpreter program in one line.';
Expand Down
Loading