Skip to content

Commit

Permalink
Add OCI GenAI Cohere Chat integration
Browse files Browse the repository at this point in the history
Adds Oracle Cloud Infrastructure (OCI) Generative AI's Cohere chat model support
to expand Spring AI's cloud provider capabilities. This allows developers to use
OCI's managed Cohere models through both dedicated and on-demand serving modes.

The integration provides auto-configuration for simple setup while allowing full
customization of model parameters through OCICohereChatOptions. Teams can now
use OCI's Cohere models alongside other providers in Spring AI applications.

This change complements the existing OCI embedding support, offering a complete
set of GenAI capabilities for Oracle Cloud users.

Signed-off-by: Anders Swanson <[email protected]>
  • Loading branch information
anders-swanson authored and Mark Pollack committed Nov 6, 2024
1 parent 1cdec7b commit fb65ed0
Show file tree
Hide file tree
Showing 17 changed files with 1,212 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
import java.util.concurrent.atomic.AtomicInteger;

import com.oracle.bmc.generativeaiinference.GenerativeAiInference;
import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode;
import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails;
import com.oracle.bmc.generativeaiinference.model.EmbedTextResult;
import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode;
import com.oracle.bmc.generativeaiinference.model.ServingMode;
import com.oracle.bmc.generativeaiinference.requests.EmbedTextRequest;
import io.micrometer.observation.ObservationRegistry;
Expand Down Expand Up @@ -128,15 +126,6 @@ private EmbeddingResponse embedAllWithContext(List<EmbedTextRequest> embedTextRe
return embeddingResponse;
}

private ServingMode servingMode(OCIEmbeddingOptions embeddingOptions) {
return switch (embeddingOptions.getServingMode()) {
case "dedicated" -> DedicatedServingMode.builder().endpointId(embeddingOptions.getModel()).build();
case "on-demand" -> OnDemandServingMode.builder().modelId(embeddingOptions.getModel()).build();
default -> throw new IllegalArgumentException(
"unknown serving mode for OCI embedding model: " + embeddingOptions.getServingMode());
};
}

