From fb65ed01bdbf0a9cdd083272100c888582a102ca Mon Sep 17 00:00:00 2001 From: Anders Swanson Date: Wed, 30 Oct 2024 08:42:59 -0700 Subject: [PATCH] Add OCI GenAI Cohere Chat integration Adds Oracle Cloud Infrastructure (OCI) Generative AI's Cohere chat model support to expand Spring AI's cloud provider capabilities. This allows developers to use OCI's managed Cohere models through both dedicated and on-demand serving modes. The integration provides auto-configuration for simple setup while allowing full customization of model parameters through OCICohereChatOptions. Teams can now use OCI's Cohere models alongside other providers in Spring AI applications. This change complements the existing OCI embedding support, offering a complete set of GenAI capabilities for Oracle Cloud users. Signed-off-by: Anders Swanson --- .../ai/oci/OCIEmbeddingModel.java | 14 +- .../ai/oci/ServingModeHelper.java | 52 +++ .../ai/oci/cohere/OCICohereChatModel.java | 249 +++++++++++++ .../ai/oci/cohere/OCICohereChatOptions.java | 346 ++++++++++++++++++ .../ai/oci/BaseEmbeddingModelTest.java | 43 +-- .../ai/oci/BaseOCIGenAITest.java | 61 +++ .../ai/oci/OCIEmbeddingModelIT.java | 4 +- .../ai/oci/cohere/OCICohereChatModelIT.java | 59 +++ pom.xml | 2 +- .../src/main/antora/modules/ROOT/nav.adoc | 3 +- .../ROOT/pages/api/chat/comparison.adoc | 36 +- .../pages/api/chat/oci-genai/cohere-chat.adoc | 208 +++++++++++ spring-ai-spring-boot-autoconfigure/pom.xml | 7 + .../genai/OCICohereChatModelProperties.java | 59 +++ .../oci/genai/OCIGenAiAutoConfiguration.java | 20 +- .../genai/OCIGenAIAutoConfigurationTest.java | 91 +++++ .../genai/OCIGenAiAutoConfigurationIT.java | 30 +- 17 files changed, 1212 insertions(+), 72 deletions(-) create mode 100644 models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/ServingModeHelper.java create mode 100644 models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java create mode 100644 models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java create mode 100644 models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseOCIGenAITest.java create mode 100644 models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatModelIT.java create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/oci-genai/cohere-chat.adoc create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCICohereChatModelProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAIAutoConfigurationTest.java diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java index e3658a226a..c8da46f622 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java @@ -22,10 +22,8 @@ import java.util.concurrent.atomic.AtomicInteger; import com.oracle.bmc.generativeaiinference.GenerativeAiInference; -import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode; import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails; import com.oracle.bmc.generativeaiinference.model.EmbedTextResult; -import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode; import com.oracle.bmc.generativeaiinference.model.ServingMode; import com.oracle.bmc.generativeaiinference.requests.EmbedTextRequest; import io.micrometer.observation.ObservationRegistry; @@ -128,15 +126,6 @@ private EmbeddingResponse embedAllWithContext(List embedTextRe return embeddingResponse; } - private ServingMode servingMode(OCIEmbeddingOptions embeddingOptions) { - return switch (embeddingOptions.getServingMode()) { - case "dedicated" -> DedicatedServingMode.builder().endpointId(embeddingOptions.getModel()).build(); - case "on-demand" -> OnDemandServingMode.builder().modelId(embeddingOptions.getModel()).build(); - default -> throw new IllegalArgumentException( - "unknown serving mode for OCI embedding model: " + embeddingOptions.getServingMode()); - }; - } - private List createRequests(List inputs, OCIEmbeddingOptions embeddingOptions) { int size = inputs.size(); List requests = new ArrayList<>(); @@ -148,8 +137,9 @@ private List createRequests(List inputs, OCIEmbeddingO } private EmbedTextRequest createRequest(List inputs, OCIEmbeddingOptions embeddingOptions) { + ServingMode servingMode = ServingModeHelper.get(this.options.getServingMode(), this.options.getModel()); EmbedTextDetails embedTextDetails = EmbedTextDetails.builder() - .servingMode(servingMode(embeddingOptions)) + .servingMode(servingMode) .compartmentId(embeddingOptions.getCompartment()) .inputs(inputs) .truncate(Objects.requireNonNullElse(embeddingOptions.getTruncate(), EmbedTextDetails.Truncate.End)) diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/ServingModeHelper.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/ServingModeHelper.java new file mode 100644 index 0000000000..22d6a91d6e --- /dev/null +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/ServingModeHelper.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.oci; + +import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode; +import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode; +import com.oracle.bmc.generativeaiinference.model.ServingMode; + +/** + * Helper class to load the OCI Gen AI + * {@link com.oracle.bmc.generativeaiinference.model.ServingMode} + * + * @author Anders Swanson + */ +public final class ServingModeHelper { + + private ServingModeHelper() { + } + + /** + * Retrieves a specific type of ServingMode based on the provided serving mode string. + * @param servingMode The serving mode as a string. Supported options are 'dedicated' + * and 'on-demand'. + * @param model The model identifier to be used with the serving mode. + * @return A ServingMode instance configured according to the provided parameters. + * @throws IllegalArgumentException If the specified serving mode is not supported. + */ + public static ServingMode get(String servingMode, String model) { + return switch (servingMode) { + case "dedicated" -> DedicatedServingMode.builder().endpointId(model).build(); + case "on-demand" -> OnDemandServingMode.builder().modelId(model).build(); + default -> throw new IllegalArgumentException(String.format( + "Unknown serving mode for OCI Gen AI: %s. Supported options are 'dedicated' and 'on-demand'", + servingMode)); + }; + } + +} diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java new file mode 100644 index 0000000000..5bd7756f88 --- /dev/null +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java @@ -0,0 +1,249 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.oci.cohere; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import com.oracle.bmc.generativeaiinference.GenerativeAiInference; +import com.oracle.bmc.generativeaiinference.model.BaseChatRequest; +import com.oracle.bmc.generativeaiinference.model.BaseChatResponse; +import com.oracle.bmc.generativeaiinference.model.ChatDetails; +import com.oracle.bmc.generativeaiinference.model.CohereChatBotMessage; +import com.oracle.bmc.generativeaiinference.model.CohereChatRequest; +import com.oracle.bmc.generativeaiinference.model.CohereChatResponse; +import com.oracle.bmc.generativeaiinference.model.CohereMessage; +import com.oracle.bmc.generativeaiinference.model.CohereSystemMessage; +import com.oracle.bmc.generativeaiinference.model.CohereToolCall; +import com.oracle.bmc.generativeaiinference.model.CohereToolMessage; +import com.oracle.bmc.generativeaiinference.model.CohereToolResult; +import com.oracle.bmc.generativeaiinference.model.CohereUserMessage; +import com.oracle.bmc.generativeaiinference.model.ServingMode; +import com.oracle.bmc.generativeaiinference.requests.ChatRequest; +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.oci.ServingModeHelper; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * {@link ChatModel} implementation that uses the OCI GenAI Chat API. + * + * @author Anders Swanson + * @since 1.0.0 + */ +public class OCICohereChatModel implements ChatModel { + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + /** + * The {@link GenerativeAiInference} client used to interact with OCI GenAI service. + */ + private final GenerativeAiInference genAi; + + /** + * The configuration information for a chat completions request. + */ + private final OCICohereChatOptions defaultOptions; + + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + public OCICohereChatModel(GenerativeAiInference genAi, OCICohereChatOptions options) { + this(genAi, options, null); + } + + public OCICohereChatModel(GenerativeAiInference genAi, OCICohereChatOptions options, + ObservationRegistry observationRegistry) { + Assert.notNull(genAi, "com.oracle.bmc.generativeaiinference.GenerativeAiInference must not be null"); + Assert.notNull(options, "OCIChatOptions must not be null"); + + this.genAi = genAi; + this.defaultOptions = options; + this.observationRegistry = observationRegistry; + } + + @Override + public ChatResponse call(Prompt prompt) { + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(AiProvider.OCI_GENAI.value()) + .requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions) + .build(); + + return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + ChatResponse chatResponse = doChatRequest(prompt); + observationContext.setResponse(chatResponse); + return chatResponse; + }); + } + + @Override + public ChatOptions getDefaultOptions() { + return OCICohereChatOptions.fromOptions(this.defaultOptions); + } + + /** + * Use the provided convention for reporting observation data + * @param observationConvention The provided convention + */ + public void setObservationConvention(ChatModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + + private ChatResponse doChatRequest(Prompt prompt) { + OCICohereChatOptions options = mergeOptions(prompt.getOptions(), this.defaultOptions); + validateChatOptions(options); + + ChatResponseMetadata metadata = ChatResponseMetadata.builder() + .withModel(options.getModel()) + .withKeyValue("compartment", options.getCompartment()) + .build(); + return new ChatResponse(getGenerations(prompt, options), metadata); + + } + + private OCICohereChatOptions mergeOptions(ChatOptions chatOptions, OCICohereChatOptions defaultOptions) { + if (chatOptions instanceof OCICohereChatOptions override) { + OCICohereChatOptions dynamicOptions = ModelOptionsUtils.merge(override, defaultOptions, + OCICohereChatOptions.class); + + if (dynamicOptions != null) { + return dynamicOptions; + } + } + return defaultOptions; + } + + private void validateChatOptions(OCICohereChatOptions options) { + if (!StringUtils.hasText(options.getModel())) { + throw new IllegalArgumentException("Model is not set!"); + } + if (!StringUtils.hasText(options.getCompartment())) { + throw new IllegalArgumentException("Compartment is not set!"); + } + if (!StringUtils.hasText(options.getServingMode())) { + throw new IllegalArgumentException("ServingMode is not set!"); + } + } + + private List getGenerations(Prompt prompt, OCICohereChatOptions options) { + com.oracle.bmc.generativeaiinference.responses.ChatResponse cr = genAi + .chat(toCohereChatRequest(prompt, options)); + return toGenerations(cr, options); + + } + + private List toGenerations(com.oracle.bmc.generativeaiinference.responses.ChatResponse ociChatResponse, + OCICohereChatOptions options) { + BaseChatResponse cr = ociChatResponse.getChatResult().getChatResponse(); + if (cr instanceof CohereChatResponse resp) { + List generations = new ArrayList<>(); + ChatGenerationMetadata metadata = ChatGenerationMetadata.from(resp.getFinishReason().getValue(), null); + AssistantMessage message = new AssistantMessage(resp.getText(), Map.of()); + generations.add(new Generation(message, metadata)); + return generations; + } + throw new IllegalStateException(String.format("Unexpected chat response type: %s", cr.getClass().getName())); + } + + private ChatRequest toCohereChatRequest(Prompt prompt, OCICohereChatOptions options) { + List messages = prompt.getInstructions(); + Message message = messages.get(0); + List chatHistory = getCohereMessages(messages); + return newChatRequest(options, message, chatHistory); + } + + private List getCohereMessages(List messages) { + List chatHistory = new ArrayList<>(); + for (int i = 1; i < messages.size(); i++) { + Message message = messages.get(i); + switch (message.getMessageType()) { + case USER -> chatHistory.add(CohereUserMessage.builder().message(message.getContent()).build()); + case ASSISTANT -> chatHistory.add(CohereChatBotMessage.builder().message(message.getContent()).build()); + case SYSTEM -> chatHistory.add(CohereSystemMessage.builder().message(message.getContent()).build()); + case TOOL -> { + if (message instanceof ToolResponseMessage tm) { + chatHistory.add(toToolMessage(tm)); + } + } + } + } + return chatHistory; + } + + private CohereToolMessage toToolMessage(ToolResponseMessage tm) { + List results = tm.getResponses().stream().map(r -> { + CohereToolCall call = CohereToolCall.builder().name(r.name()).build(); + return CohereToolResult.builder().call(call).outputs(List.of(r.responseData())).build(); + }).toList(); + return CohereToolMessage.builder().toolResults(results).build(); + } + + private ChatRequest newChatRequest(OCICohereChatOptions options, Message message, List chatHistory) { + BaseChatRequest baseChatRequest = CohereChatRequest.builder() + .frequencyPenalty(options.getFrequencyPenalty()) + .presencePenalty(options.getPresencePenalty()) + .maxTokens(options.getMaxTokens()) + .topK(options.getTopK()) + .topP(options.getTopP()) + .temperature(Objects.requireNonNullElse(options.getTemperature(), DEFAULT_TEMPERATURE)) + .preambleOverride(options.getPreambleOverride()) + .stopSequences(options.getStopSequences()) + .documents(options.getDocuments()) + .tools(options.getTools()) + .chatHistory(chatHistory) + .message(message.getContent()) + .build(); + ServingMode servingMode = ServingModeHelper.get(options.getServingMode(), options.getModel()); + ChatDetails chatDetails = ChatDetails.builder() + .compartmentId(options.getCompartment()) + .servingMode(servingMode) + .chatRequest(baseChatRequest) + .build(); + return ChatRequest.builder().body$(chatDetails).build(); + } + +} diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java new file mode 100644 index 0000000000..c41a4ed692 --- /dev/null +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java @@ -0,0 +1,346 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.oci.cohere; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.oracle.bmc.generativeaiinference.model.CohereTool; + +import org.springframework.ai.chat.prompt.ChatOptions; + +/** + * The configuration information for OCI chat requests + * + * @author Anders Swanson + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class OCICohereChatOptions implements ChatOptions { + + @JsonProperty("model") + private String model; + + /** + * The maximum number of tokens to generate per request. + */ + @JsonProperty("maxTokens") + private Integer maxTokens; + + /** + * The OCI Compartment to run chat requests in. + */ + @JsonProperty("compartment") + private String compartment; + + /** + * The serving mode of OCI Gen AI model used. May be "on-demand" or "dedicated". + */ + @JsonProperty("servingMode") + private String servingMode; + + /** + * The optional override to the chat model's prompt preamble. + */ + @JsonProperty("preambleOverride") + private String preambleOverride; + + /** + * The sample temperature, where higher values are more random, and lower values are + * more deterministic. + */ + @JsonProperty("temperature") + private Double temperature; + + /** + * The Top P parameter modifies the probability of tokens sampled. E.g., a value of + * 0.25 means only tokens from the top 25% probability mass will be considered. + */ + @JsonProperty("topP") + private Double topP; + + /** + * The Top K parameter limits the number of potential tokens considered at each step + * of text generation. E.g., a value of 5 means only the top 5 most probable tokens + * will be considered during each step of text generation. + */ + @JsonProperty("topK") + private Integer topK; + + /** + * The frequency penalty assigns a penalty to repeated tokens depending on how many + * times it has already appeared in the prompt or output. Higher values will reduce + * repeated tokens and outputs will be more random. + */ + @JsonProperty("frequencyPenalty") + private Double frequencyPenalty; + + /** + * The presence penalty assigns a penalty to each token when it appears in the output + * to encourage generating outputs with tokens that haven't been used. + */ + @JsonProperty("presencePenalty") + private Double presencePenalty; + + /** + * A collection of textual sequences that will end completions generation. + */ + @JsonProperty("stop") + private List stop; + + /** + * Documents for chat context. + */ + @JsonProperty("documents") + private List documents; + + /** + * Tools for the chatbot. + */ + @JsonProperty("tools") + private List tools; + + public static OCICohereChatOptions fromOptions(OCICohereChatOptions fromOptions) { + return builder().withModel(fromOptions.model) + .withMaxTokens(fromOptions.maxTokens) + .withCompartment(fromOptions.compartment) + .withServingMode(fromOptions.servingMode) + .withPreambleOverride(fromOptions.preambleOverride) + .withTemperature(fromOptions.temperature) + .withTopP(fromOptions.topP) + .withTopK(fromOptions.topK) + .withStop(fromOptions.stop) + .withFrequencyPenalty(fromOptions.frequencyPenalty) + .withPresencePenalty(fromOptions.presencePenalty) + .withDocuments(fromOptions.documents) + .withTools(fromOptions.tools) + .build(); + } + + public static Builder builder() { + return new Builder(); + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public String getPreambleOverride() { + return this.preambleOverride; + } + + public void setPreambleOverride(String preambleOverride) { + this.preambleOverride = preambleOverride; + } + + public String getServingMode() { + return this.servingMode; + } + + public void setServingMode(String servingMode) { + this.servingMode = servingMode; + } + + public String getCompartment() { + return this.compartment; + } + + public void setCompartment(String compartment) { + this.compartment = compartment; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public void setModel(String model) { + this.model = model; + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + public List getDocuments() { + return this.documents; + } + + public void setDocuments(List documents) { + this.documents = documents; + } + + public List getTools() { + return this.tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + /* + * ChatModel overrides. + */ + + @Override + public String getModel() { + return this.model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + @Override + public List getStopSequences() { + return this.stop; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + @Override + public Integer getTopK() { + return this.topK; + } + + @Override + public Double getTopP() { + return this.topP; + } + + @Override + public ChatOptions copy() { + return fromOptions(this); + } + + public static class Builder { + + protected OCICohereChatOptions chatOptions; + + public Builder() { + this.chatOptions = new OCICohereChatOptions(); + } + + public Builder(OCICohereChatOptions chatOptions) { + this.chatOptions = chatOptions; + } + + public Builder withModel(String model) { + this.chatOptions.model = model; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.chatOptions.maxTokens = maxTokens; + return this; + } + + public Builder withCompartment(String compartment) { + this.chatOptions.compartment = compartment; + return this; + } + + public Builder withServingMode(String servingMode) { + this.chatOptions.servingMode = servingMode; + return this; + } + + public Builder withPreambleOverride(String preambleOverride) { + this.chatOptions.preambleOverride = preambleOverride; + return this; + } + + public Builder withTemperature(Double temperature) { + this.chatOptions.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.chatOptions.topP = topP; + return this; + } + + public Builder withTopK(Integer topK) { + this.chatOptions.topK = topK; + return this; + } + + public Builder withFrequencyPenalty(Double frequencyPenalty) { + this.chatOptions.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withPresencePenalty(Double presencePenalty) { + this.chatOptions.presencePenalty = presencePenalty; + return this; + } + + public Builder withStop(List stop) { + this.chatOptions.stop = stop; + return this; + } + + public Builder withDocuments(List documents) { + this.chatOptions.documents = documents; + return this; + } + + public Builder withTools(List tools) { + this.chatOptions.tools = tools; + return this; + } + + public OCICohereChatOptions build() { + return this.chatOptions; + } + + } + +} diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java index 6ca87f2b11..77b71896b4 100644 --- a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java @@ -16,50 +16,23 @@ package org.springframework.ai.oci; -import java.io.IOException; -import java.nio.file.Paths; - -import com.oracle.bmc.Region; -import com.oracle.bmc.auth.ConfigFileAuthenticationDetailsProvider; -import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient; - -public class BaseEmbeddingModelTest { - - public static final String OCI_COMPARTMENT_ID_KEY = "OCI_COMPARTMENT_ID"; +public class BaseEmbeddingModelTest extends BaseOCIGenAITest { public static final String EMBEDDING_MODEL_V2 = "cohere.embed-english-light-v2.0"; public static final String EMBEDDING_MODEL_V3 = "cohere.embed-english-v3.0"; - private static final String CONFIG_FILE = Paths.get(System.getProperty("user.home"), ".oci", "config").toString(); - - private static final String PROFILE = "DEFAULT"; - - private static final String REGION = "us-chicago-1"; - - private static final String COMPARTMENT_ID = System.getenv(OCI_COMPARTMENT_ID_KEY); - /** * Create an OCIEmbeddingModel instance using a config file authentication provider. * @return OCIEmbeddingModel instance */ - public OCIEmbeddingModel get() { - try { - ConfigFileAuthenticationDetailsProvider authProvider = new ConfigFileAuthenticationDetailsProvider( - CONFIG_FILE, PROFILE); - GenerativeAiInferenceClient aiClient = GenerativeAiInferenceClient.builder() - .region(Region.valueOf(REGION)) - .build(authProvider); - OCIEmbeddingOptions options = OCIEmbeddingOptions.builder() - .withModel(EMBEDDING_MODEL_V2) - .withCompartment(COMPARTMENT_ID) - .withServingMode("on-demand") - .build(); - return new OCIEmbeddingModel(aiClient, options); - } - catch (IOException e) { - throw new RuntimeException(e); - } + public static OCIEmbeddingModel getEmbeddingModel() { + OCIEmbeddingOptions options = OCIEmbeddingOptions.builder() + .withModel(EMBEDDING_MODEL_V2) + .withCompartment(COMPARTMENT_ID) + .withServingMode("on-demand") + .build(); + return new OCIEmbeddingModel(getGenerativeAIClient(), options); } } diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseOCIGenAITest.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseOCIGenAITest.java new file mode 100644 index 0000000000..358563cb06 --- /dev/null +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseOCIGenAITest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.oci; + +import java.io.IOException; +import java.nio.file.Paths; + +import com.oracle.bmc.Region; +import com.oracle.bmc.auth.ConfigFileAuthenticationDetailsProvider; +import com.oracle.bmc.generativeaiinference.GenerativeAiInference; +import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient; +import org.springframework.ai.oci.cohere.OCICohereChatOptions; + +public class BaseOCIGenAITest { + + public static final String OCI_COMPARTMENT_ID_KEY = "OCI_COMPARTMENT_ID"; + + public static final String OCI_CHAT_MODEL_ID_KEY = "OCI_CHAT_MODEL_ID"; + + public static final String CONFIG_FILE = Paths.get(System.getProperty("user.home"), ".oci", "config").toString(); + + public static final String PROFILE = "DEFAULT"; + + public static final String REGION = "us-chicago-1"; + + public static final String COMPARTMENT_ID = System.getenv(OCI_COMPARTMENT_ID_KEY); + + public static final String CHAT_MODEL_ID = System.getenv(OCI_CHAT_MODEL_ID_KEY); + + public static GenerativeAiInference getGenerativeAIClient() { + try { + ConfigFileAuthenticationDetailsProvider authProvider = new ConfigFileAuthenticationDetailsProvider( + CONFIG_FILE, PROFILE); + return GenerativeAiInferenceClient.builder().region(Region.valueOf(REGION)).build(authProvider); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static OCICohereChatOptions.Builder options() { + return OCICohereChatOptions.builder() + .withModel(CHAT_MODEL_ID) + .withCompartment(COMPARTMENT_ID) + .withServingMode("on-demand"); + } + +} diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java index 586fbfddee..94059b5977 100644 --- a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/OCIEmbeddingModelIT.java @@ -29,9 +29,9 @@ @EnabledIfEnvironmentVariable(named = org.springframework.ai.oci.BaseEmbeddingModelTest.OCI_COMPARTMENT_ID_KEY, matches = ".+") -public class OCIEmbeddingModelIT extends BaseEmbeddingModelTest { +class OCIEmbeddingModelIT extends BaseEmbeddingModelTest { - private final OCIEmbeddingModel embeddingModel = get(); + private final OCIEmbeddingModel embeddingModel = getEmbeddingModel(); private final List content = List.of("How many states are in the USA?", "How many states are in India?"); diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatModelIT.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatModelIT.java new file mode 100644 index 0000000000..f12684caab --- /dev/null +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatModelIT.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.oci.cohere; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.oci.BaseOCIGenAITest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.oci.BaseOCIGenAITest.OCI_CHAT_MODEL_ID_KEY; +import static org.springframework.ai.oci.BaseOCIGenAITest.OCI_COMPARTMENT_ID_KEY; + +@EnabledIfEnvironmentVariable(named = OCI_COMPARTMENT_ID_KEY, matches = ".+") +@EnabledIfEnvironmentVariable(named = OCI_CHAT_MODEL_ID_KEY, matches = ".+") +public class OCICohereChatModelIT extends BaseOCIGenAITest { + + private static final ChatModel chatModel = new OCICohereChatModel(getGenerativeAIClient(), options().build()); + + @Test + void chatSimple() { + String response = chatModel.call("Tell me a random fact about Canada"); + assertThat(response).isNotBlank(); + } + + @Test + void chatMessages() { + String response = chatModel.call(new UserMessage("Tell me a random fact about the Arctic Circle"), + new SystemMessage("You are a helpful assistant")); + assertThat(response).isNotBlank(); + } + + @Test + void chatPrompt() { + ChatResponse response = chatModel.call(new Prompt("What's the difference between Top P and Top K sampling?")); + assertThat(response).isNotNull(); + assertThat(response.getMetadata().getModel()).isEqualTo(CHAT_MODEL_ID); + assertThat(response.getResult()).isNotNull(); + assertThat(response.getResult().getOutput().getContent()).isNotBlank(); + } + +} diff --git a/pom.xml b/pom.xml index ccafa626d9..11ae13108f 100644 --- a/pom.xml +++ b/pom.xml @@ -183,7 +183,7 @@ 0.30.0 1.19.2 - 3.46.1 + 3.51.0 26.48.0 1.9.1 2.0.9 diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 60cdd51297..2f2c7d5d4a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -31,7 +31,8 @@ //// **** xref:api/chat/functions/moonshot-chat-functions.adoc[Function Calling] *** xref:api/chat/nvidia-chat.adoc[NVIDIA] *** xref:api/chat/ollama-chat.adoc[Ollama] -**** xref:api/chat/functions/ollama-chat-functions.adoc[Function Calling] +*** OCI Generative AI +**** xref:api/chat/oci-genai/cohere-chat.adoc[Cohere] *** xref:api/chat/openai-chat.adoc[OpenAI] **** xref:api/chat/functions/openai-chat-functions.adoc[Function Calling] *** xref:api/chat/qianfan-chat.adoc[QianFan] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc index ae96a2c068..d8796c795c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc @@ -19,24 +19,24 @@ This table compares various Chat Models supported by Spring AI, detailing their |==== | Provider | Multimodality ^| Tools/Functions ^| Streaming ^| Retry ^| Observability ^| Built-in JSON ^| Local ^| OpenAI API Compatible -| xref::api/chat/anthropic-chat.adoc[Anthropic Claude] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/azure-openai-chat.adoc[Azure OpenAI] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/vertexai-gemini-chat.adoc[Google VertexAI Gemini] | text, image, audio, video ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/anthropic-chat.adoc[Anthropic Claude] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/azure-openai-chat.adoc[Azure OpenAI] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/vertexai-gemini-chat.adoc[Google VertexAI Gemini] | text, image, audio, video ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] | xref::api/chat/groq-chat.adoc[Groq (OpenAI-proxy)] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/huggingface.adoc[HuggingFace] | text ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/mistralai-chat.adoc[Mistral AI] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/minimax-chat.adoc[MiniMax] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| +| xref::api/chat/huggingface.adoc[HuggingFace] | text ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/mistralai-chat.adoc[Mistral AI] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/minimax-chat.adoc[MiniMax] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| | xref::api/chat/moonshot-chat.adoc[Moonshot AI] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| -| xref::api/chat/nvidia-chat.adoc[NVIDIA (OpenAI-proxy)] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/ollama-chat.adoc[Ollama] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] -| xref::api/chat/openai-chat.adoc[OpenAI] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] -| xref::api/chat/qianfan-chat.adoc[QianFan] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/zhipuai-chat.adoc[ZhiPu AI] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/watsonx-ai-chat.adoc[Watsonx.AI] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/bedrock/bedrock-cohere.adoc[Amazon Bedrock/Cohere] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/bedrock/bedrock-jurassic2.adoc[Amazon Bedrock/Jurassic] | text ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/bedrock/bedrock-llama.adoc[Amazon Bedrock/Llama] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/bedrock/bedrock-titan.adoc[Amazon Bedrock/Titan] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] -| xref::api/chat/bedrock/bedrock-anthropic3.adoc[Amazon Bedrock/Anthropic 3] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/nvidia-chat.adoc[NVIDIA (OpenAI-proxy)] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/oci-genai/cohere-chat.adoc[OCI GenAI/Cohere] | text ^a| image::no.svg[width=16] ^a| image::no.svg[width=16] ^a| image::no.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=16] ^a| image::no.svg[width=16] ^a| image::no.svg[width=16] +| xref::api/chat/ollama-chat.adoc[Ollama] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] +| xref::api/chat/openai-chat.adoc[OpenAI] | text, image ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] +| xref::api/chat/qianfan-chat.adoc[QianFan] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/zhipuai-chat.adoc[ZhiPu AI] | text ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/watsonx-ai-chat.adoc[Watsonx.AI] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/bedrock/bedrock-cohere.adoc[Amazon Bedrock/Cohere] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/bedrock/bedrock-jurassic2.adoc[Amazon Bedrock/Jurassic] | text ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/bedrock/bedrock-llama.adoc[Amazon Bedrock/Llama] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/bedrock/bedrock-titan.adoc[Amazon Bedrock/Titan] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] +| xref::api/chat/bedrock/bedrock-anthropic3.adoc[Amazon Bedrock/Anthropic 3] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] |==== - diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/oci-genai/cohere-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/oci-genai/cohere-chat.adoc new file mode 100644 index 0000000000..5cfd220c42 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/oci-genai/cohere-chat.adoc @@ -0,0 +1,208 @@ += OCI GenAI Cohere Chat + +https://www.oracle.com/artificial-intelligence/generative-ai/generative-ai-service/[OCI GenAI Service] offers generative AI chat with on-demand models, or dedicated AI clusters. + +The https://docs.oracle.com/en-us/iaas/Content/generative-ai/chat-models.htm[OCI Chat Models Page] and https://docs.oracle.com/en-us/iaas/Content/generative-ai/use-playground-embed.htm[OCI Generative AI Playground] provide detailed information about using and hosting chat models on OCI. + +== Prerequisites + +You will need an active https://signup.oraclecloud.com/[Oracle Cloud Infrastructure (OCI)] account to use the OCI GenAI Cohere Chat client. The client offers four different ways to connect, including simple authentication with a user and private key, workload identity, instance principal, or OCI configuration file authentication. + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. +Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the OCI GenAI Cohere Chat Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-oci-genai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-oci-genai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Chat Properties + +==== Connection Properties + +The prefix `spring.ai.oci.genai` is the property prefix to configure the connection to OCI GenAI. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.oci.genai.authenticationType | The type of authentication to use when authenticating to OCI. May be `file`, `instance-principal`, `workload-identity`, or `simple`. | file +| spring.ai.oci.genai.region | OCI service region. | us-chicago-1 +| spring.ai.oci.genai.tenantId | OCI tenant OCID, used when authenticating with `simple` auth. | - +| spring.ai.oci.genai.userId | OCI user OCID, used when authenticating with `simple` auth. | - +| spring.ai.oci.genai.fingerprint | Private key fingerprint, used when authenticating with `simple` auth. | - +| spring.ai.oci.genai.privateKey | Private key content, used when authenticating with `simple` auth. | - +| spring.ai.oci.genai.passPhrase | Optional private key passphrase, used when authenticating with `simple` auth and a passphrase protected private key. | - +| spring.ai.oci.genai.file | Path to OCI config file. Used when authenticating with `file` auth. | /.oci/config +| spring.ai.oci.genai.profile | OCI profile name. Used when authenticating with `file` auth. | DEFAULT +| spring.ai.oci.genai.endpoint | Optional OCI GenAI endpoint. | - + +|==== + + +==== Configuration Properties + +The prefix `spring.ai.oci.genai.chat.cohere` is the property prefix that configures the `ChatModel` implementation for OCI GenAI Cohere Chat. + +[cols="3,5,1", stripes=even] +|==== +| Property | Description | Default + +| spring.ai.oci.genai.chat.cohere.enabled | Enable OCI GenAI Cohere chat model. | true +| spring.ai.oci.genai.chat.cohere.options.model | Model OCID or endpoint | - +| spring.ai.oci.genai.chat.cohere.options.compartment | Model compartment OCID. | - +| spring.ai.oci.genai.chat.cohere.options.servingMode | The model serving mode to be used. May be `on-demand`, or `dedicated`. | on-demand +| spring.ai.oci.genai.chat.cohere.options.preambleOverride | Override the chat model's prompt preamble | - +| spring.ai.oci.genai.chat.cohere.options.temperature | Inference temperature | - +| spring.ai.oci.genai.chat.cohere.options.topP | Top P parameter | - +| spring.ai.oci.genai.chat.cohere.options.topK | Top K parameter | - +| spring.ai.oci.genai.chat.cohere.options.frequencyPenalty | Higher values will reduce repeated tokens and outputs will be more random. | - +| spring.ai.oci.genai.chat.cohere.options.presencePenalty | Higher values encourage generating outputs with tokens that haven't been used. | - +| spring.ai.oci.genai.chat.cohere.options.stop | List of textual sequences that will end completions generation. | - +| spring.ai.oci.genai.chat.cohere.options.documents | List of documents used in chat context. | - +|==== + +TIP: All properties prefixed with `spring.ai.oci.genai.chat.cohere.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java[OCICohereChatOptions.java] provides model configurations, such as the model to use, the temperature, the frequency penalty, etc. + +On start-up, the default options can be configured with the `OCICohereChatModel(api, options)` constructor or the `spring.ai.oci.genai.chat.cohere.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default model and temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates.", + OCICohereChatOptions.builder() + .withModel("my-model-ocid") + .withCompartment("my-compartment-ocid") + .withTemperature(0.5) + .build() + )); +---- + +== Sample Controller + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-oci-genai-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the OCI GenAI Cohere chat model: + +[source,application.properties] +---- +spring.ai.oci.genai.authenticationType=file +spring.ai.oci.genai.file=/path/to/oci/config/file +spring.ai.oci.genai.cohere.chat.options.compartment=my-compartment-ocid +spring.ai.oci.genai.cohere.chat.options.servingMode=on-demand +spring.ai.oci.genai.cohere.chat.options.model=my-chat-model-ocid +---- + +TIP: replace the `file`, `compartment`, and `model` with your values from your OCI account. + +This will create a `OCICohereChatModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat model for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final OCICohereChatModel chatModel; + + @Autowired + public ChatController(OCICohereChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatModel.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + var prompt = new Prompt(new UserMessage(message)); + return chatModel.stream(prompt); + } +} +---- + +== Manual Configuration +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java[OCICohereChatModel] implements the `ChatModel` and uses the OCI Java SDK to connect to the OCI GenAI service. + +Add the `spring-ai-oci-genai` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-oci-genai + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-oci-genai' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create a `OCICohereChatModel` and use it for text generations: + +[source,java] +---- +var CONFIG_FILE = Paths.get(System.getProperty("user.home"), ".oci", "config").toString(); +var COMPARTMENT_ID = System.getenv("OCI_COMPARTMENT_ID"); +var MODEL_ID = System.getenv("OCI_CHAT_MODEL_ID"); + +ConfigFileAuthenticationDetailsProvider authProvider = new ConfigFileAuthenticationDetailsProvider( + CONFIG_FILE, + "DEFAULT" +); +var genAi = GenerativeAiInferenceClient.builder() + .region(Region.valueOf("us-chicago-1")) + .build(authProvider); + +var chatModel = new OCICohereChatModel(genAi, OCICohereChatOptions.builder() + .withModel(MODEL_ID) + .withCompartment(COMPARTMENT_ID) + .withServingMode("on-demand") + .build()); + +ChatResponse response = chatModel.call( + new Prompt("Generate the names of 5 famous pirates.")); +---- + +The `OCICohereChatOptions` provides the configuration information for the chat requests. +The `OCICohereChatOptions.Builder` is fluent options builder. diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 425674e77f..4ccc3c233e 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -407,6 +407,13 @@ test + + com.oracle.oci.sdk + oci-java-sdk-common + ${oci-sdk-version} + test + + org.springframework.boot spring-boot-starter-test diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCICohereChatModelProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCICohereChatModelProperties.java new file mode 100644 index 0000000000..58c7c8ef97 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCICohereChatModelProperties.java @@ -0,0 +1,59 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.oci.genai; + +import org.springframework.ai.oci.cohere.OCICohereChatOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Anders Swanson + */ +@ConfigurationProperties(OCICohereChatModelProperties.CONFIG_PREFIX) +public class OCICohereChatModelProperties { + + public static final String CONFIG_PREFIX = "spring.ai.oci.genai.cohere.chat"; + + private static final String DEFAULT_SERVING_MODE = ServingMode.ON_DEMAND.getMode(); + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + private boolean enabled; + + @NestedConfigurationProperty + private OCICohereChatOptions options = OCICohereChatOptions.builder() + .withServingMode(DEFAULT_SERVING_MODE) + .withTemperature(DEFAULT_TEMPERATURE) + .build(); + + public boolean isEnabled() { + return enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public OCICohereChatOptions getOptions() { + return this.options; + } + + public void setOptions(OCICohereChatOptions options) { + this.options = options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java index 681ee71f3c..d974607025 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfiguration.java @@ -28,8 +28,12 @@ import com.oracle.bmc.auth.okeworkloadidentity.OkeWorkloadIdentityAuthenticationDetailsProvider; import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient; import com.oracle.bmc.retrier.RetryConfiguration; +import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.oci.OCIEmbeddingModel; +import org.springframework.ai.oci.cohere.OCICohereChatModel; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -43,7 +47,8 @@ */ @AutoConfiguration @ConditionalOnClass({ GenerativeAiInferenceClient.class, OCIEmbeddingModel.class }) -@EnableConfigurationProperties({ OCIConnectionProperties.class, OCIEmbeddingModelProperties.class }) +@EnableConfigurationProperties({ OCIConnectionProperties.class, OCIEmbeddingModelProperties.class, + OCICohereChatModelProperties.class, }) public class OCIGenAiAutoConfiguration { private static BasicAuthenticationDetailsProvider authenticationProvider(OCIConnectionProperties properties) @@ -89,4 +94,17 @@ public OCIEmbeddingModel ociEmbeddingModel(GenerativeAiInferenceClient generativ return new OCIEmbeddingModel(generativeAiClient, properties.getEmbeddingOptions()); } + @Bean + @ConditionalOnProperty(prefix = OCICohereChatModelProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public OCICohereChatModel ociChatModel(GenerativeAiInferenceClient generativeAiClient, + OCICohereChatModelProperties properties, ObjectProvider observationRegistry, + ObjectProvider observationConvention) { + var chatModel = new OCICohereChatModel(generativeAiClient, properties.getOptions(), + observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + observationConvention.ifAvailable(chatModel::setObservationConvention); + + return chatModel; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAIAutoConfigurationTest.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAIAutoConfigurationTest.java new file mode 100644 index 0000000000..397a113f94 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAIAutoConfigurationTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.oci.genai; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.KeyPairGenerator; + +import com.oracle.bmc.http.client.pki.Pem; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.springframework.ai.oci.cohere.OCICohereChatModel; +import org.springframework.ai.oci.cohere.OCICohereChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +class OCIGenAIAutoConfigurationTest { + + @Test + void setProperties(@TempDir Path tempDir) throws Exception { + Path tmp = tempDir.resolve("my-key.pem"); + createPrivateKey(tmp); + ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.oci.genai.authenticationType=simple", + "spring.ai.oci.genai.userId=my-user", + "spring.ai.oci.genai.tenantId=my-tenant", + "spring.ai.oci.genai.fingerprint=xyz", + "spring.ai.oci.genai.privateKey=" + tmp.toAbsolutePath(), + "spring.ai.oci.genai.region=us-ashburn-1", + "spring.ai.oci.genai.cohere.chat.options.compartment=my-compartment", + "spring.ai.oci.genai.cohere.chat.options.servingMode=dedicated", + "spring.ai.oci.genai.cohere.chat.options.model=my-model", + "spring.ai.oci.genai.cohere.chat.options.maxTokens=1000", + "spring.ai.oci.genai.cohere.chat.options.temperature=0.5", + "spring.ai.oci.genai.cohere.chat.options.topP=0.8", + "spring.ai.oci.genai.cohere.chat.options.maxTokens=1000", + "spring.ai.oci.genai.cohere.chat.options.frequencyPenalty=0.1", + "spring.ai.oci.genai.cohere.chat.options.presencePenalty=0.2" + // @formatter:on + ).withConfiguration(AutoConfigurations.of(OCIGenAiAutoConfiguration.class)); + + contextRunner.run(context -> { + OCICohereChatModel chatModel = context.getBean(OCICohereChatModel.class); + assertThat(chatModel).isNotNull(); + OCICohereChatOptions options = (OCICohereChatOptions) chatModel.getDefaultOptions(); + assertThat(options.getCompartment()).isEqualTo("my-compartment"); + assertThat(options.getModel()).isEqualTo("my-model"); + assertThat(options.getServingMode()).isEqualTo("dedicated"); + assertThat(options.getMaxTokens()).isEqualTo(1000); + assertThat(options.getTemperature()).isEqualTo(0.5); + assertThat(options.getTopP()).isEqualTo(0.8); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.1); + assertThat(options.getPresencePenalty()).isEqualTo(0.2); + + OCIConnectionProperties props = context.getBean(OCIConnectionProperties.class); + assertThat(props.getAuthenticationType()).isEqualTo(OCIConnectionProperties.AuthenticationType.SIMPLE); + assertThat(props.getUserId()).isEqualTo("my-user"); + assertThat(props.getTenantId()).isEqualTo("my-tenant"); + assertThat(props.getFingerprint()).isEqualTo("xyz"); + assertThat(props.getPrivateKey()).isEqualTo(tmp.toAbsolutePath().toString()); + assertThat(props.getRegion()).isEqualTo("us-ashburn-1"); + + }); + } + + private void createPrivateKey(Path tmp) throws Exception { + KeyPairGenerator gen = KeyPairGenerator.getInstance("RSA"); + gen.initialize(2048); + KeyPair keyPair = gen.generateKeyPair(); + byte[] encoded = Pem.encoder().encode(keyPair.getPrivate()); + Files.write(tmp, encoded); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java index d23681cad1..fd282f51b3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/oci/genai/OCIGenAiAutoConfigurationIT.java @@ -25,6 +25,7 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.oci.OCIEmbeddingModel; +import org.springframework.ai.oci.cohere.OCICohereChatModel; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -35,11 +36,15 @@ public class OCIGenAiAutoConfigurationIT { public static final String COMPARTMENT_ID_KEY = "OCI_COMPARTMENT_ID"; + public static final String OCI_CHAT_MODEL_ID_KEY = "OCI_CHAT_MODEL_ID"; + private final String CONFIG_FILE = Paths.get(System.getProperty("user.home"), ".oci", "config").toString(); private final String COMPARTMENT_ID = System.getenv(COMPARTMENT_ID_KEY); - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( + private final String CHAT_MODEL_ID = System.getenv(OCI_CHAT_MODEL_ID_KEY); + + private final ApplicationContextRunner embeddingContextRunner = new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.oci.genai.authenticationType=file", "spring.ai.oci.genai.file=" + this.CONFIG_FILE, @@ -49,9 +54,19 @@ public class OCIGenAiAutoConfigurationIT { // @formatter:on ).withConfiguration(AutoConfigurations.of(OCIGenAiAutoConfiguration.class)); + private final ApplicationContextRunner cohereChatContextRunner = new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.oci.genai.authenticationType=file", + "spring.ai.oci.genai.file=" + CONFIG_FILE, + "spring.ai.oci.genai.cohere.chat.options.compartment=" + COMPARTMENT_ID, + "spring.ai.oci.genai.cohere.chat.options.servingMode=on-demand", + "spring.ai.oci.genai.cohere.chat.options.model=" + CHAT_MODEL_ID + // @formatter:on + ).withConfiguration(AutoConfigurations.of(OCIGenAiAutoConfiguration.class)); + @Test void embeddings() { - this.contextRunner.run(context -> { + this.embeddingContextRunner.run(context -> { OCIEmbeddingModel embeddingModel = context.getBean(OCIEmbeddingModel.class); assertThat(embeddingModel).isNotNull(); EmbeddingResponse response = embeddingModel @@ -61,4 +76,15 @@ void embeddings() { }); } + @Test + @EnabledIfEnvironmentVariable(named = OCIGenAiAutoConfigurationIT.OCI_CHAT_MODEL_ID_KEY, matches = ".+") + void cohereChat() { + this.cohereChatContextRunner.run(context -> { + OCICohereChatModel chatModel = context.getBean(OCICohereChatModel.class); + assertThat(chatModel).isNotNull(); + String response = chatModel.call("How many states are in the United States of America?"); + assertThat(response).isNotBlank(); + }); + } + }