Skip to content

Commit

Permalink
Add support for Embedding an Azure OpenAI
Browse files Browse the repository at this point in the history
* add some logging
* improve integration tests with evaluators
  • Loading branch information
markpollack committed Aug 20, 2023
1 parent f2f81cd commit e9a6ae7
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 287 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class AzureOpenAiClient implements AiClient {
private String model = "gpt-35-turbo";

public AzureOpenAiClient(OpenAIClient msoftOpenAiClient) {
Assert.notNull(msoftOpenAiClient, "OpenAiClient must not be null");
Assert.notNull(msoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
this.msoftOpenAiClient = msoftOpenAiClient;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package org.springframework.ai.azure.openai.embedding;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
import com.azure.ai.openai.models.EmbeddingsOptions;
import com.azure.ai.openai.models.EmbeddingsUsage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class AzureOpenAiEmbeddingClient implements EmbeddingClient {

private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingClient.class);

private final OpenAIClient azureOpenAiClient;

private final String model;

public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient) {
this(azureOpenAiClient, "text-embedding-ada-002");
}

public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, String model) {
Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(model, "Model must not be null");
this.azureOpenAiClient = azureOpenAiClient;
this.model = model;
}

@Override
public List<Double> embed(String text) {
logger.debug("Retrieving embeddings");
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, new EmbeddingsOptions(List.of(text)));
logger.debug("Embeddings retrieved");
return extractEmbeddingsList(embeddings);
}

@Override
public List<Double> embed(Document document) {
logger.debug("Retrieving embeddings");
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model,
new EmbeddingsOptions(List.of(document.getContent())));
logger.debug("Embeddings retrieved");
return extractEmbeddingsList(embeddings);
}

private List<Double> extractEmbeddingsList(Embeddings embeddings) {
return embeddings.getData()
.stream()
.map(EmbeddingItem::getEmbedding)
.flatMap(List::stream)
.collect(Collectors.toList());
}

@Override
public List<List<Double>> embed(List<String> texts) {
logger.debug("Retrieving embeddings");
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, new EmbeddingsOptions(texts));
logger.debug("Embeddings retrieved");
return embeddings.getData().stream().map(emb -> emb.getEmbedding()).collect(Collectors.toList());
}

@Override
public EmbeddingResponse embedForResponse(List<String> texts) {
logger.debug("Retrieving embeddings");
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, new EmbeddingsOptions(texts));
logger.debug("Embeddings retrieved");
return generateEmbeddingResponse(embeddings);
}

private EmbeddingResponse generateEmbeddingResponse(Embeddings embeddings) {
List<Embedding> data = generateEmbeddingList(embeddings.getData());
Map<String, Object> metadata = generateMetadata(this.model, embeddings.getUsage());
return new EmbeddingResponse(data, metadata);
}

private Map<String, Object> generateMetadata(String model, EmbeddingsUsage embeddingsUsage) {
Map<String, Object> metadata = new HashMap<>();
metadata.put("model", model);
metadata.put("prompt-tokens", embeddingsUsage.getPromptTokens());
// NOTE, not in API of AzureAI - metadata.put("completion-tokens",
// embeddingsUsage.getCompletionTokens());
metadata.put("total-tokens", embeddingsUsage.getTotalTokens());
return metadata;
}

