diff --git a/README.md b/README.md
index f4e50f4..5f45140 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,7 @@
## 支持的平台
+ OpenAi
+ Zhipu
++ DeepSeek
+ 待添加
## 支持的服务
@@ -19,11 +20,13 @@
+ 支持流式输出。支持函数调用参数输出
+ 轻松使用Tool Calls
+ 支持多个函数同时调用(智谱不支持)
++ 支持stream_options,流式输出直接获取token usage
+ 内置向量数据库支持: Pinecone
+ 使用Tika读取文件
+ Token统计`TikTokensUtil.java`
## 更新日志
++ [2024-08-29] 新增对DeepSeek平台的支持、新增stream_options可以直接统计usage、新增错误拦截器`ErrorInterceptor.java`、发布0.3.0版本.
+ [2024-08-29] 修改SseListener以兼容智谱函数调用
+ [2024-08-28] 添加token统计、添加智谱AI的Chat服务、优化函数调用可以支持多轮多函数。
+ [2024-08-17] 增强SseListener监听器功能。发布0.2.0版本。
@@ -32,11 +35,11 @@
## 导入
### Gradle
```groovy
-implementation group: 'io.github.lnyo-cly', name: 'ai4j', version: '0.1.0'
+implementation group: 'io.github.lnyo-cly', name: 'ai4j', version: '0.3.0'
```
```groovy
-implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', version: '0.1.0'
+implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', version: '0.3.0'
```
@@ -46,7 +49,7 @@ implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', ver
io.github.lnyo-cly
ai4j
- 0.2.0
+ 0.3.0
```
@@ -55,7 +58,7 @@ implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', ver
io.github.lnyo-cly
ai4j-spring-boot-stater
- 0.2.0
+ 0.3.0
```
@@ -89,7 +92,7 @@ implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', ver
}
```
-#### Spring获取
+### Spring获取
```yml
# 国内访问默认需要代理
ai:
@@ -98,11 +101,20 @@ ai:
okhttp:
proxy-port: 10809
proxy-url: "127.0.0.1"
+ zhipu:
+ api-key: "xxx"
+ #other...
```
```java
+// 注入Ai服务
@Autowired
private AiService aiService;
+
+// 获取需要的服务实例
+IChatService chatService = aiService.getChatService(PlatformType.OPENAI);
+IEmbeddingService embeddingService = aiService.getEmbeddingService(PlatformType.OPENAI);
+// ......
```
## Chat服务
diff --git a/ai4j-spring-boot-stater/pom.xml b/ai4j-spring-boot-stater/pom.xml
index f171a50..a1edb8e 100644
--- a/ai4j-spring-boot-stater/pom.xml
+++ b/ai4j-spring-boot-stater/pom.xml
@@ -6,10 +6,10 @@
io.github.lnyo-cly
ai4j-spring-boot-stater
jar
- 0.2.0
+ 0.3.0
ai4j-spring-boot-stater
- ai4j-spring-boot-stater
+ 为aj4j所提供的spring-stater,便于接入spring项目。关于ai4j: 整合多平台大模型,如OpenAi、Zhipu(ChatGLM)、DeepSeek等等,提供统一的输入输出(对齐OpenAi),优化函数调用(Tool Call),优化RAG调用、支持向量数据库(Pinecone),并且支持JDK1.8,为用户提供快速整合AI的能力。
diff --git a/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/AiConfigAutoConfiguration.java b/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/AiConfigAutoConfiguration.java
index 1b4fbf1..d128b08 100644
--- a/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/AiConfigAutoConfiguration.java
+++ b/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/AiConfigAutoConfiguration.java
@@ -1,8 +1,10 @@
package io.github.lnyocly.ai4j;
+import io.github.lnyocly.ai4j.config.DeepSeekConfig;
import io.github.lnyocly.ai4j.config.OpenAiConfig;
import io.github.lnyocly.ai4j.config.PineconeConfig;
import io.github.lnyocly.ai4j.config.ZhipuConfig;
+import io.github.lnyocly.ai4j.interceptor.ErrorInterceptor;
import io.github.lnyocly.ai4j.service.PlatformType;
import io.github.lnyocly.ai4j.service.factor.AiService;
import io.github.lnyocly.ai4j.vector.service.PineconeService;
@@ -27,21 +29,24 @@
OpenAiConfigProperties.class,
OkHttpConfigProperties.class,
PineconeConfigProperties.class,
- ZhipuConfigProperties.class})
+ ZhipuConfigProperties.class,
+ DeepSeekConfigProperties.class})
public class AiConfigAutoConfiguration {
private final OkHttpConfigProperties okHttpConfigProperties;
private final OpenAiConfigProperties openAiConfigProperties;
private final PineconeConfigProperties pineconeConfigProperties;
private final ZhipuConfigProperties zhipuConfigProperties;
+ private final DeepSeekConfigProperties deepSeekConfigProperties;
private io.github.lnyocly.ai4j.service.Configuration configuration = new io.github.lnyocly.ai4j.service.Configuration();
- public AiConfigAutoConfiguration(OkHttpConfigProperties okHttpConfigProperties, OpenAiConfigProperties openAiConfigProperties, PineconeConfigProperties pineconeConfigProperties, ZhipuConfigProperties zhipuConfigProperties) {
+ public AiConfigAutoConfiguration(OkHttpConfigProperties okHttpConfigProperties, OpenAiConfigProperties openAiConfigProperties, PineconeConfigProperties pineconeConfigProperties, ZhipuConfigProperties zhipuConfigProperties, DeepSeekConfigProperties deepSeekConfigProperties) {
this.okHttpConfigProperties = okHttpConfigProperties;
this.openAiConfigProperties = openAiConfigProperties;
this.pineconeConfigProperties = pineconeConfigProperties;
this.zhipuConfigProperties = zhipuConfigProperties;
+ this.deepSeekConfigProperties = deepSeekConfigProperties;
}
@Bean
@@ -60,6 +65,7 @@ private void init() {
initOpenAiConfig();
initPineconeConfig();
initZhipuConfig();
+ initDeepSeekConfig();
}
private void initOkHttp() {
@@ -75,6 +81,7 @@ private void initOkHttp() {
OkHttpClient okHttpClient = new OkHttpClient
.Builder()
.addInterceptor(httpLoggingInterceptor)
+ .addInterceptor(new ErrorInterceptor())
.connectTimeout(okHttpConfigProperties.getConnectTimeout(), okHttpConfigProperties.getTimeUnit())
.writeTimeout(okHttpConfigProperties.getWriteTimeout(), okHttpConfigProperties.getTimeUnit())
.readTimeout(okHttpConfigProperties.getReadTimeout(), okHttpConfigProperties.getTimeUnit())
@@ -115,6 +122,13 @@ private void initPineconeConfig() {
configuration.setPineconeConfig(pineconeConfig);
}
+ private void initDeepSeekConfig(){
+ DeepSeekConfig deepSeekConfig = new DeepSeekConfig();
+ deepSeekConfig.setApiHost(deepSeekConfigProperties.getApiHost());
+ deepSeekConfig.setApiKey(deepSeekConfigProperties.getApiKey());
+ deepSeekConfig.setChat_completion(deepSeekConfigProperties.getChat_completion());
+ configuration.setDeepSeekConfig(deepSeekConfig);
+ }
}
diff --git a/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/DeepSeekConfigProperties.java b/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/DeepSeekConfigProperties.java
new file mode 100644
index 0000000..3f3f96b
--- /dev/null
+++ b/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/DeepSeekConfigProperties.java
@@ -0,0 +1,40 @@
+package io.github.lnyocly.ai4j;
+
+import org.springframework.boot.context.properties.ConfigurationProperties;
+
+/**
+ * @Author cly
+ * @Description TODO
+ * @Date 2024/8/29 15:01
+ */
+@ConfigurationProperties(prefix = "ai.deepseek")
+public class DeepSeekConfigProperties {
+
+ private String apiHost = "https://api.deepseek.com/";
+ private String apiKey = "";
+ private String chat_completion = "chat/completions";
+
+ public String getApiHost() {
+ return apiHost;
+ }
+
+ public void setApiHost(String apiHost) {
+ this.apiHost = apiHost;
+ }
+
+ public String getChat_completion() {
+ return chat_completion;
+ }
+
+ public void setChat_completion(String chat_completion) {
+ this.chat_completion = chat_completion;
+ }
+
+ public String getApiKey() {
+ return apiKey;
+ }
+
+ public void setApiKey(String apiKey) {
+ this.apiKey = apiKey;
+ }
+}
diff --git a/ai4j/pom.xml b/ai4j/pom.xml
index 79c88bc..068ac9c 100644
--- a/ai4j/pom.xml
+++ b/ai4j/pom.xml
@@ -7,10 +7,10 @@
io.github.lnyo-cly
ai4j
jar
- 0.2.0
+ 0.3.0
ai4j
- ai4j基础组件项目
+ 整合多平台大模型,如OpenAi、Zhipu(ChatGLM)、DeepSeek等等,提供统一的输入输出(对齐OpenAi),优化函数调用(Tool Call),优化RAG调用、支持向量数据库(Pinecone),并且支持JDK1.8,为用户提供快速整合AI的能力。
@@ -78,6 +78,20 @@
logback-classic
1.2.3
+
+
+ org.slf4j
+ slf4j-api
+ 1.7.30
+
+
+
+ org.slf4j
+ slf4j-log4j12
+ 1.7.30
+
+
+
org.reflections
reflections
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/config/DeepSeekConfig.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/config/DeepSeekConfig.java
new file mode 100644
index 0000000..002c00d
--- /dev/null
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/config/DeepSeekConfig.java
@@ -0,0 +1,21 @@
+package io.github.lnyocly.ai4j.config;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+/**
+ * @Author cly
+ * @Description DeepSeek 配置文件
+ * @Date 2024/8/29 10:31
+ */
+
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+public class DeepSeekConfig {
+
+ private String apiHost = "https://api.deepseek.com/";
+ private String apiKey = "";
+ private String chat_completion = "chat/completions";
+}
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/interceptor/ErrorInterceptor.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/interceptor/ErrorInterceptor.java
new file mode 100644
index 0000000..eda7e63
--- /dev/null
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/interceptor/ErrorInterceptor.java
@@ -0,0 +1,39 @@
+package io.github.lnyocly.ai4j.interceptor;
+
+import io.github.lnyocly.ai4j.exception.CommonException;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.Interceptor;
+import okhttp3.Request;
+import okhttp3.Response;
+import org.jetbrains.annotations.NotNull;
+
+import java.io.IOException;
+
+/**
+ * @Author cly
+ * @Description 错误处理器
+ * @Date 2024/8/29 14:55
+ */
+@Slf4j
+public class ErrorInterceptor implements Interceptor {
+ @NotNull
+ @Override
+ public Response intercept(@NotNull Chain chain) throws IOException {
+ Request original = chain.request();
+
+ Response response = chain.proceed(original);
+
+ if(!response.isSuccessful()){
+ //response.close();
+ String errorMsg = response.body().string();
+
+ log.error("AI服务请求异常:{}", errorMsg);
+ throw new CommonException(errorMsg);
+
+
+ }
+
+
+ return response;
+ }
+}
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/listener/SseListener.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/listener/SseListener.java
index 13f735b..a1d52bb 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/listener/SseListener.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/listener/SseListener.java
@@ -1,8 +1,10 @@
package io.github.lnyocly.ai4j.listener;
import com.alibaba.fastjson2.JSON;
+import io.github.lnyocly.ai4j.exception.CommonException;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletionResponse;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatMessage;
+import io.github.lnyocly.ai4j.platform.openai.chat.entity.Choice;
import io.github.lnyocly.ai4j.platform.openai.chat.enums.ChatMessageType;
import io.github.lnyocly.ai4j.platform.openai.tool.ToolCall;
import io.github.lnyocly.ai4j.platform.openai.usage.Usage;
@@ -60,7 +62,7 @@ public abstract class SseListener extends EventSourceListener {
* 花费token
*/
@Getter
- private Usage usage = null;
+ private final Usage usage = new Usage();
@Setter
@Getter
@@ -82,26 +84,38 @@ public abstract class SseListener extends EventSourceListener {
@Override
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
- log.error("流式输出异常 onFailure ");
+
countDownLatch.countDown();
}
@Override
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
if ("[DONE]".equalsIgnoreCase(data)) {
- log.info("模型会话 [DONE]");
+ //log.info("模型会话 [DONE]");
return;
}
ChatCompletionResponse chatCompletionResponse = JSON.parseObject(data, ChatCompletionResponse.class);
- ChatMessage responseMessage = chatCompletionResponse.getChoices().get(0).getDelta();
+ // 统计token,当设置include_usage = true时,最后一条消息会携带usage, 其他消息中usage为null
+ Usage currUsage = chatCompletionResponse.getUsage();
+ if(currUsage != null){
+ usage.setPromptTokens(usage.getPromptTokens() + currUsage.getPromptTokens());
+ usage.setCompletionTokens(usage.getCompletionTokens() + currUsage.getCompletionTokens());
+ usage.setTotalTokens(usage.getTotalTokens() + currUsage.getTotalTokens());
+ }
+
+ List choices = chatCompletionResponse.getChoices();
+ if(choices == null || choices.isEmpty()){
+ return;
+ }
+ ChatMessage responseMessage = choices.get(0).getDelta();
- finishReason = chatCompletionResponse.getChoices().get(0).getFinishReason();
+ finishReason = choices.get(0).getFinishReason();
// tool_calls回答已经结束
- if("tool_calls".equals(chatCompletionResponse.getChoices().get(0).getFinishReason())){
+ if("tool_calls".equals(choices.get(0).getFinishReason())){
if(toolCall == null && responseMessage.getToolCalls()!=null) {
toolCalls = responseMessage.getToolCalls();
if(showToolArgs){
@@ -171,12 +185,11 @@ public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Null
- log.info("测试结果:{}", chatCompletionResponse);
+ //log.info("测试结果:{}", chatCompletionResponse);
}
@Override
public void onClosed(@NotNull EventSource eventSource) {
- log.info("调用 onClosed ");
countDownLatch.countDown();
countDownLatch = new CountDownLatch(1);
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/DeepSeekChatService.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/DeepSeekChatService.java
new file mode 100644
index 0000000..16f5653
--- /dev/null
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/DeepSeekChatService.java
@@ -0,0 +1,275 @@
+package io.github.lnyocly.ai4j.platform.deepseek.chat;
+
+import com.alibaba.fastjson2.JSON;
+import io.github.lnyocly.ai4j.config.DeepSeekConfig;
+import io.github.lnyocly.ai4j.constant.Constants;
+import io.github.lnyocly.ai4j.convert.ParameterConvert;
+import io.github.lnyocly.ai4j.convert.ResultConvert;
+import io.github.lnyocly.ai4j.listener.SseListener;
+import io.github.lnyocly.ai4j.platform.deepseek.chat.entity.DeepSeekChatCompletion;
+import io.github.lnyocly.ai4j.platform.deepseek.chat.entity.DeepSeekChatCompletionResponse;
+import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletion;
+import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletionResponse;
+import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatMessage;
+import io.github.lnyocly.ai4j.platform.openai.chat.entity.Choice;
+import io.github.lnyocly.ai4j.platform.openai.tool.Tool;
+import io.github.lnyocly.ai4j.platform.openai.tool.ToolCall;
+import io.github.lnyocly.ai4j.platform.openai.usage.Usage;
+import io.github.lnyocly.ai4j.platform.zhipu.chat.entity.ZhipuChatCompletionResponse;
+import io.github.lnyocly.ai4j.service.Configuration;
+import io.github.lnyocly.ai4j.service.IChatService;
+import io.github.lnyocly.ai4j.utils.BearerTokenUtils;
+import io.github.lnyocly.ai4j.utils.ToolUtil;
+import okhttp3.*;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+import org.jetbrains.annotations.NotNull;
+import org.jetbrains.annotations.Nullable;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * @Author cly
+ * @Description DeepSeek Chat服务
+ * @Date 2024/8/29 10:26
+ */
+public class DeepSeekChatService implements IChatService, ParameterConvert, ResultConvert {
+ private final DeepSeekConfig deepSeekConfig;
+ private final OkHttpClient okHttpClient;
+ private final EventSource.Factory factory;
+
+ public DeepSeekChatService(Configuration configuration) {
+ this.deepSeekConfig = configuration.getDeepSeekConfig();
+ this.okHttpClient = configuration.getOkHttpClient();
+ this.factory = configuration.createRequestFactory();
+ }
+
+
+ @Override
+ public DeepSeekChatCompletion convertChatCompletionObject(ChatCompletion chatCompletion) {
+ DeepSeekChatCompletion deepSeekChatCompletion = new DeepSeekChatCompletion();
+ deepSeekChatCompletion.setModel(chatCompletion.getModel());
+ deepSeekChatCompletion.setMessages(chatCompletion.getMessages());
+ deepSeekChatCompletion.setFrequencyPenalty(chatCompletion.getFrequencyPenalty());
+ deepSeekChatCompletion.setMaxTokens(chatCompletion.getMaxTokens());
+ deepSeekChatCompletion.setPresencePenalty(chatCompletion.getPresencePenalty());
+ deepSeekChatCompletion.setResponseFormat(chatCompletion.getResponseFormat());
+ deepSeekChatCompletion.setStop(chatCompletion.getStop());
+ deepSeekChatCompletion.setStream(chatCompletion.getStream());
+ deepSeekChatCompletion.setStreamOptions(chatCompletion.getStreamOptions());
+ deepSeekChatCompletion.setTemperature(chatCompletion.getTemperature());
+ deepSeekChatCompletion.setTopP(chatCompletion.getTopP());
+ deepSeekChatCompletion.setTools(chatCompletion.getTools());
+ deepSeekChatCompletion.setFunctions(chatCompletion.getFunctions());
+ deepSeekChatCompletion.setToolChoice(chatCompletion.getToolChoice());
+ deepSeekChatCompletion.setLogprobs(chatCompletion.getLogprobs());
+ deepSeekChatCompletion.setTopLogprobs(chatCompletion.getTopLogprobs());
+ return deepSeekChatCompletion;
+ }
+
+ @Override
+ public EventSourceListener convertEventSource(EventSourceListener eventSourceListener) {
+ return new EventSourceListener() {
+ @Override
+ public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
+ eventSourceListener.onOpen(eventSource, response);
+ }
+
+ @Override
+ public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ eventSourceListener.onFailure(eventSource, t, response);
+ }
+
+ @Override
+ public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
+ if ("[DONE]".equalsIgnoreCase(data)) {
+ eventSourceListener.onEvent(eventSource, id, type, data);
+ return;
+ }
+
+ DeepSeekChatCompletionResponse chatCompletionResponse = JSON.parseObject(data, DeepSeekChatCompletionResponse.class);
+ ChatCompletionResponse response = convertChatCompletionResponse(chatCompletionResponse);
+
+ eventSourceListener.onEvent(eventSource, id, type, JSON.toJSONString(response));
+ }
+
+ @Override
+ public void onClosed(@NotNull EventSource eventSource) {
+ eventSourceListener.onClosed(eventSource);
+ }
+ };
+ }
+
+ @Override
+ public ChatCompletionResponse convertChatCompletionResponse(DeepSeekChatCompletionResponse deepSeekChatCompletionResponse) {
+ ChatCompletionResponse chatCompletionResponse = new ChatCompletionResponse();
+ chatCompletionResponse.setId(deepSeekChatCompletionResponse.getId());
+ chatCompletionResponse.setObject(deepSeekChatCompletionResponse.getObject());
+ chatCompletionResponse.setCreated(deepSeekChatCompletionResponse.getCreated());
+ chatCompletionResponse.setModel(deepSeekChatCompletionResponse.getModel());
+ chatCompletionResponse.setSystemFingerprint(deepSeekChatCompletionResponse.getSystemFingerprint());
+ chatCompletionResponse.setChoices(deepSeekChatCompletionResponse.getChoices());
+ chatCompletionResponse.setUsage(deepSeekChatCompletionResponse.getUsage());
+ return chatCompletionResponse;
+ }
+
+ @Override
+ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, ChatCompletion chatCompletion) throws Exception {
+ if(baseUrl == null || "".equals(baseUrl)) baseUrl = deepSeekConfig.getApiHost();
+ if(apiKey == null || "".equals(apiKey)) apiKey = deepSeekConfig.getApiKey();
+ chatCompletion.setStream(false);
+ chatCompletion.setStreamOptions(null);
+
+ // 转换 请求参数
+ DeepSeekChatCompletion deepSeekChatCompletion = this.convertChatCompletionObject(chatCompletion);
+
+ // 如含有function,则添加tool
+ if(deepSeekChatCompletion.getFunctions()!=null && !deepSeekChatCompletion.getFunctions().isEmpty()){
+ List tools = ToolUtil.getAllFunctionTools(deepSeekChatCompletion.getFunctions());
+ deepSeekChatCompletion.setTools(tools);
+ }
+
+ // 总token消耗
+ Usage allUsage = new Usage();
+
+ String finishReason = "first";
+
+ while("first".equals(finishReason) || "tool_calls".equals(finishReason)){
+
+ finishReason = null;
+
+ // 构造请求
+ String requestString = JSON.toJSONString(deepSeekChatCompletion);
+
+ Request request = new Request.Builder()
+ .header("Authorization", "Bearer " + apiKey)
+ .url(baseUrl.concat(deepSeekConfig.getChat_completion()))
+ .post(RequestBody.create(requestString, MediaType.parse(Constants.JSON_CONTENT_TYPE)))
+ .build();
+
+ Response execute = okHttpClient.newCall(request).execute();
+ if (execute.isSuccessful() && execute.body() != null){
+ DeepSeekChatCompletionResponse deepSeekChatCompletionResponse = JSON.parseObject(execute.body().string(), DeepSeekChatCompletionResponse.class);
+
+ Choice choice = deepSeekChatCompletionResponse.getChoices().get(0);
+ finishReason = choice.getFinishReason();
+
+ Usage usage = deepSeekChatCompletionResponse.getUsage();
+ allUsage.setCompletionTokens(allUsage.getCompletionTokens() + usage.getCompletionTokens());
+ allUsage.setTotalTokens(allUsage.getTotalTokens() + usage.getTotalTokens());
+ allUsage.setPromptTokens(allUsage.getPromptTokens() + usage.getPromptTokens());
+
+ // 判断是否为函数调用返回
+ if("tool_calls".equals(finishReason)){
+ ChatMessage message = choice.getMessage();
+ List toolCalls = message.getToolCalls();
+
+ List messages = new ArrayList<>(deepSeekChatCompletion.getMessages());
+ messages.add(message);
+
+ // 添加 tool 消息
+ for (ToolCall toolCall : toolCalls) {
+ String functionName = toolCall.getFunction().getName();
+ String arguments = toolCall.getFunction().getArguments();
+ String functionResponse = ToolUtil.invoke(functionName, arguments);
+
+ messages.add(ChatMessage.withTool(functionResponse, toolCall.getId()));
+ }
+ deepSeekChatCompletion.setMessages(messages);
+
+ }else{// 其他情况直接返回
+
+ // 设置包含tool的总token数
+ deepSeekChatCompletionResponse.setUsage(allUsage);
+ //deepSeekChatCompletionResponse.setObject("chat.completion");
+
+ // 恢复原始请求数据
+ chatCompletion.setMessages(deepSeekChatCompletion.getMessages());
+ chatCompletion.setTools(deepSeekChatCompletion.getTools());
+
+ return this.convertChatCompletionResponse(deepSeekChatCompletionResponse);
+
+ }
+
+ }
+
+ }
+
+
+ return null;
+ }
+
+ @Override
+ public ChatCompletionResponse chatCompletion(ChatCompletion chatCompletion) throws Exception {
+ return this.chatCompletion(null, null, chatCompletion);
+ }
+
+ @Override
+ public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion chatCompletion, SseListener eventSourceListener) throws Exception {
+ if(baseUrl == null || "".equals(baseUrl)) baseUrl = deepSeekConfig.getApiHost();
+ if(apiKey == null || "".equals(apiKey)) apiKey = deepSeekConfig.getApiKey();
+ chatCompletion.setStream(true);
+
+ // 转换 请求参数
+ DeepSeekChatCompletion deepSeekChatCompletion = this.convertChatCompletionObject(chatCompletion);
+
+ // 如含有function,则添加tool
+ if(deepSeekChatCompletion.getFunctions()!=null && !deepSeekChatCompletion.getFunctions().isEmpty()){
+ List tools = ToolUtil.getAllFunctionTools(deepSeekChatCompletion.getFunctions());
+ deepSeekChatCompletion.setTools(tools);
+ }
+
+ String finishReason = "first";
+
+ while("first".equals(finishReason) || "tool_calls".equals(finishReason)){
+
+ finishReason = null;
+ String jsonString = JSON.toJSONString(deepSeekChatCompletion);
+
+ Request request = new Request.Builder()
+ .header("Authorization", "Bearer " + apiKey)
+ .url(baseUrl.concat(deepSeekConfig.getChat_completion()))
+ .post(RequestBody.create(jsonString, MediaType.parse(Constants.APPLICATION_JSON)))
+ .build();
+
+
+ factory.newEventSource(request, convertEventSource(eventSourceListener));
+ eventSourceListener.getCountDownLatch().await();
+
+ finishReason = eventSourceListener.getFinishReason();
+ List toolCalls = eventSourceListener.getToolCalls();
+
+ // 需要调用函数
+ if("tool_calls".equals(finishReason) && !toolCalls.isEmpty()){
+ // 创建tool响应消息
+ ChatMessage responseMessage = ChatMessage.withAssistant(eventSourceListener.getToolCalls());
+
+ List messages = new ArrayList<>(deepSeekChatCompletion.getMessages());
+ messages.add(responseMessage);
+
+ // 封装tool结果消息
+ for (ToolCall toolCall : toolCalls) {
+ String functionName = toolCall.getFunction().getName();
+ String arguments = toolCall.getFunction().getArguments();
+ String functionResponse = ToolUtil.invoke(functionName, arguments);
+
+ messages.add(ChatMessage.withTool(functionResponse, toolCall.getId()));
+ }
+ eventSourceListener.setToolCalls(new ArrayList<>());
+ eventSourceListener.setToolCall(null);
+ deepSeekChatCompletion.setMessages(messages);
+ }
+
+ }
+
+ // 补全原始请求
+ chatCompletion.setMessages(deepSeekChatCompletion.getMessages());
+ chatCompletion.setTools(deepSeekChatCompletion.getTools());
+ }
+
+ @Override
+ public void chatCompletionStream(ChatCompletion chatCompletion, SseListener eventSourceListener) throws Exception {
+ this.chatCompletionStream(null, null, chatCompletion, eventSourceListener);
+ }
+}
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletion.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletion.java
new file mode 100644
index 0000000..7174f9f
--- /dev/null
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletion.java
@@ -0,0 +1,154 @@
+package io.github.lnyocly.ai4j.platform.deepseek.chat.entity;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatMessage;
+import io.github.lnyocly.ai4j.platform.openai.chat.entity.StreamOptions;
+import io.github.lnyocly.ai4j.platform.openai.tool.Tool;
+import io.github.lnyocly.ai4j.platform.zhipu.chat.entity.ZhipuChatCompletion;
+import lombok.*;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * @Author cly
+ * @Description DeepSeek对话请求实体
+ * @Date 2024/8/29 10:27
+ */
+@Data
+@Builder(toBuilder = true)
+@NoArgsConstructor()
+@AllArgsConstructor()
+@JsonIgnoreProperties(ignoreUnknown = true)
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public class DeepSeekChatCompletion {
+
+
+ @NonNull
+ private String model;
+
+ @NonNull
+ private List messages;
+
+ /**
+ * 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。
+ */
+ @Builder.Default
+ @JsonProperty("frequency_penalty")
+ private Float frequencyPenalty = 0f;
+
+ /**
+ * 限制一次请求中模型生成 completion 的最大 token 数。输入 token 和输出 token 的总长度受模型的上下文长度的限制。
+ */
+ @JsonProperty("max_tokens")
+ private Integer maxTokens;
+
+ /**
+ * 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。
+ */
+ @Builder.Default
+ @JsonProperty("presence_penalty")
+ private Float presencePenalty = 0f;
+
+ /**
+ * 一个 object,指定模型必须输出的格式。
+ *
+ * 设置为 { "type": "json_object" } 以启用 JSON 模式,该模式保证模型生成的消息是有效的 JSON。
+ *
+ * 注意: 使用 JSON 模式时,你还必须通过系统或用户消息指示模型生成 JSON。
+ * 否则,模型可能会生成不断的空白字符,直到生成达到令牌限制,从而导致请求长时间运行并显得“卡住”。
+ * 此外,如果 finish_reason="length",这表示生成超过了 max_tokens 或对话超过了最大上下文长度,消息内容可能会被部分截断。
+ */
+ @JsonProperty("response_format")
+ private Object responseFormat;
+
+ /**
+ * 在遇到这些词时,API 将停止生成更多的 token。
+ */
+ private List stop;
+
+ /**
+ * 如果设置为 True,将会以 SSE(server-sent events)的形式以流式发送消息增量。消息流以 data: [DONE] 结尾
+ */
+ private Boolean stream = false;
+
+ /**
+ * 流式输出相关选项。只有在 stream 参数为 true 时,才可设置此参数。
+ */
+ @Builder.Default
+ @JsonProperty("stream_options")
+ private StreamOptions streamOptions = new StreamOptions();
+
+ /**
+ * 采样温度,介于 0 和 2 之间。更高的值,如 0.8,会使输出更随机,而更低的值,如 0.2,会使其更加集中和确定。
+ * 我们通常建议可以更改这个值或者更改 top_p,但不建议同时对两者进行修改。
+ */
+ @Builder.Default
+ private Float temperature = 1f;
+
+ /**
+ * 作为调节采样温度的替代方案,模型会考虑前 top_p 概率的 token 的结果。所以 0.1 就意味着只有包括在最高 10% 概率中的 token 会被考虑。
+ * 我们通常建议修改这个值或者更改 temperature,但不建议同时对两者进行修改。
+ */
+ @Builder.Default
+ @JsonProperty("top_p")
+ private Float topP = 1f;
+
+ /**
+ * 模型可能会调用的 tool 的列表。目前,仅支持 function 作为工具。使用此参数来提供以 JSON 作为输入参数的 function 列表。
+ */
+ private List tools;
+
+ /**
+ * 辅助属性
+ */
+ @JsonIgnore
+ private List functions;
+
+ /**
+ * 控制模型调用 tool 的行为。
+ * none 意味着模型不会调用任何 tool,而是生成一条消息。
+ * auto 意味着模型可以选择生成一条消息或调用一个或多个 tool。
+ * 当没有 tool 时,默认值为 none。如果有 tool 存在,默认值为 auto。
+ */
+ @JsonProperty("tool_choice")
+ private String toolChoice;
+
+ /**
+ * 是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。
+ */
+ @Builder.Default
+ private Boolean logprobs = false;
+
+ /**
+ * 一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。
+ */
+ @JsonProperty("top_logprobs")
+ private Integer topLogprobs;
+
+ public static class DeepSeekChatCompletionBuilder {
+ private List functions;
+
+ public DeepSeekChatCompletion.DeepSeekChatCompletionBuilder functions(String... functions){
+ if (this.functions == null) {
+ this.functions = new ArrayList<>();
+ }
+ this.functions.addAll(Arrays.asList(functions));
+ return this;
+ }
+
+ public DeepSeekChatCompletion.DeepSeekChatCompletionBuilder functions(List functions){
+ if (this.functions == null) {
+ this.functions = new ArrayList<>();
+ }
+ this.functions.addAll(functions);
+ return this;
+ }
+
+
+ }
+}
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletionResponse.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletionResponse.java
new file mode 100644
index 0000000..7aa31e6
--- /dev/null
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletionResponse.java
@@ -0,0 +1,61 @@
+package io.github.lnyocly.ai4j.platform.deepseek.chat.entity;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import io.github.lnyocly.ai4j.platform.openai.chat.entity.Choice;
+import io.github.lnyocly.ai4j.platform.openai.usage.Usage;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.List;
+
+/**
+ * @Author cly
+ * @Description DeepSeek对话响应实体
+ * @Date 2024/8/29 10:28
+ */
+
+@Data
+@NoArgsConstructor()
+@AllArgsConstructor()
+@JsonIgnoreProperties(ignoreUnknown = true)
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public class DeepSeekChatCompletionResponse {
+ /**
+ * 该对话的唯一标识符。
+ */
+ private String id;
+
+ /**
+ * 对象的类型, 其值为 chat.completion 或 chat.completion.chunk
+ */
+ private String object;
+
+ /**
+ * 创建聊天完成时的 Unix 时间戳(以秒为单位)。
+ */
+ private Long created;
+
+ /**
+ * 生成该 completion 的模型名。
+ */
+ private String model;
+
+ /**
+ * 模型生成的 completion 的选择列表。
+ */
+ private List choices;
+
+ /**
+ * 该对话补全请求的用量信息。
+ */
+ private Usage usage;
+
+ /**
+ * 该指纹代表模型运行时使用的后端配置。
+ */
+ @JsonProperty("system_fingerprint")
+ private String systemFingerprint;
+}
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/OpenAiChatService.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/OpenAiChatService.java
index 1b22818..dd0f3cd 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/OpenAiChatService.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/OpenAiChatService.java
@@ -48,6 +48,8 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat
if(baseUrl == null || "".equals(baseUrl)) baseUrl = openAiConfig.getApiHost();
if(apiKey == null || "".equals(apiKey)) apiKey = openAiConfig.getApiKey();
chatCompletion.setStream(false);
+ chatCompletion.setStreamOptions(null);
+
if(chatCompletion.getFunctions()!=null && !chatCompletion.getFunctions().isEmpty()){
List tools = ToolUtil.getAllFunctionTools(chatCompletion.getFunctions());
chatCompletion.setTools(tools);
@@ -56,9 +58,12 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat
// 总token消耗
Usage allUsage = new Usage();
- String finishReason = null;
+ String finishReason = "first";
+
+ while("first".equals(finishReason) || "tool_calls".equals(finishReason)){
+
+ finishReason = null;
- while(finishReason == null || "tool_calls".equals(finishReason)){
// 构造请求
String requestString = JSON.toJSONString(chatCompletion);
@@ -74,8 +79,6 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat
Choice choice = chatCompletionResponse.getChoices().get(0);
finishReason = choice.getFinishReason();
- System.out.println("finishReason: " + finishReason);
- System.out.println(JSON.toJSONString(chatCompletionResponse));
Usage usage = chatCompletionResponse.getUsage();
allUsage.setCompletionTokens(allUsage.getCompletionTokens() + usage.getCompletionTokens());
@@ -136,10 +139,11 @@ public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion c
chatCompletion.setTools(tools);
}
- String finishReason = null;
+ String finishReason = "first";
- while(finishReason == null || "tool_calls".equals(finishReason)){
+ while("first".equals(finishReason) || "tool_calls".equals(finishReason)){
+ finishReason = null;
String jsonString = JSON.toJSONString(chatCompletion);
Request request = new Request.Builder()
@@ -148,6 +152,7 @@ public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion c
.post(RequestBody.create(jsonString, MediaType.parse(Constants.APPLICATION_JSON)))
.build();
+
factory.newEventSource(request, eventSourceListener);
eventSourceListener.getCountDownLatch().await();
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletion.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletion.java
index 7cbe23d..20944a1 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletion.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletion.java
@@ -40,28 +40,49 @@ public class ChatCompletion {
private List messages;
/**
- * 流式输出
+ * 如果设置为 True,将会以 SSE(server-sent events)的形式以流式发送消息增量。消息流以 data: [DONE] 结尾
*/
@Builder.Default
private Boolean stream = false;
+ /**
+ * 流式输出相关选项。只有在 stream 参数为 true 时,才可设置此参数。
+ */
+ @Builder.Default
+ @JsonProperty("stream_options")
+ private StreamOptions streamOptions = new StreamOptions();
+
+ /**
+ * 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。
+ */
+ @Builder.Default
@JsonProperty("frequency_penalty")
private Float frequencyPenalty = 0f;
/**
- * [0.0, 2.0]
+ * 采样温度,介于 0 和 2 之间。更高的值,如 0.8,会使输出更随机,而更低的值,如 0.2,会使其更加集中和确定。
+ * 我们通常建议可以更改这个值或者更改 top_p,但不建议同时对两者进行修改。
*/
+ @Builder.Default
private Float temperature = 1f;
/**
- * [0.0, 1.0]
+ * 作为调节采样温度的替代方案,模型会考虑前 top_p 概率的 token 的结果。所以 0.1 就意味着只有包括在最高 10% 概率中的 token 会被考虑。
+ * 我们通常建议修改这个值或者更改 temperature,但不建议同时对两者进行修改。
*/
+ @Builder.Default
@JsonProperty("top_p")
private Float topP = 1f;
+ /**
+ * 限制一次请求中模型生成 completion 的最大 token 数。输入 token 和输出 token 的总长度受模型的上下文长度的限制。
+ */
@JsonProperty("max_tokens")
private Integer maxTokens;
+ /**
+ * 模型可能会调用的 tool 的列表。目前,仅支持 function 作为工具。使用此参数来提供以 JSON 作为输入参数的 function 列表。
+ */
private List tools;
/**
@@ -70,26 +91,60 @@ public class ChatCompletion {
@JsonIgnore
private List functions;
+ /**
+ * 控制模型调用 tool 的行为。
+ * none 意味着模型不会调用任何 tool,而是生成一条消息。
+ * auto 意味着模型可以选择生成一条消息或调用一个或多个 tool。
+ * 当没有 tool 时,默认值为 none。如果有 tool 存在,默认值为 auto。
+ */
@JsonProperty("tool_choice")
private String toolChoice;
+ @Builder.Default
@JsonProperty("parallel_tool_calls")
private Boolean parallelToolCalls = true;
+ /**
+ * 一个 object,指定模型必须输出的格式。
+ *
+ * 设置为 { "type": "json_object" } 以启用 JSON 模式,该模式保证模型生成的消息是有效的 JSON。
+ *
+ * 注意: 使用 JSON 模式时,你还必须通过系统或用户消息指示模型生成 JSON。
+ * 否则,模型可能会生成不断的空白字符,直到生成达到令牌限制,从而导致请求长时间运行并显得“卡住”。
+ * 此外,如果 finish_reason="length",这表示生成超过了 max_tokens 或对话超过了最大上下文长度,消息内容可能会被部分截断。
+ */
@JsonProperty("response_format")
private Object responseFormat;
private String user;
+ @Builder.Default
private Integer n = 1;
+ /**
+ * 在遇到这些词时,API 将停止生成更多的 token。
+ */
private List stop;
+ /**
+ * 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。
+ */
+ @Builder.Default
+ @JsonProperty("presence_penalty")
+ private Float presencePenalty = 0f;
+
@JsonProperty("logit_bias")
private Map logitBias;
+ /**
+ * 是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。
+ */
+ @Builder.Default
private Boolean logprobs = false;
+ /**
+ * 一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。
+ */
@JsonProperty("top_logprobs")
private Integer topLogprobs;
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletionResponse.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletionResponse.java
index 2b5d591..03d6158 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletionResponse.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletionResponse.java
@@ -23,14 +23,39 @@
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatCompletionResponse {
+ /**
+ * 该对话的唯一标识符。
+ */
private String id;
+
+ /**
+ * 对象的类型, 其值为 chat.completion 或 chat.completion.chunk
+ */
private String object;
+
+ /**
+ * 创建聊天完成时的 Unix 时间戳(以秒为单位)。
+ */
private Long created;
+
+ /**
+ * 生成该 completion 的模型名。
+ */
private String model;
+
+ /**
+ * 该指纹代表模型运行时使用的后端配置。
+ */
@JsonProperty("system_fingerprint")
private String systemFingerprint;
+ /**
+ * 模型生成的 completion 的选择列表。
+ */
private List choices;
+ /**
+ * 该对话补全请求的用量信息。
+ */
private Usage usage;
}
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/Choice.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/Choice.java
index 83748fc..864c8ce 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/Choice.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/Choice.java
@@ -7,7 +7,7 @@
/**
* @Author cly
- * @Description TODO
+ * @Description 模型生成的 completion
* @Date 2024/8/11 20:01
*/
@Data
@@ -21,6 +21,18 @@ public class Choice {
private Object logprobs;
+ /**
+ * 模型停止生成 token 的原因。
+ *
+ * [stop, length, content_filter, tool_calls, insufficient_system_resource]
+ *
+ * stop:模型自然停止生成,或遇到 stop 序列中列出的字符串。
+ * length:输出长度达到了模型上下文长度限制,或达到了 max_tokens 的限制。
+ * content_filter:输出内容因触发过滤策略而被过滤。
+ * tool_calls:函数调用。
+ * insufficient_system_resource:系统推理资源不足,生成被打断。
+ *
+ */
@JsonProperty("finish_reason")
private String finishReason;
}
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/StreamOptions.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/StreamOptions.java
new file mode 100644
index 0000000..20fbbbb
--- /dev/null
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/StreamOptions.java
@@ -0,0 +1,24 @@
+package io.github.lnyocly.ai4j.platform.openai.chat.entity;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+/**
+ * @Author cly
+ * @Description 流式输出相关选项
+ * @Date 2024/8/29 13:00
+ */
+@Data
+@NoArgsConstructor()
+@AllArgsConstructor()
+@JsonIgnoreProperties(ignoreUnknown = true)
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public class StreamOptions {
+ @JsonProperty("include_usage")
+ private Boolean includeUsage = true;
+}
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/ZhipuChatService.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/ZhipuChatService.java
index c99acee..6ca9931 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/ZhipuChatService.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/ZhipuChatService.java
@@ -32,11 +32,11 @@
/**
* @Author cly
- * @Description TODO
+ * @Description 智谱chat服务
* @Date 2024/8/27 17:29
*/
@Slf4j
-public class ZhipuChatService implements IChatService, ParameterConvert, ResultConvert {
+public class ZhipuChatService implements IChatService, ParameterConvert, ResultConvert {
private final ZhipuConfig zhipuConfig;
private final OkHttpClient okHttpClient;
@@ -54,8 +54,8 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat
if(baseUrl == null || "".equals(baseUrl)) baseUrl = zhipuConfig.getApiHost();
if(apiKey == null || "".equals(apiKey)) apiKey = zhipuConfig.getApiKey();
chatCompletion.setStream(false);
+ chatCompletion.setStreamOptions(null);
- String finishReason = null;
// 根据key获取token
String token = BearerTokenUtils.getToken(apiKey);
@@ -72,8 +72,11 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat
// 总token消耗
Usage allUsage = new Usage();
+ String finishReason = "first";
- while(finishReason == null || "tool_calls".equals(finishReason)){
+ while("first".equals(finishReason) || "tool_calls".equals(finishReason)){
+
+ finishReason = null;
// 构造请求
String requestString = JSON.toJSONString(zhipuChatCompletion);
@@ -89,8 +92,6 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat
Choice choice = zhipuChatCompletionResponse.getChoices().get(0);
finishReason = choice.getFinishReason();
- System.out.println("finishReason: " + finishReason);
- System.out.println(JSON.toJSONString(zhipuChatCompletionResponse));
Usage usage = zhipuChatCompletionResponse.getUsage();
allUsage.setCompletionTokens(allUsage.getCompletionTokens() + usage.getCompletionTokens());
@@ -118,7 +119,7 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat
}else{
// 其他情况直接返回
zhipuChatCompletionResponse.setUsage(allUsage);
-
+ zhipuChatCompletionResponse.setObject("chat.completion");
// 恢复原始请求数据
chatCompletion.setMessages(zhipuChatCompletion.getMessages());
chatCompletion.setTools(zhipuChatCompletion.getTools());
@@ -146,7 +147,6 @@ public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion c
if(apiKey == null || "".equals(apiKey)) apiKey = zhipuConfig.getApiKey();
chatCompletion.setStream(true);
- String finishReason = null;
// 根据key获取token
String token = BearerTokenUtils.getToken(apiKey);
@@ -160,8 +160,11 @@ public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion c
zhipuChatCompletion.setTools(tools);
}
- while(finishReason == null || "tool_calls".equals(finishReason)){
+ String finishReason = "first";
+
+ while("first".equals(finishReason) || "tool_calls".equals(finishReason)){
+ finishReason = null;
String jsonString = JSON.toJSONString(zhipuChatCompletion);
Request request = new Request.Builder()
@@ -216,7 +219,7 @@ public ZhipuChatCompletion convertChatCompletionObject(ChatCompletion chatComple
zhipuChatCompletion.setModel(chatCompletion.getModel());
zhipuChatCompletion.setMessages(chatCompletion.getMessages());
zhipuChatCompletion.setStream(chatCompletion.getStream());
- zhipuChatCompletion.setTemperature(chatCompletion.getTemperature());
+ zhipuChatCompletion.setTemperature(chatCompletion.getTemperature() / 2);
zhipuChatCompletion.setTopP(chatCompletion.getTopP());
zhipuChatCompletion.setMaxTokens(chatCompletion.getMaxTokens());
zhipuChatCompletion.setStop(chatCompletion.getStop());
@@ -247,11 +250,9 @@ public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Null
}
ZhipuChatCompletionResponse chatCompletionResponse = JSON.parseObject(data, ZhipuChatCompletionResponse.class);
+ chatCompletionResponse.setObject("chat.completion.chunk");
ChatCompletionResponse response = convertChatCompletionResponse(chatCompletionResponse);
- // 把智谱的格式,转为OpenAi的格式传输给onEvent
-
-
eventSourceListener.onEvent(eventSource, id, type, JSON.toJSONString(response));
}
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletion.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletion.java
index 8631049..95e0078 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletion.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletion.java
@@ -27,7 +27,9 @@
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ZhipuChatCompletion {
+ @NonNull
private String model;
+ @NonNull
private List messages;
@JsonProperty("request_id")
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletionResponse.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletionResponse.java
index 1cbc04c..3e46be5 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletionResponse.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletionResponse.java
@@ -24,9 +24,9 @@
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ZhipuChatCompletionResponse {
private String id;
+ private String object;
private Long created;
private String model;
private List choices;
-
private Usage usage;
}
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/Configuration.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/Configuration.java
index 66d5c95..960242a 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/Configuration.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/Configuration.java
@@ -1,5 +1,6 @@
package io.github.lnyocly.ai4j.service;
+import io.github.lnyocly.ai4j.config.DeepSeekConfig;
import io.github.lnyocly.ai4j.config.OpenAiConfig;
import io.github.lnyocly.ai4j.config.PineconeConfig;
import io.github.lnyocly.ai4j.config.ZhipuConfig;
@@ -28,6 +29,7 @@ public EventSource.Factory createRequestFactory() {
private OpenAiConfig openAiConfig;
private ZhipuConfig zhipuConfig;
+ private DeepSeekConfig deepSeekConfig;
private PineconeConfig pineconeConfig;
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/PlatformType.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/PlatformType.java
index 3f9ced0..7d0d05f 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/PlatformType.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/PlatformType.java
@@ -13,6 +13,7 @@
public enum PlatformType {
OPENAI("openai"),
ZHIPU("zhipu"),
+ DEEPSEEK("deepseek"),
;
private final String platform;
diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/factor/AiService.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/factor/AiService.java
index a13368b..50d8fc0 100644
--- a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/factor/AiService.java
+++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/factor/AiService.java
@@ -1,5 +1,6 @@
package io.github.lnyocly.ai4j.service.factor;
+import io.github.lnyocly.ai4j.platform.deepseek.chat.DeepSeekChatService;
import io.github.lnyocly.ai4j.platform.openai.chat.OpenAiChatService;
import io.github.lnyocly.ai4j.platform.zhipu.chat.ZhipuChatService;
import io.github.lnyocly.ai4j.service.Configuration;
@@ -38,6 +39,8 @@ private IChatService createChatService(PlatformType platform) {
return new OpenAiChatService(configuration);
case ZHIPU:
return new ZhipuChatService(configuration);
+ case DEEPSEEK:
+ return new DeepSeekChatService(configuration);
default:
throw new IllegalArgumentException("Unknown platform: " + platform);
}
diff --git a/ai4j/src/test/java/io/github/lnyocly/OpenAiTest.java b/ai4j/src/test/java/io/github/lnyocly/OpenAiTest.java
index df86b67..de415a1 100644
--- a/ai4j/src/test/java/io/github/lnyocly/OpenAiTest.java
+++ b/ai4j/src/test/java/io/github/lnyocly/OpenAiTest.java
@@ -1,9 +1,11 @@
package io.github.lnyocly;
import com.alibaba.fastjson2.JSON;
+import io.github.lnyocly.ai4j.config.DeepSeekConfig;
import io.github.lnyocly.ai4j.config.OpenAiConfig;
import io.github.lnyocly.ai4j.config.ZhipuConfig;
+import io.github.lnyocly.ai4j.interceptor.ErrorInterceptor;
import io.github.lnyocly.ai4j.listener.SseListener;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletion;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletionResponse;
@@ -67,10 +69,12 @@ public class OpenAiTest {
public void test_init(){
OpenAiConfig openAiConfig = new OpenAiConfig();
ZhipuConfig zhipuConfig = new ZhipuConfig();
+ DeepSeekConfig deepSeekConfig = new DeepSeekConfig();
Configuration configuration = new Configuration();
configuration.setOpenAiConfig(openAiConfig);
configuration.setZhipuConfig(zhipuConfig);
+ configuration.setDeepSeekConfig(deepSeekConfig);
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor();
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
@@ -78,6 +82,7 @@ public void test_init(){
OkHttpClient okHttpClient = new OkHttpClient
.Builder()
.addInterceptor(httpLoggingInterceptor)
+ .addInterceptor(new ErrorInterceptor())
.connectTimeout(300, TimeUnit.SECONDS)
.writeTimeout(300, TimeUnit.SECONDS)
.readTimeout(300, TimeUnit.SECONDS)
@@ -89,7 +94,7 @@ public void test_init(){
embeddingService = aiService.getEmbeddingService(PlatformType.OPENAI);
//chatService = aiService.getChatService(PlatformType.getPlatform("OPENAI"));
- chatService = aiService.getChatService(PlatformType.ZHIPU);
+ chatService = aiService.getChatService(PlatformType.DEEPSEEK);
}
@@ -138,7 +143,7 @@ public void test_embed() throws Exception {
@Test
public void test_chatCompletions_common() throws Exception {
ChatCompletion chatCompletion = ChatCompletion.builder()
- .model("glm-4-flash")
+ .model("deepseek-chat")
.message(ChatMessage.withUser("鲁迅为什么打周树人"))
.build();
@@ -174,32 +179,33 @@ public void test_chatCompletions_multimodal() throws Exception {
@Test
public void test_chatCompletions_stream() throws Exception {
ChatCompletion chatCompletion = ChatCompletion.builder()
- .model("gpt-4o-mini")
+ .model("deepseek-chat")
.message(ChatMessage.withUser("鲁迅为什么打周树人"))
.build();
System.out.println("请求参数");
System.out.println(chatCompletion);
- CountDownLatch countDownLatch = new CountDownLatch(1);
- chatService.chatCompletionStream(chatCompletion, new SseListener() {
+ // 构造监听器
+ SseListener sseListener = new SseListener() {
@Override
protected void send() {
-
+ System.out.println(this.getCurrStr());
}
- });
+ };
- countDownLatch.await();
+ chatService.chatCompletionStream(chatCompletion, sseListener);
System.out.println("请求成功");
+ System.out.println(sseListener.getOutput());
}
@Test
public void test_chatCompletions_function() throws Exception {
ChatCompletion chatCompletion = ChatCompletion.builder()
- .model("gpt-4o-mini")
+ .model("deepseek-chat")
.message(ChatMessage.withUser("查询洛阳明天的天气,并告诉我火车是否发车"))
.functions("queryWeather", "queryTrainInfo")
.build();
@@ -221,7 +227,7 @@ public void test_chatCompletions_stream_function() throws Exception {
// 构造请求参数
ChatCompletion chatCompletion = ChatCompletion.builder()
- .model("glm-4-flash")
+ .model("deepseekaa-chat")
.message(ChatMessage.withUser("查询洛阳明天的天气"))
.functions("queryWeather", "queryTrainInfo")
.build();
@@ -241,6 +247,8 @@ protected void send() {
chatService.chatCompletionStream(chatCompletion, sseListener);
System.out.println("完整内容: ");
System.out.println(sseListener.getOutput());
+ System.out.println("内容花费: ");
+ System.out.println(sseListener.getUsage());
}
@Test
diff --git a/pom.xml b/pom.xml
index 85da238..0a156db 100644
--- a/pom.xml
+++ b/pom.xml
@@ -7,7 +7,7 @@
io.github.lnyo-cly
ai4j-sdk
- 0.2.0
+ 0.3.0
pom