Skip to content

Commit

Permalink
Only construct Observation.Context for Chat and Vector ops on Model o…
Browse files Browse the repository at this point in the history
…bservations.

* Simplify and polish code in OpenAI using the API with Spring constructs and Builders.
* Simplify common code expressions used in AI provider ChatModels with ValueUtils.
* Apply whitespace to improve readability.
  • Loading branch information
jxblum committed Nov 3, 2024
1 parent 0d2d4b7 commit dc799cd
Show file tree
Hide file tree
Showing 16 changed files with 945 additions and 405 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import io.micrometer.observation.Observation;
Expand Down Expand Up @@ -54,10 +55,12 @@
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.ChatModelObservationSupport;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Content;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
Expand All @@ -77,6 +80,7 @@
* @author Mariusz Bernacki
* @author Thomas Vitale
* @author Claudio Silva Junior
* @author John Blum
* @since 1.0.0
*/
public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel {
Expand Down Expand Up @@ -209,28 +213,31 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul

@Override
public ChatResponse call(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, false);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AnthropicApi.PROVIDER_NAME)
.requestOptions(buildRequestOptions(request))
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
this.observationRegistry);

ChatResponse response = observation.observe(() -> {

ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));
ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));

ChatResponse chatResponse = toChatResponse(completionEntity.getBody());
ChatResponse chatResponse = toChatResponse(completionEntity.getBody());

observationContext.setResponse(chatResponse);
ChatModelObservationSupport.getObservationContext(observation)
.ifPresent(context -> context.setResponse(chatResponse));

return chatResponse;
});
return chatResponse;
});

if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null
&& this.isToolCall(response, Set.of("tool_use"))) {
Expand All @@ -243,17 +250,19 @@ public ChatResponse call(Prompt prompt) {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

return Flux.deferContextual(contextView -> {

ChatCompletionRequest request = createRequest(prompt, true);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AnthropicApi.PROVIDER_NAME)
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
.requestOptions(buildRequestOptions(request))
.provider(AnthropicApi.PROVIDER_NAME)
.prompt(prompt)
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
Expand All @@ -276,7 +285,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on

return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
return new MessageAggregator().aggregate(chatResponseFlux,
ChatModelObservationSupport.setChatResponseInObservationContext(observation));
});
}

Expand Down Expand Up @@ -408,7 +418,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
String systemPrompt = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
.map(m -> m.getContent())
.map(Content::getContent)
.collect(Collectors.joining(System.lineSeparator()));

ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
Expand Down Expand Up @@ -78,6 +80,7 @@
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.ChatModelObservationSupport;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
Expand All @@ -87,6 +90,7 @@
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.util.ValueUtils;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand Down Expand Up @@ -195,24 +199,24 @@ public AzureOpenAiChatOptions getDefaultOptions() {
@Override
public ChatResponse call(Prompt prompt) {

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(false);

ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
ChatResponse chatResponse = toChatResponse(chatCompletions);
observationContext.setResponse(chatResponse);
return chatResponse;
});
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
this.observationRegistry);

ChatResponse response = observation.observe(() -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt).setStream(false);
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
ChatResponse chatResponse = toChatResponse(chatCompletions);
ChatModelObservationSupport.getObservationContext(observation)
.ifPresent(context -> context.setResponse(chatResponse));
return chatResponse;
});

if (!isProxyToolCalls(prompt, this.defaultOptions)
&& isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
Expand All @@ -229,24 +233,28 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS
public Flux<ChatResponse> stream(Prompt prompt) {

return Flux.deferContextual(contextView -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(true);

ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt).setStream(true);

Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
.getChatCompletionsStream(options.getModel(), options);

// @formatter:off
// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
// TODO: Why is roleMap not used? I am guessing it should have served the same
// purpose as the roleMap in OpenAiChatModel.stream(:Prompt)
// @formatter:on
ConcurrentMap<String, String> roleMap = new ConcurrentHashMap<>();

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
.requestOptions(ValueUtils.defaultIfNull(prompt.getOptions(), this.defaultOptions))
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.prompt(prompt)
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
Expand Down Expand Up @@ -295,7 +303,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));

return new MessageAggregator().aggregate(flux, observationContext::setResponse);
return new MessageAggregator().aggregate(flux,
ChatModelObservationSupport.setChatResponseInObservationContext(observation));
});

});
Expand Down
Loading

0 comments on commit dc799cd

Please sign in to comment.