private List<EmbedTextRequest> createRequests(List<String> inputs, OCIEmbeddingOptions embeddingOptions) {
int size = inputs.size();
List<EmbedTextRequest> requests = new ArrayList<>();
Expand All @@ -148,8 +137,9 @@ private List<EmbedTextRequest> createRequests(List<String> inputs, OCIEmbeddingO
}

private EmbedTextRequest createRequest(List<String> inputs, OCIEmbeddingOptions embeddingOptions) {
ServingMode servingMode = ServingModeHelper.get(this.options.getServingMode(), this.options.getModel());
EmbedTextDetails embedTextDetails = EmbedTextDetails.builder()
.servingMode(servingMode(embeddingOptions))
.servingMode(servingMode)
.compartmentId(embeddingOptions.getCompartment())
.inputs(inputs)
.truncate(Objects.requireNonNullElse(embeddingOptions.getTruncate(), EmbedTextDetails.Truncate.End))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.oci;

import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode;
import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode;
import com.oracle.bmc.generativeaiinference.model.ServingMode;

/**
* Helper class to load the OCI Gen AI
* {@link com.oracle.bmc.generativeaiinference.model.ServingMode}
*
* @author Anders Swanson
*/
public final class ServingModeHelper {

private ServingModeHelper() {
}

/**
* Retrieves a specific type of ServingMode based on the provided serving mode string.
* @param servingMode The serving mode as a string. Supported options are 'dedicated'
* and 'on-demand'.
* @param model The model identifier to be used with the serving mode.
* @return A ServingMode instance configured according to the provided parameters.
* @throws IllegalArgumentException If the specified serving mode is not supported.
*/
public static ServingMode get(String servingMode, String model) {
return switch (servingMode) {
case "dedicated" -> DedicatedServingMode.builder().endpointId(model).build();
case "on-demand" -> OnDemandServingMode.builder().modelId(model).build();
default -> throw new IllegalArgumentException(String.format(
"Unknown serving mode for OCI Gen AI: %s. Supported options are 'dedicated' and 'on-demand'",
servingMode));
};
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.oci.cohere;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import com.oracle.bmc.generativeaiinference.GenerativeAiInference;
import com.oracle.bmc.generativeaiinference.model.BaseChatRequest;
import com.oracle.bmc.generativeaiinference.model.BaseChatResponse;
import com.oracle.bmc.generativeaiinference.model.ChatDetails;
import com.oracle.bmc.generativeaiinference.model.CohereChatBotMessage;
import com.oracle.bmc.generativeaiinference.model.CohereChatRequest;
import com.oracle.bmc.generativeaiinference.model.CohereChatResponse;
import com.oracle.bmc.generativeaiinference.model.CohereMessage;
import com.oracle.bmc.generativeaiinference.model.CohereSystemMessage;
import com.oracle.bmc.generativeaiinference.model.CohereToolCall;
import com.oracle.bmc.generativeaiinference.model.CohereToolMessage;
import com.oracle.bmc.generativeaiinference.model.CohereToolResult;
import com.oracle.bmc.generativeaiinference.model.CohereUserMessage;
import com.oracle.bmc.generativeaiinference.model.ServingMode;
import com.oracle.bmc.generativeaiinference.requests.ChatRequest;
import io.micrometer.observation.ObservationRegistry;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.oci.ServingModeHelper;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* {@link ChatModel} implementation that uses the OCI GenAI Chat API.
*
* @author Anders Swanson
* @since 1.0.0
*/
public class OCICohereChatModel implements ChatModel {

private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();

private static final Double DEFAULT_TEMPERATURE = 0.7;

/**
* The {@link GenerativeAiInference} client used to interact with OCI GenAI service.
*/
private final GenerativeAiInference genAi;

/**
* The configuration information for a chat completions request.
*/
private final OCICohereChatOptions defaultOptions;

private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public OCICohereChatModel(GenerativeAiInference genAi, OCICohereChatOptions options) {
this(genAi, options, null);
}

public OCICohereChatModel(GenerativeAiInference genAi, OCICohereChatOptions options,
ObservationRegistry observationRegistry) {
Assert.notNull(genAi, "com.oracle.bmc.generativeaiinference.GenerativeAiInference must not be null");
Assert.notNull(options, "OCIChatOptions must not be null");

this.genAi = genAi;
this.defaultOptions = options;
this.observationRegistry = observationRegistry;
}

@Override
public ChatResponse call(Prompt prompt) {
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.OCI_GENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
ChatResponse chatResponse = doChatRequest(prompt);
observationContext.setResponse(chatResponse);
return chatResponse;
});
}

@Override
public ChatOptions getDefaultOptions() {
return OCICohereChatOptions.fromOptions(this.defaultOptions);
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
*/
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}

private ChatResponse doChatRequest(Prompt prompt) {
OCICohereChatOptions options = mergeOptions(prompt.getOptions(), this.defaultOptions);
validateChatOptions(options);

ChatResponseMetadata metadata = ChatResponseMetadata.builder()
.withModel(options.getModel())
.withKeyValue("compartment", options.getCompartment())
.build();
return new ChatResponse(getGenerations(prompt, options), metadata);

}

private OCICohereChatOptions mergeOptions(ChatOptions chatOptions, OCICohereChatOptions defaultOptions) {
if (chatOptions instanceof OCICohereChatOptions override) {
OCICohereChatOptions dynamicOptions = ModelOptionsUtils.merge(override, defaultOptions,
OCICohereChatOptions.class);

if (dynamicOptions != null) {
return dynamicOptions;
}
}
return defaultOptions;
}

private void validateChatOptions(OCICohereChatOptions options) {
if (!StringUtils.hasText(options.getModel())) {
throw new IllegalArgumentException("Model is not set!");
}
if (!StringUtils.hasText(options.getCompartment())) {
throw new IllegalArgumentException("Compartment is not set!");
}
if (!StringUtils.hasText(options.getServingMode())) {
throw new IllegalArgumentException("ServingMode is not set!");
}
}

private List<Generation> getGenerations(Prompt prompt, OCICohereChatOptions options) {
com.oracle.bmc.generativeaiinference.responses.ChatResponse cr = genAi
.chat(toCohereChatRequest(prompt, options));
return toGenerations(cr, options);

}

private List<Generation> toGenerations(com.oracle.bmc.generativeaiinference.responses.ChatResponse ociChatResponse,
OCICohereChatOptions options) {
BaseChatResponse cr = ociChatResponse.getChatResult().getChatResponse();
if (cr instanceof CohereChatResponse resp) {
List<Generation> generations = new ArrayList<>();
ChatGenerationMetadata metadata = ChatGenerationMetadata.from(resp.getFinishReason().getValue(), null);
AssistantMessage message = new AssistantMessage(resp.getText(), Map.of());
generations.add(new Generation(message, metadata));
return generations;
}
throw new IllegalStateException(String.format("Unexpected chat response type: %s", cr.getClass().getName()));
}

private ChatRequest toCohereChatRequest(Prompt prompt, OCICohereChatOptions options) {
List<Message> messages = prompt.getInstructions();
Message message = messages.get(0);
List<CohereMessage> chatHistory = getCohereMessages(messages);
return newChatRequest(options, message, chatHistory);
}

private List<CohereMessage> getCohereMessages(List<Message> messages) {
List<CohereMessage> chatHistory = new ArrayList<>();
for (int i = 1; i < messages.size(); i++) {
Message message = messages.get(i);
switch (message.getMessageType()) {
case USER -> chatHistory.add(CohereUserMessage.builder().message(message.getContent()).build());
case ASSISTANT -> chatHistory.add(CohereChatBotMessage.builder().message(message.getContent()).build());
case SYSTEM -> chatHistory.add(CohereSystemMessage.builder().message(message.getContent()).build());
case TOOL -> {
if (message instanceof ToolResponseMessage tm) {
chatHistory.add(toToolMessage(tm));
}
}
}
}
return chatHistory;
}

private CohereToolMessage toToolMessage(ToolResponseMessage tm) {
List<CohereToolResult> results = tm.getResponses().stream().map(r -> {
CohereToolCall call = CohereToolCall.builder().name(r.name()).build();
return CohereToolResult.builder().call(call).outputs(List.of(r.responseData())).build();
}).toList();
return CohereToolMessage.builder().toolResults(results).build();
}

private ChatRequest newChatRequest(OCICohereChatOptions options, Message message, List<CohereMessage> chatHistory) {
BaseChatRequest baseChatRequest = CohereChatRequest.builder()
.frequencyPenalty(options.getFrequencyPenalty())
.presencePenalty(options.getPresencePenalty())
.maxTokens(options.getMaxTokens())
.topK(options.getTopK())
.topP(options.getTopP())
.temperature(Objects.requireNonNullElse(options.getTemperature(), DEFAULT_TEMPERATURE))
.preambleOverride(options.getPreambleOverride())
.stopSequences(options.getStopSequences())
.documents(options.getDocuments())
.tools(options.getTools())
.chatHistory(chatHistory)
.message(message.getContent())
.build();
ServingMode servingMode = ServingModeHelper.get(options.getServingMode(), options.getModel());
ChatDetails chatDetails = ChatDetails.builder()
.compartmentId(options.getCompartment())
.servingMode(servingMode)
.chatRequest(baseChatRequest)
.build();
return ChatRequest.builder().body$(chatDetails).build();
}

}
Loading

0 comments on commit fb65ed0

Please sign in to comment.