diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index fc4da3e8bd..3213555145 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.Collectors; import io.micrometer.observation.Observation; @@ -54,10 +55,12 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.ChatModelObservationSupport; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.Content; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -77,6 +80,7 @@ * @author Mariusz Bernacki * @author Thomas Vitale * @author Claudio Silva Junior + * @author John Blum * @since 1.0.0 */ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel { @@ -209,28 +213,31 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul @Override public ChatResponse call(Prompt prompt) { + ChatCompletionRequest request = createRequest(prompt, false); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + Supplier observationContext = () -> ChatModelObservationContext.builder() .prompt(prompt) .provider(AnthropicApi.PROVIDER_NAME) .requestOptions(buildRequestOptions(request)) .build(); - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); + + ChatResponse response = observation.observe(() -> { - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.anthropicApi.chatCompletionEntity(request)); + ResponseEntity completionEntity = this.retryTemplate + .execute(ctx -> this.anthropicApi.chatCompletionEntity(request)); - ChatResponse chatResponse = toChatResponse(completionEntity.getBody()); + ChatResponse chatResponse = toChatResponse(completionEntity.getBody()); - observationContext.setResponse(chatResponse); + ChatModelObservationSupport.getObservationContext(observation) + .ifPresent(context -> context.setResponse(chatResponse)); - return chatResponse; - }); + return chatResponse; + }); if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null && this.isToolCall(response, Set.of("tool_use"))) { @@ -243,17 +250,19 @@ public ChatResponse call(Prompt prompt) { @Override public Flux stream(Prompt prompt) { + return Flux.deferContextual(contextView -> { + ChatCompletionRequest request = createRequest(prompt, true); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(AnthropicApi.PROVIDER_NAME) + Supplier observationContext = () -> ChatModelObservationContext.builder() .requestOptions(buildRequestOptions(request)) + .provider(AnthropicApi.PROVIDER_NAME) + .prompt(prompt) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); @@ -276,7 +285,8 @@ public Flux stream(Prompt prompt) { .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on - return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse); + return new MessageAggregator().aggregate(chatResponseFlux, + ChatModelObservationSupport.setChatResponseInObservationContext(observation)); }); } @@ -408,7 +418,7 @@ else if (message.getMessageType() == MessageType.TOOL) { String systemPrompt = prompt.getInstructions() .stream() .filter(m -> m.getMessageType() == MessageType.SYSTEM) - .map(m -> m.getContent()) + .map(Content::getContent) .collect(Collectors.joining(System.lineSeparator())); ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages, diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 203981d902..7f9de605d6 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -25,7 +25,9 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; @@ -78,6 +80,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.ChatModelObservationSupport; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; @@ -87,6 +90,7 @@ import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.util.ValueUtils; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -195,24 +199,24 @@ public AzureOpenAiChatOptions getDefaultOptions() { @Override public ChatResponse call(Prompt prompt) { - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + Supplier observationContext = () -> ChatModelObservationContext.builder() .prompt(prompt) .provider(AiProvider.AZURE_OPENAI.value()) .requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions) .build(); - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { - ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); - options.setStream(false); - - ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options); - ChatResponse chatResponse = toChatResponse(chatCompletions); - observationContext.setResponse(chatResponse); - return chatResponse; - }); + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); + + ChatResponse response = observation.observe(() -> { + ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt).setStream(false); + ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options); + ChatResponse chatResponse = toChatResponse(chatCompletions); + ChatModelObservationSupport.getObservationContext(observation) + .ifPresent(context -> context.setResponse(chatResponse)); + return chatResponse; + }); if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { @@ -229,24 +233,28 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS public Flux stream(Prompt prompt) { return Flux.deferContextual(contextView -> { - ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); - options.setStream(true); + + ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt).setStream(true); Flux chatCompletionsStream = this.openAIAsyncClient .getChatCompletionsStream(options.getModel(), options); + // @formatter:off // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. - ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + // TODO: Why is roleMap not used? I am guessing it should have served the same + // purpose as the roleMap in OpenAiChatModel.stream(:Prompt) + // @formatter:on + ConcurrentMap roleMap = new ConcurrentHashMap<>(); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) + Supplier observationContext = () -> ChatModelObservationContext.builder() + .requestOptions(ValueUtils.defaultIfNull(prompt.getOptions(), this.defaultOptions)) .provider(AiProvider.AZURE_OPENAI.value()) - .requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions) + .prompt(prompt) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); @@ -295,7 +303,8 @@ public Flux stream(Prompt prompt) { .doFinally(s -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - return new MessageAggregator().aggregate(flux, observationContext::setResponse); + return new MessageAggregator().aggregate(flux, + ChatModelObservationSupport.setChatResponseInObservationContext(observation)); }); }); diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index f32fe5d3c5..63f7777fa0 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -22,6 +22,8 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Supplier; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; @@ -46,6 +48,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.ChatModelObservationSupport; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; @@ -77,6 +80,7 @@ * backed by {@link MiniMaxApi}. * * @author Geng Rong + * @author John Blum * @see ChatModel * @see StreamingChatModel * @see MiniMaxApi @@ -209,62 +213,66 @@ private static Generation buildGeneration(Choice choice, Map met @Override public ChatResponse call(Prompt prompt) { + ChatCompletionRequest request = createRequest(prompt, false); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + Supplier observationContext = () -> ChatModelObservationContext.builder() .prompt(prompt) .provider(MiniMaxApiConstants.PROVIDER_NAME) .requestOptions(buildRequestOptions(request)) .build(); - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.miniMaxApi.chatCompletionEntity(request)); + ChatResponse response = observation.observe(() -> { - var chatCompletion = completionEntity.getBody(); + ResponseEntity completionEntity = this.retryTemplate + .execute(ctx -> this.miniMaxApi.chatCompletionEntity(request)); - if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + var chatCompletion = completionEntity.getBody(); - List choices = chatCompletion.choices(); - if (choices == null) { - logger.warn("No choices returned for prompt: {}, because: {}}", prompt, - chatCompletion.baseResponse().message()); - return new ChatResponse(List.of()); - } + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } - List generations = choices.stream().map(choice -> { - // @formatter:off - // if the choice is a web search tool call, return last message of choice.messages - ChatCompletionMessage message = null; - if (choice.message() != null) { - message = choice.message(); - } - else if (!CollectionUtils.isEmpty(choice.messages())) { - // the MiniMax web search messages result is ['user message','assistant tool call', 'tool call', 'assistant message'] - // so the last message is the assistant message - message = choice.messages().get(choice.messages().size() - 1); - } - Map metadata = Map.of( - "id", chatCompletion.id(), - "role", message != null && message.role() != null ? message.role().name() : "", - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); - // @formatter:on - return buildGeneration(message, choice.finishReason(), metadata); - }).toList(); + List choices = chatCompletion.choices(); - ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); + if (choices == null) { + logger.warn("No choices returned for prompt: {}, because: {}}", prompt, + chatCompletion.baseResponse().message()); + return new ChatResponse(List.of()); + } + + List generations = choices.stream().map(choice -> { + // @formatter:off + // if the choice is a web search tool call, return last message of choice.messages + ChatCompletionMessage message = null; + if (choice.message() != null) { + message = choice.message(); + } + else if (!CollectionUtils.isEmpty(choice.messages())) { + // the MiniMax web search messages result is ['user message','assistant tool call', 'tool call', 'assistant message'] + // so the last message is the assistant message + message = choice.messages().get(choice.messages().size() - 1); + } + Map metadata = Map.of( + "id", chatCompletion.id(), + "role", message != null && message.role() != null ? message.role().name() : "", + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); + // @formatter:on + return buildGeneration(message, choice.finishReason(), metadata); + }).toList(); - observationContext.setResponse(chatResponse); + ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - return chatResponse; - }); + ChatModelObservationSupport.getObservationContext(observation) + .ifPresent(context -> context.setResponse(chatResponse)); + + return chatResponse; + }); if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { @@ -284,24 +292,26 @@ public ChatOptions getDefaultOptions() { @Override public Flux stream(Prompt prompt) { + return Flux.deferContextual(contextView -> { + ChatCompletionRequest request = createRequest(prompt, true); Flux completionChunks = this.retryTemplate - .execute(ctx -> this.miniMaxApi.chatCompletionStream(request)); + .execute(retryContext -> this.miniMaxApi.chatCompletionStream(request)); // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. - ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + ConcurrentMap roleMap = new ConcurrentHashMap<>(); - final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(MiniMaxApiConstants.PROVIDER_NAME) + Supplier observationContext = () -> ChatModelObservationContext.builder() .requestOptions(buildRequestOptions(request)) + .provider(MiniMaxApiConstants.PROVIDER_NAME) + .prompt(prompt) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); @@ -309,46 +319,48 @@ public Flux stream(Prompt prompt) { // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse // the function call handling logic. Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) - .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { + .switchMap(chatCompletion -> Mono.just(chatCompletion).map(justChatCompletion -> { try { @SuppressWarnings("null") - String id = chatCompletion2.id(); + String id = justChatCompletion.id(); // @formatter:off - List generations = chatCompletion2.choices().stream().map(choice -> { - if (choice.message().role() != null) { - roleMap.putIfAbsent(id, choice.message().role().name()); - } - Map metadata = Map.of( - "id", chatCompletion2.id(), - "role", roleMap.getOrDefault(id, ""), - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); - return buildGeneration(choice, metadata); - }).toList(); - return new ChatResponse(generations, from(chatCompletion2)); + List generations = justChatCompletion.choices().stream().map(choice -> { + Role role = choice.message().role(); + if (role != null) { + roleMap.putIfAbsent(id, role.name()); + } + Map metadata = Map.of( + "id", id, + "role", roleMap.getOrDefault(id, ""), + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); + return buildGeneration(choice, metadata); + }).toList(); + // @formatter:on + return new ChatResponse(generations, from(justChatCompletion)); } catch (Exception e) { - logger.error("Error processing chat completion", e); - return new ChatResponse(List.of()); - } - })); + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } + })); Flux flux = chatResponse.flatMap(response -> { - if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, - Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { - var toolCallConversation = handleToolCalls(prompt, response); - // Recursively call the stream method with the tool call message - // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); - } - return Flux.just(response); - }) - .doOnError(observation::error) - .doFinally(signalType -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - // @formatter:on - - return new MessageAggregator().aggregate(flux, observationContext::setResponse); + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, + Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { + var toolCallConversation = handleToolCalls(prompt, response); + // Recursively call the stream method with the tool call message + // conversation that contains the call responses. + return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + } + return Flux.just(response); + }) + .doOnError(observation::error) + .doFinally(signalType -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + + return new MessageAggregator().aggregate(flux, + ChatModelObservationSupport.setChatResponseInObservationContext(observation)); }); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 25657ec39b..fb1278aa03 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -21,6 +21,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; @@ -44,6 +45,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.ChatModelObservationSupport; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; @@ -75,6 +77,7 @@ * @author Grogdunn * @author Thomas Vitale * @author luocongqiu + * @author John Blum * @since 1.0.0 */ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel { @@ -161,44 +164,46 @@ public ChatResponse call(Prompt prompt) { MistralAiApi.ChatCompletionRequest request = createRequest(prompt, false); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + Supplier observationContext = () -> ChatModelObservationContext.builder() .prompt(prompt) .provider(MistralAiApi.PROVIDER_NAME) .requestOptions(buildRequestOptions(request)) .build(); - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.mistralAiApi.chatCompletionEntity(request)); + ChatResponse response = observation.observe(() -> { - ChatCompletion chatCompletion = completionEntity.getBody(); + ResponseEntity completionEntity = this.retryTemplate + .execute(ctx -> this.mistralAiApi.chatCompletionEntity(request)); - if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + ChatCompletion chatCompletion = completionEntity.getBody(); - List generations = chatCompletion.choices().stream().map(choice -> { - // @formatter:off + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + List generations = chatCompletion.choices().stream().map(choice -> { + // @formatter:off Map metadata = Map.of( "id", chatCompletion.id() != null ? chatCompletion.id() : "", "index", choice.index(), "role", choice.message().role() != null ? choice.message().role().name() : "", "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); // @formatter:on - return buildGeneration(choice, metadata); - }).toList(); + return buildGeneration(choice, metadata); + }).toList(); - ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); + ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - observationContext.setResponse(chatResponse); + ChatModelObservationSupport.getObservationContext(observation) + .ifPresent(context -> context.setResponse(chatResponse)); - return chatResponse; - }); + return chatResponse; + }); if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), @@ -214,17 +219,19 @@ && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALL @Override public Flux stream(Prompt prompt) { + return Flux.deferContextual(contextView -> { + var request = createRequest(prompt, true); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(MistralAiApi.PROVIDER_NAME) + Supplier observationContext = () -> ChatModelObservationContext.builder() .requestOptions(buildRequestOptions(request)) + .provider(MistralAiApi.PROVIDER_NAME) + .prompt(prompt) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); @@ -284,11 +291,12 @@ public Flux stream(Prompt prompt) { } }) .doOnError(observation::error) - .doFinally(s -> observation.stop()) + .doFinally(signalType -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on; - return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse); + return new MessageAggregator().aggregate(chatResponseFlux, + ChatModelObservationSupport.setChatResponseInObservationContext(observation)); }); } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index aa76d5fa9c..22edc0a2b3 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -21,6 +21,8 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Supplier; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; @@ -45,6 +47,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.ChatModelObservationSupport; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; @@ -73,6 +76,7 @@ /** * @author Geng Rong + * @author John Blum */ public class MoonshotChatModel extends AbstractToolCallSupport implements ChatModel, StreamingChatModel { @@ -177,35 +181,37 @@ private static Generation buildGeneration(Choice choice, Map met @Override public ChatResponse call(Prompt prompt) { + ChatCompletionRequest request = createRequest(prompt, false); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + Supplier observationContext = () -> ChatModelObservationContext.builder() .prompt(prompt) .provider(MoonshotConstants.PROVIDER_NAME) .requestOptions(buildRequestOptions(request)) .build(); - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.moonshotApi.chatCompletionEntity(request)); + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); - var chatCompletion = completionEntity.getBody(); + ChatResponse response = observation.observe(() -> { + ResponseEntity completionEntity = this.retryTemplate + .execute(ctx -> this.moonshotApi.chatCompletionEntity(request)); - if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + var chatCompletion = completionEntity.getBody(); - List choices = chatCompletion.choices(); - if (choices == null) { - logger.warn("No choices returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } - List generations = choices.stream().map(choice -> { + List choices = chatCompletion.choices(); + if (choices == null) { + logger.warn("No choices returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + List generations = choices.stream().map(choice -> { // @formatter:off Map metadata = Map.of( "id", chatCompletion.id(), @@ -213,15 +219,16 @@ public ChatResponse call(Prompt prompt) { "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" ); // @formatter:on - return buildGeneration(choice, metadata); - }).toList(); + return buildGeneration(choice, metadata); + }).toList(); - ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); + ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - observationContext.setResponse(chatResponse); + ChatModelObservationSupport.getObservationContext(observation) + .ifPresent(context -> context.setResponse(chatResponse)); - return chatResponse; - }); + return chatResponse; + }); if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS.name(), @@ -241,7 +248,9 @@ public ChatOptions getDefaultOptions() { @Override public Flux stream(Prompt prompt) { + return Flux.deferContextual(contextView -> { + ChatCompletionRequest request = createRequest(prompt, true); Flux completionChunks = this.retryTemplate @@ -249,16 +258,16 @@ public Flux stream(Prompt prompt) { // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. - ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + ConcurrentMap roleMap = new ConcurrentHashMap<>(); - final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(MoonshotConstants.PROVIDER_NAME) + Supplier observationContext = () -> ChatModelObservationContext.builder() .requestOptions(buildRequestOptions(request)) + .provider(MoonshotConstants.PROVIDER_NAME) + .prompt(prompt) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); @@ -308,7 +317,8 @@ public Flux stream(Prompt prompt) { .doFinally(signalType -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - return new MessageAggregator().aggregate(flux, observationContext::setResponse); + return new MessageAggregator().aggregate(flux, + ChatModelObservationSupport.setChatResponseInObservationContext(observation)); }); } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index f4fcd722f1..498ff6fba8 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; @@ -41,6 +42,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.ChatModelObservationSupport; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; @@ -74,6 +76,7 @@ * @author Christian Tzolov * @author luocongqiu * @author Thomas Vitale + * @author John Blum * @since 1.0.0 */ public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel { @@ -130,42 +133,46 @@ public ChatResponse call(Prompt prompt) { OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + Supplier observationContext = () -> ChatModelObservationContext.builder() .prompt(prompt) .provider(OllamaApi.PROVIDER_NAME) .requestOptions(buildRequestOptions(request)) .build(); - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); - OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request); + ChatResponse response = observation.observe(() -> { - List toolCalls = ollamaResponse.message().toolCalls() == null ? List.of() - : ollamaResponse.message() - .toolCalls() - .stream() - .map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), - ModelOptionsUtils.toJsonString(toolCall.function().arguments()))) - .toList(); + OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request); - var assistantMessage = new AssistantMessage(ollamaResponse.message().content(), Map.of(), toolCalls); + List toolCalls = ollamaResponse.message().toolCalls() == null ? List.of() + : ollamaResponse.message() + .toolCalls() + .stream() + .map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), + ModelOptionsUtils.toJsonString(toolCall.function().arguments()))) + .toList(); - ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; - if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) { - generationMetadata = ChatGenerationMetadata.from(ollamaResponse.doneReason(), null); - } + var assistantMessage = new AssistantMessage(ollamaResponse.message().content(), Map.of(), toolCalls); - var generator = new Generation(assistantMessage, generationMetadata); - ChatResponse chatResponse = new ChatResponse(List.of(generator), from(ollamaResponse)); + ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; - observationContext.setResponse(chatResponse); + if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) { + generationMetadata = ChatGenerationMetadata.from(ollamaResponse.doneReason(), null); + } - return chatResponse; + var generator = new Generation(assistantMessage, generationMetadata); - }); + ChatResponse chatResponse = new ChatResponse(List.of(generator), from(ollamaResponse)); + + ChatModelObservationSupport.getObservationContext(observation) + .ifPresent(context -> context.setResponse(chatResponse)); + + return chatResponse; + + }); if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null && isToolCall(response, Set.of("stop"))) { @@ -180,17 +187,19 @@ && isToolCall(response, Set.of("stop"))) { @Override public Flux stream(Prompt prompt) { + return Flux.deferContextual(contextView -> { + OllamaApi.ChatRequest request = ollamaChatRequest(prompt, true); - final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(OllamaApi.PROVIDER_NAME) + Supplier observationContext = () -> ChatModelObservationContext.builder() .requestOptions(buildRequestOptions(request)) + .provider(OllamaApi.PROVIDER_NAME) + .prompt(prompt) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); @@ -236,13 +245,12 @@ public Flux stream(Prompt prompt) { } }) .doOnError(observation::error) - .doFinally(s -> - observation.stop() - ) + .doFinally(signalType -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on - return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse); + return new MessageAggregator().aggregate(chatResponseFlux, + ChatModelObservationSupport.setChatResponseInObservationContext(observation)); }); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index f21a064ee6..b3cacb80ba 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -17,13 +17,16 @@ package org.springframework.ai.openai; import java.util.ArrayList; -import java.util.Base64; +import java.util.Collections; +import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Supplier; import java.util.stream.Collectors; import io.micrometer.observation.Observation; @@ -51,6 +54,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.ChatModelObservationSupport; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; @@ -71,13 +75,12 @@ import org.springframework.ai.openai.metadata.OpenAiUsage; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.util.ValueUtils; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.MimeType; import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} @@ -104,6 +107,9 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + private static final EnumSet USER_SYSTEM_MESSAGE_TYPE_SET = EnumSet.of(MessageType.USER, + MessageType.SYSTEM); + /** * The default options used for the chat completion requests. */ @@ -213,55 +219,55 @@ public ChatResponse call(Prompt prompt) { ChatCompletionRequest request = createRequest(prompt, false); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(OpenAiApiConstants.PROVIDER_NAME) + Supplier observationContext = () -> ChatModelObservationContext.builder() .requestOptions(buildRequestOptions(request)) + .provider(OpenAiApiConstants.PROVIDER_NAME) + .prompt(prompt) .build(); - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt))); + ChatResponse response = observation.observe(() -> { - var chatCompletion = completionEntity.getBody(); + ResponseEntity completionEntity = this.retryTemplate + .execute(context -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt))); - if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + var chatCompletion = completionEntity.getBody(); - List choices = chatCompletion.choices(); - if (choices == null) { - logger.warn("No choices returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(Collections.emptyList()); + } + + List choices = chatCompletion.choices(); + + if (choices == null) { + logger.warn("No choices returned for prompt: {}", prompt); + return new ChatResponse(Collections.emptyList()); + } - List generations = choices.stream().map(choice -> { // @formatter:off - Map metadata = Map.of( - "id", chatCompletion.id() != null ? chatCompletion.id() : "", - "role", choice.message().role() != null ? choice.message().role().name() : "", - "index", choice.index(), - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "", - "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : ""); - // @formatter:on - return buildGeneration(choice, metadata); - }).toList(); + List generations = choices.stream().map(choice -> { + Map metadata = choiceMetadata(choice, chatCompletion.id(), + ValueUtils.defaultToEmptyString(choice.message().role(), OpenAiApi.ChatCompletionMessage.Role::name)); - // Non function calling. - RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); + return buildGeneration(choice, metadata); + }).toList(); + // @formatter:on - ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody(), rateLimit)); + // Non function calling. + RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); - observationContext.setResponse(chatResponse); + ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody(), rateLimit)); - return chatResponse; + ChatModelObservationSupport.getObservationContext(observation) + .ifPresent(context -> context.setResponse(chatResponse)); - }); + return chatResponse; + + }); if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), @@ -277,7 +283,9 @@ && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.n @Override public Flux stream(Prompt prompt) { + return Flux.deferContextual(contextView -> { + ChatCompletionRequest request = createRequest(prompt, true); Flux completionChunks = this.openAiApi.chatCompletionStream(request, @@ -285,16 +293,16 @@ public Flux stream(Prompt prompt) { // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. - ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + ConcurrentMap roleMap = new ConcurrentHashMap<>(); - final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(OpenAiApiConstants.PROVIDER_NAME) + Supplier observationContext = () -> ChatModelObservationContext.builder() .requestOptions(buildRequestOptions(request)) + .provider(OpenAiApiConstants.PROVIDER_NAME) + .prompt(prompt) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); @@ -302,27 +310,22 @@ public Flux stream(Prompt prompt) { // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse // the function call handling logic. Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) - .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { + .switchMap(chatCompletion -> Mono.just(chatCompletion).map(resolvedChatCompletion -> { try { @SuppressWarnings("null") - String id = chatCompletion2.id(); + String id = resolvedChatCompletion.id(); - List generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:off + List generations = resolvedChatCompletion.choices().stream().map(choice -> { if (choice.message().role() != null) { roleMap.putIfAbsent(id, choice.message().role().name()); } - Map metadata = Map.of( - "id", chatCompletion2.id(), - "role", roleMap.getOrDefault(id, ""), - "index", choice.index(), - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "", - "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : ""); + + Map metadata = choiceMetadata(choice, id, roleMap.getOrDefault(id, "")); return buildGeneration(choice, metadata); }).toList(); - // @formatter:on - return new ChatResponse(generations, from(chatCompletion2, null)); + return new ChatResponse(generations, from(resolvedChatCompletion, null)); } catch (Exception e) { logger.error("Error processing chat completion", e); @@ -332,29 +335,44 @@ public Flux stream(Prompt prompt) { })); // @formatter:off - Flux flux = chatResponse.flatMap(response -> { - - if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { - var toolCallConversation = handleToolCalls(prompt, response); - // Recursively call the stream method with the tool call message - // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); - } - else { - return Flux.just(response); - } - }) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + Flux flux = chatResponse + .flatMap(response -> { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { + var toolCallConversation = handleToolCalls(prompt, response); + // Recursively call the stream method with the tool call message + // conversation that contains the call responses. + return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + } + else { + return Flux.just(response); + } + }) + .doOnError(observation::error) + .doFinally(signalType -> observation.stop()) + .contextWrite(context -> context.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on - return new MessageAggregator().aggregate(flux, observationContext::setResponse); - + return new MessageAggregator().aggregate(flux, + ChatModelObservationSupport.setChatResponseInObservationContext(observation)); }); } + private Map choiceMetadata(Choice choice, String id, String roleName) { + + // @formatter:off + return Map.of( + "id", ValueUtils.defaultToEmptyString(id), + "role", roleName, + "index", choice.index(), + "finishReason", ValueUtils.defaultToEmptyString(choice.finishReason(), + OpenAiApi.ChatCompletionFinishReason::name), + "refusal", ValueUtils.defaultToEmptyString(choice.message().refusal()) + ); + // @formatter:on + } + private MultiValueMap getAdditionalHttpHeaders(Prompt prompt) { Map headers = new HashMap<>(this.defaultOptions.getHttpHeaders()); @@ -416,7 +434,8 @@ private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionC ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { List chatCompletionMessages = prompt.getInstructions().stream().map(message -> { - if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) { + MessageType messageType = message.getMessageType(); + if (USER_SYSTEM_MESSAGE_TYPE_SET.contains(messageType)) { Object content = message.getContent(); if (message instanceof UserMessage userMessage) { if (!CollectionUtils.isEmpty(userMessage.getMedia())) { @@ -425,8 +444,7 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { contentList.addAll(userMessage.getMedia() .stream() - .map(media -> new MediaContent(new MediaContent.ImageUrl( - this.fromMediaData(media.getMimeType(), media.getData())))) + .map(OpenAiApi.MediaConverter.INSTANCE::convert) .toList()); content = contentList; @@ -436,7 +454,7 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { return List.of(new ChatCompletionMessage(content, ChatCompletionMessage.Role.valueOf(message.getMessageType().name()))); } - else if (message.getMessageType() == MessageType.ASSISTANT) { + else if (MessageType.ASSISTANT.equals(messageType)) { var assistantMessage = (AssistantMessage) message; List toolCalls = null; if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { @@ -445,18 +463,26 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { return new ToolCall(toolCall.id(), toolCall.type(), function); }).toList(); } - return List.of(new ChatCompletionMessage(assistantMessage.getContent(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null)); + return List.of(ChatCompletionMessage.builder() + .rawContent(assistantMessage.getContent()) + .role(ChatCompletionMessage.Role.ASSISTANT) + .toolCalls(toolCalls) + .build()); } - else if (message.getMessageType() == MessageType.TOOL) { + else if (MessageType.TOOL.equals(messageType)) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; toolMessage.getResponses() - .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); + .forEach(response -> Assert.notNull(response.id(), "ToolResponseMessage must have an id")); + return toolMessage.getResponses() .stream() - .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(), - tr.id(), null, null)) + .map(toolResponse -> ChatCompletionMessage.builder() + .rawContent(toolResponse.responseData()) + .role(ChatCompletionMessage.Role.TOOL) + .name(toolResponse.name()) + .toolCallId(toolResponse.id()) + .build()) .toList(); } else { @@ -493,7 +519,6 @@ else if (prompt.getOptions() instanceof OpenAiChatOptions) { // Add the enabled functions definitions to the request's tools parameter. if (!CollectionUtils.isEmpty(enabledToolsToUse)) { - request = ModelOptionsUtils.merge( OpenAiChatOptions.builder().withTools(this.getFunctionTools(enabledToolsToUse)).build(), request, ChatCompletionRequest.class); @@ -508,22 +533,6 @@ else if (prompt.getOptions() instanceof OpenAiChatOptions) { return request; } - private String fromMediaData(MimeType mimeType, Object mediaContentData) { - if (mediaContentData instanceof byte[] bytes) { - // Assume the bytes are an image. So, convert the bytes to a base64 encoded - // following the prefix pattern. - return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes)); - } - else if (mediaContentData instanceof String text) { - // Assume the text is a URLs or a base64 encoded image prefixed by the user. - return text; - } - else { - throw new IllegalArgumentException( - "Unsupported media data type: " + mediaContentData.getClass().getSimpleName()); - } - } - private List getFunctionTools(Set functionNames) { return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { var function = new OpenAiApi.FunctionTool.Function(functionCallback.getDescription(), diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 173fbf866c..46e548cb15 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -16,6 +16,7 @@ package org.springframework.ai.openai.api; +import java.util.Base64; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; @@ -29,16 +30,19 @@ import reactor.core.publisher.Mono; import org.springframework.ai.model.ChatModelDescription; +import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MimeType; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; @@ -56,6 +60,7 @@ * @author Mariusz Bernacki * @author Thomas Vitale * @author David Frizelle + * @author John Blum */ public class OpenAiApi { @@ -984,6 +989,10 @@ public ChatCompletionMessage(Object content, Role role) { this(content, role, null, null, null, null); } + public static Builder builder() { + return new ChatCompletionMessage.Builder(); + } + /** * Get message content as String. */ @@ -1112,6 +1121,57 @@ public record ChatCompletionFunction(// @formatter:off @JsonProperty("arguments") String arguments) { // @formatter:on } + public static class Builder { + + private Object rawContent; + + private Role role; + + private String name; + + private String toolCallId; + + private String refusal; + + private List toolCalls; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder rawContent(Object rawContent) { + this.rawContent = rawContent; + return this; + } + + public Builder refusal(String refusal) { + this.refusal = refusal; + return this; + } + + public Builder role(Role role) { + this.role = role; + return this; + } + + public Builder toolCallId(String toolCallId) { + this.toolCallId = toolCallId; + return this; + } + + public Builder toolCalls(List toolCalls) { + this.toolCalls = List.copyOf(toolCalls); + return this; + } + + public ChatCompletionMessage build() { + return new ChatCompletionMessage(this.rawContent, this.role, this.name, this.toolCallId, this.toolCalls, + this.refusal); + } + + } + } /** @@ -1391,4 +1451,38 @@ public record EmbeddingList(// @formatter:off @JsonProperty("usage") Usage usage) { // @formatter:on } + public static class MediaConverter implements Converter { + + public static final MediaConverter INSTANCE = new MediaConverter(); + + @Override + public ChatCompletionMessage.MediaContent convert(Media media) { + String url = fromMediaData(media.getMimeType(), media.getData()); + ChatCompletionMessage.MediaContent.ImageUrl imageUrl = new ChatCompletionMessage.MediaContent.ImageUrl(url); + return new ChatCompletionMessage.MediaContent(imageUrl); + } + + private String fromMediaData(MimeType mimeType, Object mediaContentData) { + + if (mediaContentData instanceof byte[] bytes) { + // Assume the bytes are an image. So, convert the bytes to a base64 + // encoded + // following the prefix pattern. + return String.format("data:%s;base64,%s", mimeType.toString(), + Base64.getEncoder().encodeToString(bytes)); + } + else if (mediaContentData instanceof String text) { + // Assume the text is a URLs or a base64 encoded image prefixed by the + // user. + return text; + } + else { + throw new IllegalArgumentException( + "Unsupported media data type: " + mediaContentData.getClass().getSimpleName()); + } + + } + + } + } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java index 944a1c4e1f..adcc0c5ea3 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java @@ -19,6 +19,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; @@ -39,6 +40,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.ChatModelObservationSupport; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; @@ -62,6 +64,7 @@ * backed by {@link QianFanApi}. * * @author Geng Rong + * @author John Blum * @see ChatModel * @see StreamingChatModel * @see QianFanApi @@ -155,56 +158,59 @@ public ChatResponse call(Prompt prompt) { ChatCompletionRequest request = createRequest(prompt, false); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + Supplier observationContext = () -> ChatModelObservationContext.builder() .prompt(prompt) .provider(QianFanConstants.PROVIDER_NAME) .requestOptions(buildRequestOptions(request)) .build(); - return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.qianFanApi.chatCompletionEntity(request)); + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); - var chatCompletion = completionEntity.getBody(); - if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + return observation.observe(() -> { + ResponseEntity completionEntity = this.retryTemplate + .execute(ctx -> this.qianFanApi.chatCompletionEntity(request)); + + var chatCompletion = completionEntity.getBody(); + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } // @formatter:off - Map metadata = Map.of( - "id", chatCompletion.id(), - "role", Role.ASSISTANT - ); - // @formatter:on - - var assistantMessage = new AssistantMessage(chatCompletion.result(), metadata); - List generations = Collections.singletonList(new Generation(assistantMessage)); - ChatResponse chatResponse = new ChatResponse(generations, from(chatCompletion, request.model())); - observationContext.setResponse(chatResponse); - return chatResponse; - }); + Map metadata = Map.of( + "id", chatCompletion.id(), + "role", Role.ASSISTANT + ); + // @formatter:on + + var assistantMessage = new AssistantMessage(chatCompletion.result(), metadata); + List generations = Collections.singletonList(new Generation(assistantMessage)); + ChatResponse chatResponse = new ChatResponse(generations, from(chatCompletion, request.model())); + ChatModelObservationSupport.getObservationContext(observation) + .ifPresent(context -> context.setResponse(chatResponse)); + return chatResponse; + }); } @Override public Flux stream(Prompt prompt) { return Flux.deferContextual(contextView -> { + ChatCompletionRequest request = createRequest(prompt, true); var completionChunks = this.qianFanApi.chatCompletionStream(request); - final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(QianFanConstants.PROVIDER_NAME) + Supplier observationContext = () -> ChatModelObservationContext.builder() .requestOptions(buildRequestOptions(request)) + .provider(QianFanConstants.PROVIDER_NAME) + .prompt(prompt) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); @@ -225,7 +231,9 @@ public Flux stream(Prompt prompt) { .doOnError(observation::error) .doFinally(s -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - return new MessageAggregator().aggregate(chatResponse, observationContext::setResponse); + + return new MessageAggregator().aggregate(chatResponse, + ChatModelObservationSupport.setChatResponseInObservationContext(observation)); }); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 956225dbd9..d69f4025ea 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -64,6 +65,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.ChatModelObservationSupport; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; @@ -90,6 +92,7 @@ * @author Chris Turchin * @author Mark Pollack * @author Soby Chacko + * @author John Blum * @since 0.8.1 */ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements ChatModel, DisposableBean { @@ -285,33 +288,35 @@ public ChatResponse call(Prompt prompt) { VertexAiGeminiChatOptions vertexAiGeminiChatOptions = vertexAiGeminiChatOptions(prompt); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + Supplier observationContext = () -> ChatModelObservationContext.builder() .prompt(prompt) .provider(VertexAiGeminiConstants.PROVIDER_NAME) .requestOptions(vertexAiGeminiChatOptions) .build(); - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> this.retryTemplate.execute(context -> { + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); - var geminiRequest = createGeminiRequest(prompt, vertexAiGeminiChatOptions); + ChatResponse response = observation.observe(() -> this.retryTemplate.execute(context -> { - GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); + var geminiRequest = createGeminiRequest(prompt, vertexAiGeminiChatOptions); - List generations = generateContentResponse.getCandidatesList() - .stream() - .map(this::responseCandiateToGeneration) - .flatMap(List::stream) - .toList(); + GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); + + List generations = generateContentResponse.getCandidatesList() + .stream() + .map(this::responseCandiateToGeneration) + .flatMap(List::stream) + .toList(); - ChatResponse chatResponse = new ChatResponse(generations, - toChatResponseMetadata(generateContentResponse)); + ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(generateContentResponse)); - observationContext.setResponse(chatResponse); - return chatResponse; - })); + ChatModelObservationSupport.getObservationContext(observation) + .ifPresent(it -> it.setResponse(chatResponse)); + + return chatResponse; + })); if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(FinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); @@ -326,20 +331,23 @@ public ChatResponse call(Prompt prompt) { @Override public Flux stream(Prompt prompt) { + return Flux.deferContextual(contextView -> { + VertexAiGeminiChatOptions vertexAiGeminiChatOptions = vertexAiGeminiChatOptions(prompt); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(VertexAiGeminiConstants.PROVIDER_NAME) + Supplier observationContext = () -> ChatModelObservationContext.builder() .requestOptions(vertexAiGeminiChatOptions) + .provider(VertexAiGeminiConstants.PROVIDER_NAME) + .prompt(prompt) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); + var request = createGeminiRequest(prompt, vertexAiGeminiChatOptions); try { @@ -369,7 +377,8 @@ public Flux stream(Prompt prompt) { .doFinally(s -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse); + return new MessageAggregator().aggregate(chatResponseFlux, + ChatModelObservationSupport.setChatResponseInObservationContext(observation)); }); } catch (Exception e) { @@ -545,14 +554,12 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) { private List toGeminiContent(List instrucitons) { - List contents = instrucitons.stream() + return instrucitons.stream() .map(message -> Content.newBuilder() .setRole(toGeminiMessageType(message.getMessageType()).getValue()) .addAllParts(messageToGeminiParts(message)) .build()) .toList(); - - return contents; } private List getFunctionTools(Set functionNames) { diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 7da150c62e..b3a9f5a4be 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -23,6 +23,8 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Supplier; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; @@ -48,6 +50,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.ChatModelObservationSupport; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; @@ -81,6 +84,7 @@ * backed by {@link ZhiPuAiApi}. * * @author Geng Rong + * @author John Blum * @see ChatModel * @see StreamingChatModel * @see ZhiPuAiApi @@ -194,32 +198,34 @@ private static Generation buildGeneration(Choice choice, Map met @Override public ChatResponse call(Prompt prompt) { + ChatCompletionRequest request = createRequest(prompt, false); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + Supplier observationContext = () -> ChatModelObservationContext.builder() .prompt(prompt) .provider(ZhiPuApiConstants.PROVIDER_NAME) .requestOptions(buildRequestOptions(request)) .build(); - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.zhiPuAiApi.chatCompletionEntity(request)); + ChatResponse response = observation.observe(() -> { - var chatCompletion = completionEntity.getBody(); + ResponseEntity completionEntity = this.retryTemplate + .execute(ctx -> this.zhiPuAiApi.chatCompletionEntity(request)); - if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + var chatCompletion = completionEntity.getBody(); + + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } - List choices = chatCompletion.choices(); + List choices = chatCompletion.choices(); - List generations = choices.stream().map(choice -> { + List generations = choices.stream().map(choice -> { // @formatter:off Map metadata = Map.of( "id", chatCompletion.id(), @@ -227,15 +233,17 @@ public ChatResponse call(Prompt prompt) { "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" ); // @formatter:on - return buildGeneration(choice, metadata); - }).toList(); + return buildGeneration(choice, metadata); + }).toList(); + + ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); + ChatModelObservationSupport.getObservationContext(observation) + .ifPresent(context -> context.setResponse(chatResponse)); - observationContext.setResponse(chatResponse); + return chatResponse; + }); - return chatResponse; - }); if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); @@ -254,7 +262,9 @@ public ChatOptions getDefaultOptions() { @Override public Flux stream(Prompt prompt) { + return Flux.deferContextual(contextView -> { + ChatCompletionRequest request = createRequest(prompt, true); Flux completionChunks = this.retryTemplate @@ -262,16 +272,16 @@ public Flux stream(Prompt prompt) { // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. - ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + ConcurrentMap roleMap = new ConcurrentHashMap<>(); - final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(ZhiPuApiConstants.PROVIDER_NAME) + Supplier observationContext = () -> ChatModelObservationContext.builder() .requestOptions(buildRequestOptions(request)) + .provider(ZhiPuApiConstants.PROVIDER_NAME) + .prompt(prompt) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); @@ -319,7 +329,8 @@ public Flux stream(Prompt prompt) { .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on - return new MessageAggregator().aggregate(flux, observationContext::setResponse); + return new MessageAggregator().aggregate(flux, + ChatModelObservationSupport.setChatResponseInObservationContext(observation)); }); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationSupport.java new file mode 100644 index 0000000000..d75ce9c1f0 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationSupport.java @@ -0,0 +1,55 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.observation; + +import java.util.Optional; +import java.util.function.Consumer; + +import io.micrometer.observation.Observation; + +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.lang.Nullable; + +/** + * Support class for processing {@link ChatModel} Micrometer {@link Observation + * Observations}. + * + * @author John Blum + * @see ChatModel + * @see Observation + * @since 1.0.0 + */ +public abstract class ChatModelObservationSupport { + + public static Optional getObservationContext(@Nullable Observation observation) { + + // Avoid unnecessary construction of an Optional if Observations are not enabled + // (aka NOOP). + return Observation.NOOP.equals(observation) ? Optional.empty() + : Optional.ofNullable(observation) + .map(Observation::getContext) + .filter(ChatModelObservationContext.class::isInstance) + .map(ChatModelObservationContext.class::cast); + } + + public static Consumer setChatResponseInObservationContext(@Nullable Observation observation) { + return chatResponse -> getObservationContext(observation) + .ifPresent(context -> context.setResponse(chatResponse)); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/ValueUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/util/ValueUtils.java new file mode 100644 index 0000000000..47ac1bb073 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/ValueUtils.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.util; + +import java.util.function.Function; + +import org.springframework.util.StringUtils; + +/** + * Abstract utility class for process values. + * + * @author John Blum + * @since 1.0.0 + */ +@SuppressWarnings("unused") +public abstract class ValueUtils { + + protected static final String EMPTY_STRING = ""; + + public static T defaultIfNull(T value, T defaultValue) { + return value != null ? value : defaultValue; + } + + public static String defaultToEmptyString(T target, Function transform) { + String value = target != null ? transform.apply(target) : null; + return defaultToEmptyString(value); + } + + public static String defaultToEmptyString(String value) { + return StringUtils.hasText(value) ? value : EMPTY_STRING; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java index 025f8e600f..bbc9f4d4c9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java @@ -18,7 +18,9 @@ import java.util.List; import java.util.Optional; +import java.util.function.Supplier; +import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.document.Document; @@ -28,6 +30,7 @@ /** * @author Christian Tzolov + * @author John Blum * @since 1.0.0 */ public abstract class AbstractObservationVectorStore implements VectorStore { @@ -48,45 +51,65 @@ public AbstractObservationVectorStore(ObservationRegistry observationRegistry, @Override public void add(List documents) { - VectorStoreObservationContext observationContext = this - .createObservationContextBuilder(VectorStoreObservationContext.Operation.ADD.value()) - .build(); + Supplier observationContext = observationContextSupplier( + VectorStoreObservationContext.Operation.ADD); VectorStoreObservationDocumentation.AI_VECTOR_STORE - .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry) - .observe(() -> this.doAdd(documents)); + .observe(() -> doAdd(documents)); } @Override + @SuppressWarnings("all") public Optional delete(List deleteDocIds) { - VectorStoreObservationContext observationContext = this - .createObservationContextBuilder(VectorStoreObservationContext.Operation.DELETE.value()) - .build(); + Supplier observationContext = observationContextSupplier( + VectorStoreObservationContext.Operation.DELETE); return VectorStoreObservationDocumentation.AI_VECTOR_STORE - .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, this.observationRegistry) - .observe(() -> this.doDelete(deleteDocIds)); + .observe(() -> doDelete(deleteDocIds)); } @Override + @SuppressWarnings("all") public List similaritySearch(SearchRequest request) { - VectorStoreObservationContext searchObservationContext = this - .createObservationContextBuilder(VectorStoreObservationContext.Operation.QUERY.value()) - .withQueryRequest(request) - .build(); + Supplier observationContext = observationContextSupplier( + VectorStoreObservationContext.Operation.QUERY, builder -> builder.withQueryRequest(request)); - return VectorStoreObservationDocumentation.AI_VECTOR_STORE - .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, - () -> searchObservationContext, this.observationRegistry) - .observe(() -> { - var documents = this.doSimilaritySearch(request); - searchObservationContext.setQueryResponse(documents); - return documents; - }); + Observation observation = VectorStoreObservationDocumentation.AI_VECTOR_STORE.observation( + this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext, + this.observationRegistry); + + return observation.observe(() -> { + var documents = doSimilaritySearch(request); + getObservationContext(observation).ifPresent(context -> context.setQueryResponse(documents)); + return documents; + }); + } + + private Optional getObservationContext(@Nullable Observation observation) { + + return Optional.ofNullable(observation) + .map(Observation::getContext) + .filter(VectorStoreObservationContext.class::isInstance) + .map(VectorStoreObservationContext.class::cast); + } + + private Supplier observationContextSupplier( + VectorStoreObservationContext.Operation operation) { + + return observationContextSupplier(operation, VectorStoreObservationContextBuilderCustomizer.IDENTITY); + } + + private Supplier observationContextSupplier( + VectorStoreObservationContext.Operation operation, + VectorStoreObservationContextBuilderCustomizer customizer) { + + return () -> customizer.customize(createObservationContextBuilder(operation.value())).build(); } public abstract void doAdd(List documents); @@ -97,4 +120,20 @@ public List similaritySearch(SearchRequest request) { public abstract VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName); + @FunctionalInterface + @SuppressWarnings("unused") + protected interface VectorStoreObservationContextBuilderCustomizer { + + VectorStoreObservationContextBuilderCustomizer IDENTITY = builder -> builder; + + VectorStoreObservationContext.Builder customize(VectorStoreObservationContext.Builder builder); + + default VectorStoreObservationContextBuilderCustomizer andThen( + @Nullable VectorStoreObservationContextBuilderCustomizer customizer) { + + return customizer != null ? builder -> customizer.customize(this.customize(builder)) : this; + } + + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationSupportTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationSupportTests.java new file mode 100644 index 0000000000..421fdb58af --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationSupportTests.java @@ -0,0 +1,142 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.observation; + +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import io.micrometer.observation.Observation; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Unit Tests for {@link ChatModelObservationSupport}. + * + * @author John Blum + */ +public class ChatModelObservationSupportTests { + + @Test + void getObservationContextIsNullSafe() { + assertThat(ChatModelObservationSupport.getObservationContext(null)).isEmpty(); + } + + @Test + void getObservationContextForNonChat() { + assertThat(ChatModelObservationSupport.getObservationContext(Observation.NOOP)).isEmpty(); + } + + @Test + void getObservationContextForChat() { + + ChatModelObservation mockObservation = spy(ChatModelObservation.class); + + assertThat(ChatModelObservationSupport.getObservationContext(mockObservation).orElse(null)) + .isInstanceOf(ChatModelObservationContext.class); + + verify(mockObservation, times(1)).getContext(); + verifyNoMoreInteractions(mockObservation); + } + + @Test + void setsChatResponseInObservationContext() { + + ChatModelObservation mockObservation = spy(ChatModelObservation.class); + ChatResponse mockChatResponse = mock(ChatResponse.class); + + Consumer chatResponseConsumer = ChatModelObservationSupport + .setChatResponseInObservationContext(mockObservation); + + assertThat(chatResponseConsumer).isNotNull(); + + chatResponseConsumer.accept(mockChatResponse); + + assertThat(mockObservation.getContext()).isNotNull() + .asInstanceOf(InstanceOfAssertFactories.type(ChatModelObservationContext.class)) + .extracting(ChatModelObservationContext::getResponse) + .isSameAs(mockChatResponse); + + verifyNoInteractions(mockChatResponse); + } + + @Test + void doesNotSetChatResponseInObservationContext() { + + ChatResponse mockChatResponse = mock(ChatResponse.class); + Observation mockObservation = mock(Observation.class); + Observation.Context mockContext = mock(Observation.Context.class); + + doReturn(mockContext).when(mockObservation).getContext(); + + Consumer chatResponseConsumer = ChatModelObservationSupport + .setChatResponseInObservationContext(mockObservation); + + assertThat(chatResponseConsumer).isNotNull(); + + chatResponseConsumer.accept(mockChatResponse); + + verifyNoInteractions(mockChatResponse, mockContext); + } + + @Test + void setChatResponseInObservationContextIsNullSafe() { + + ChatResponse mockChatResponse = mock(ChatResponse.class); + + Consumer chatResponseConsumer = ChatModelObservationSupport + .setChatResponseInObservationContext(null); + + assertThat(chatResponseConsumer).isNotNull(); + + chatResponseConsumer.accept(mockChatResponse); + + verifyNoInteractions(mockChatResponse); + } + + static abstract class ChatModelObservation implements Observation { + + private final AtomicReference contextRef = new AtomicReference<>(null); + + @Override + public Context getContext() { + return this.contextRef.updateAndGet(context -> context != null ? context : getContextSupplier().get()); + } + + static Supplier getContextSupplier() { + return () -> { + ChatOptions mockChatOptions = mock(ChatOptions.class); + return new ChatModelObservationContext(new Prompt("This is a test"), "TestProvider", mockChatOptions); + }; + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/util/ValueUtilsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/util/ValueUtilsTests.java new file mode 100644 index 0000000000..3c837dd6a6 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/util/ValueUtilsTests.java @@ -0,0 +1,71 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.util; + +import java.util.stream.Stream; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for {@link ValueUtils}. + * + * @author John Blum + */ +public class ValueUtilsTests { + + @Test + @SuppressWarnings("all") + void defaultIfNullReturnsValue() { + assertThat(ValueUtils.defaultIfNull("test", "mock")).isEqualTo("test"); + } + + @Test + void defaultIfNullReturnsDefaultValue() { + assertThat(ValueUtils.defaultIfNull(null, "mock")).isEqualTo("mock"); + } + + @Test + void defaultToEmptyStringReturnsString() { + assertThat(ValueUtils.defaultToEmptyString("test")).isEqualTo("test"); + } + + @Test + void defaultToEmptyStringReturnsEmptyString() { + Stream.of(" ", "", null) + .forEach(string -> assertThat(ValueUtils.defaultToEmptyString(string)).isEqualTo(ValueUtils.EMPTY_STRING)); + } + + @Test + void defaultToEmptyStringWithFunctionReturnsValue() { + assertThat(ValueUtils.defaultToEmptyString(Named.from("test"), Named::name)).isEqualTo("test"); + } + + @Test + void defaultToEmptyStringWithFunctionReturnsEmptyString() { + assertThat(ValueUtils.defaultToEmptyString(null, Named::name)).isEqualTo(ValueUtils.EMPTY_STRING); + } + + record Named(String name) { + + static Named from(String name) { + return new Named(name); + } + } + +}