Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only construct Observation.Context for Chat and Vector ops on Model observations #1661

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import io.micrometer.observation.Observation;
Expand Down Expand Up @@ -54,10 +55,12 @@
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.support.ChatModelObservationSupport;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Content;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
Expand All @@ -77,6 +80,7 @@
* @author Mariusz Bernacki
* @author Thomas Vitale
* @author Claudio Silva Junior
* @author John Blum
* @since 1.0.0
*/
public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel {
Expand Down Expand Up @@ -209,28 +213,31 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul

@Override
public ChatResponse call(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, false);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
Copy link
Member

@jonatan-ivanov jonatan-ivanov Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about calling it observationContextSupplier? (There are a few other occasions.)

The point of passing the context via a Supplier when you create an Observation is exactly this; when observations are "off" (noop), it will not call the supplier so no unnecessary context objects will be created. 👍🏼

In this particular case though, since the request can be different for every call, this replaces creating a new context to creating a new Supplier which I think might be more lightweight but there is a bit of additional complexity through the noop check + insanceof + cast + using Optional so which one is more lightweight, I'm not sure, only JMH can tell I guess.

Copy link
Contributor Author

@jxblum jxblum Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with whatever name we use. In some cases (such as Wrapper types or Wrapper-like types, e.g. Supplier, or Optional), I simply use, or prefer, the name of the thing it wraps.

My reasoning is also similar in effect to List<User> users vs. List<User> userList. I like users, particularly if the collection-type might change (e.g. Set)

.prompt(prompt)
.provider(AnthropicApi.PROVIDER_NAME)
.requestOptions(buildRequestOptions(request))
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
this.observationRegistry);

ChatResponse response = observation.observe(() -> {

ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));
ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));

ChatResponse chatResponse = toChatResponse(completionEntity.getBody());
ChatResponse chatResponse = toChatResponse(completionEntity.getBody());

observationContext.setResponse(chatResponse);
ChatModelObservationSupport.getObservationContext(observation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do this instead (might be simpler):

if (!observation.isNoop()) {
    ((ChatModelObservationContext)observation.getContext()).setResponse(chatResponse);
}

Copy link
Contributor Author

@jxblum jxblum Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally agree, but I'd like to see something like this in Micrometer:

interface Observation {

  static boolean isNoop(Observation observation) {
    return observation == null || observation.isNoop();
  }

  static boolean isNotNoop(Observation observation) {
    return !isNoop(observation);
  }

  // ...
}

.ifPresent(context -> context.setResponse(chatResponse));

return chatResponse;
});
return chatResponse;
});

if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null
&& this.isToolCall(response, Set.of("tool_use"))) {
Expand All @@ -243,17 +250,19 @@ public ChatResponse call(Prompt prompt) {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

return Flux.deferContextual(contextView -> {

ChatCompletionRequest request = createRequest(prompt, true);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AnthropicApi.PROVIDER_NAME)
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
.requestOptions(buildRequestOptions(request))
.provider(AnthropicApi.PROVIDER_NAME)
.prompt(prompt)
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
Expand All @@ -276,7 +285,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on

return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
return new MessageAggregator().aggregate(chatResponseFlux,
ChatModelObservationSupport.setChatResponseInObservationContext(observation));
});
}

Expand Down Expand Up @@ -408,7 +418,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
String systemPrompt = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
.map(m -> m.getContent())
.map(Content::getContent)
.collect(Collectors.joining(System.lineSeparator()));

ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
Expand Down Expand Up @@ -78,6 +80,7 @@
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.support.ChatModelObservationSupport;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
Expand All @@ -87,6 +90,7 @@
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.util.ValueUtils;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand Down Expand Up @@ -195,24 +199,24 @@ public AzureOpenAiChatOptions getDefaultOptions() {
@Override
public ChatResponse call(Prompt prompt) {

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(false);

ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
ChatResponse chatResponse = toChatResponse(chatCompletions);
observationContext.setResponse(chatResponse);
return chatResponse;
});
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
this.observationRegistry);

ChatResponse response = observation.observe(() -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt).setStream(false);
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
ChatResponse chatResponse = toChatResponse(chatCompletions);
ChatModelObservationSupport.getObservationContext(observation)
.ifPresent(context -> context.setResponse(chatResponse));
return chatResponse;
});

if (!isProxyToolCalls(prompt, this.defaultOptions)
&& isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
Expand All @@ -229,24 +233,28 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS
public Flux<ChatResponse> stream(Prompt prompt) {

return Flux.deferContextual(contextView -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(true);

ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt).setStream(true);

Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
.getChatCompletionsStream(options.getModel(), options);

// @formatter:off
// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
// TODO: Why is roleMap not used? I am guessing it should have served the same
// purpose as the roleMap in OpenAiChatModel.stream(:Prompt)
// @formatter:on
ConcurrentMap<String, String> roleMap = new ConcurrentHashMap<>();

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
.requestOptions(ValueUtils.defaultIfNull(prompt.getOptions(), this.defaultOptions))
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.prompt(prompt)
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
Expand Down Expand Up @@ -295,7 +303,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));

return new MessageAggregator().aggregate(flux, observationContext::setResponse);
return new MessageAggregator().aggregate(flux,
ChatModelObservationSupport.setChatResponseInObservationContext(observation));
});

});
Expand Down
Loading