private List<Embedding> generateEmbeddingList(List<EmbeddingItem> nativeData) {
List<Embedding> data = new ArrayList<>();
for (EmbeddingItem nativeDatum : nativeData) {
List<Double> nativeDatumEmbedding = nativeDatum.getEmbedding();
int nativeIndex = nativeDatum.getPromptIndex();
Embedding embedding = new Embedding(nativeDatumEmbedding, nativeIndex);
data.add(embedding);
}
return data;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public List<Double> embed(String text) {
EmbeddingRequest embeddingRequest = EmbeddingRequest.builder().input(List.of(text)).model(this.model).build();
com.theokanning.openai.embedding.EmbeddingResult nativeEmbeddingResult = this.openAiService
.createEmbeddings(embeddingRequest);
return generateEmbeddingResult(nativeEmbeddingResult).getData().get(0).getEmbedding();
return generateEmbeddingResponse(nativeEmbeddingResult).getData().get(0).getEmbedding();
}

public List<Double> embed(Document document) {
Expand All @@ -45,7 +45,7 @@ public List<Double> embed(Document document) {
.build();
com.theokanning.openai.embedding.EmbeddingResult nativeEmbeddingResult = this.openAiService
.createEmbeddings(embeddingRequest);
return generateEmbeddingResult(nativeEmbeddingResult).getData().get(0).getEmbedding();
return generateEmbeddingResponse(nativeEmbeddingResult).getData().get(0).getEmbedding();
}

public List<List<Double>> embed(List<String> texts) {
Expand All @@ -58,26 +58,17 @@ public EmbeddingResponse embedForResponse(List<String> texts) {
EmbeddingRequest embeddingRequest = EmbeddingRequest.builder().input(texts).model(this.model).build();
com.theokanning.openai.embedding.EmbeddingResult nativeEmbeddingResult = this.openAiService
.createEmbeddings(embeddingRequest);
return generateEmbeddingResult(nativeEmbeddingResult);
return generateEmbeddingResponse(nativeEmbeddingResult);
}

private EmbeddingResponse generateEmbeddingResult(
private EmbeddingResponse generateEmbeddingResponse(
com.theokanning.openai.embedding.EmbeddingResult nativeEmbeddingResult) {
List<Embedding> data = generateEmbeddingList(nativeEmbeddingResult.getData());
Map<String, Object> metadata = generateMetadata(nativeEmbeddingResult.getModel(),
nativeEmbeddingResult.getUsage());
return new EmbeddingResponse(data, metadata);
}

private Map<String, Object> generateMetadata(String model, Usage usage) {
Map<String, Object> metadata = new HashMap<>();
metadata.put("model", model);
metadata.put("prompt-tokens", usage.getPromptTokens());
metadata.put("completion-tokens", usage.getCompletionTokens());
metadata.put("total-tokens", usage.getTotalTokens());
return metadata;
}

private List<Embedding> generateEmbeddingList(List<com.theokanning.openai.embedding.Embedding> nativeData) {
List<Embedding> data = new ArrayList<>();
for (com.theokanning.openai.embedding.Embedding nativeDatum : nativeData) {
Expand All @@ -89,4 +80,13 @@ private List<Embedding> generateEmbeddingList(List<com.theokanning.openai.embedd
return data;
}

private Map<String, Object> generateMetadata(String model, Usage usage) {
Map<String, Object> metadata = new HashMap<>();
metadata.put("model", model);
metadata.put("prompt-tokens", usage.getPromptTokens());
metadata.put("completion-tokens", usage.getCompletionTokens());
metadata.put("total-tokens", usage.getTotalTokens());
return metadata;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ void acmeChain() {
VectorStoreRetriever vectorStoreRetriever = new VectorStoreRetriever(vectorStore);

logger.info("Retrieving relevant documents");
String userQuery = "How much does the SonicRide 8S cost?";
// "Tell me about the bike 'The SonicRide 8S'" ;
String userQuery = "What bike is good for city commuting?";

// "What bike is good for city commuting?";
// "Tell me more about the bike 'The SonicRide 8S'" ;
// "How much does the SonicRide 8S cost?";

// Eventually include metadata in query.
List<Document> similarDocuments = vectorStoreRetriever.retrieve(userQuery);
logger.info(String.format("Found %s relevant documents.", similarDocuments.size()));

Expand All @@ -100,9 +102,6 @@ void acmeChain() {

private Message getSystemMessage(List<Document> similarDocuments) {

// Would need to figure out which of the documenta metadata fields to add, from
// the loader, now just the 'full description.'

String documents = similarDocuments.stream().map(entry -> entry.getContent()).collect(Collectors.joining("\n"));

SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemBikePrompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void roleTest() {
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
AiResponse response = openAiClient.generate(prompt);
evaluateQuestionAndAnswer(request, response, false);
// needs fine tuning... evaluateQuestionAndAnswer(request, response, false);
}

@Test
Expand Down
Loading

0 comments on commit e9a6ae7

Please sign in to comment.