diff --git a/README.md b/README.md index f4e50f4..5f45140 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ ## 支持的平台 + OpenAi + Zhipu ++ DeepSeek + 待添加 ## 支持的服务 @@ -19,11 +20,13 @@ + 支持流式输出。支持函数调用参数输出 + 轻松使用Tool Calls + 支持多个函数同时调用(智谱不支持) ++ 支持stream_options,流式输出直接获取token usage + 内置向量数据库支持: Pinecone + 使用Tika读取文件 + Token统计`TikTokensUtil.java` ## 更新日志 ++ [2024-08-29] 新增对DeepSeek平台的支持、新增stream_options可以直接统计usage、新增错误拦截器`ErrorInterceptor.java`、发布0.3.0版本. + [2024-08-29] 修改SseListener以兼容智谱函数调用 + [2024-08-28] 添加token统计、添加智谱AI的Chat服务、优化函数调用可以支持多轮多函数。 + [2024-08-17] 增强SseListener监听器功能。发布0.2.0版本。 @@ -32,11 +35,11 @@ ## 导入 ### Gradle ```groovy -implementation group: 'io.github.lnyo-cly', name: 'ai4j', version: '0.1.0' +implementation group: 'io.github.lnyo-cly', name: 'ai4j', version: '0.3.0' ``` ```groovy -implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', version: '0.1.0' +implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', version: '0.3.0' ``` @@ -46,7 +49,7 @@ implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', ver io.github.lnyo-cly ai4j - 0.2.0 + 0.3.0 ``` @@ -55,7 +58,7 @@ implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', ver io.github.lnyo-cly ai4j-spring-boot-stater - 0.2.0 + 0.3.0 ``` @@ -89,7 +92,7 @@ implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', ver } ``` -#### Spring获取 +### Spring获取 ```yml # 国内访问默认需要代理 ai: @@ -98,11 +101,20 @@ ai: okhttp: proxy-port: 10809 proxy-url: "127.0.0.1" + zhipu: + api-key: "xxx" + #other... ``` ```java +// 注入Ai服务 @Autowired private AiService aiService; + +// 获取需要的服务实例 +IChatService chatService = aiService.getChatService(PlatformType.OPENAI); +IEmbeddingService embeddingService = aiService.getEmbeddingService(PlatformType.OPENAI); +// ...... ``` ## Chat服务 diff --git a/ai4j-spring-boot-stater/pom.xml b/ai4j-spring-boot-stater/pom.xml index f171a50..a1edb8e 100644 --- a/ai4j-spring-boot-stater/pom.xml +++ b/ai4j-spring-boot-stater/pom.xml @@ -6,10 +6,10 @@ io.github.lnyo-cly ai4j-spring-boot-stater jar - 0.2.0 + 0.3.0 ai4j-spring-boot-stater - ai4j-spring-boot-stater + 为aj4j所提供的spring-stater,便于接入spring项目。关于ai4j: 整合多平台大模型,如OpenAi、Zhipu(ChatGLM)、DeepSeek等等,提供统一的输入输出(对齐OpenAi),优化函数调用(Tool Call),优化RAG调用、支持向量数据库(Pinecone),并且支持JDK1.8,为用户提供快速整合AI的能力。 diff --git a/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/AiConfigAutoConfiguration.java b/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/AiConfigAutoConfiguration.java index 1b4fbf1..d128b08 100644 --- a/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/AiConfigAutoConfiguration.java +++ b/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/AiConfigAutoConfiguration.java @@ -1,8 +1,10 @@ package io.github.lnyocly.ai4j; +import io.github.lnyocly.ai4j.config.DeepSeekConfig; import io.github.lnyocly.ai4j.config.OpenAiConfig; import io.github.lnyocly.ai4j.config.PineconeConfig; import io.github.lnyocly.ai4j.config.ZhipuConfig; +import io.github.lnyocly.ai4j.interceptor.ErrorInterceptor; import io.github.lnyocly.ai4j.service.PlatformType; import io.github.lnyocly.ai4j.service.factor.AiService; import io.github.lnyocly.ai4j.vector.service.PineconeService; @@ -27,21 +29,24 @@ OpenAiConfigProperties.class, OkHttpConfigProperties.class, PineconeConfigProperties.class, - ZhipuConfigProperties.class}) + ZhipuConfigProperties.class, + DeepSeekConfigProperties.class}) public class AiConfigAutoConfiguration { private final OkHttpConfigProperties okHttpConfigProperties; private final OpenAiConfigProperties openAiConfigProperties; private final PineconeConfigProperties pineconeConfigProperties; private final ZhipuConfigProperties zhipuConfigProperties; + private final DeepSeekConfigProperties deepSeekConfigProperties; private io.github.lnyocly.ai4j.service.Configuration configuration = new io.github.lnyocly.ai4j.service.Configuration(); - public AiConfigAutoConfiguration(OkHttpConfigProperties okHttpConfigProperties, OpenAiConfigProperties openAiConfigProperties, PineconeConfigProperties pineconeConfigProperties, ZhipuConfigProperties zhipuConfigProperties) { + public AiConfigAutoConfiguration(OkHttpConfigProperties okHttpConfigProperties, OpenAiConfigProperties openAiConfigProperties, PineconeConfigProperties pineconeConfigProperties, ZhipuConfigProperties zhipuConfigProperties, DeepSeekConfigProperties deepSeekConfigProperties) { this.okHttpConfigProperties = okHttpConfigProperties; this.openAiConfigProperties = openAiConfigProperties; this.pineconeConfigProperties = pineconeConfigProperties; this.zhipuConfigProperties = zhipuConfigProperties; + this.deepSeekConfigProperties = deepSeekConfigProperties; } @Bean @@ -60,6 +65,7 @@ private void init() { initOpenAiConfig(); initPineconeConfig(); initZhipuConfig(); + initDeepSeekConfig(); } private void initOkHttp() { @@ -75,6 +81,7 @@ private void initOkHttp() { OkHttpClient okHttpClient = new OkHttpClient .Builder() .addInterceptor(httpLoggingInterceptor) + .addInterceptor(new ErrorInterceptor()) .connectTimeout(okHttpConfigProperties.getConnectTimeout(), okHttpConfigProperties.getTimeUnit()) .writeTimeout(okHttpConfigProperties.getWriteTimeout(), okHttpConfigProperties.getTimeUnit()) .readTimeout(okHttpConfigProperties.getReadTimeout(), okHttpConfigProperties.getTimeUnit()) @@ -115,6 +122,13 @@ private void initPineconeConfig() { configuration.setPineconeConfig(pineconeConfig); } + private void initDeepSeekConfig(){ + DeepSeekConfig deepSeekConfig = new DeepSeekConfig(); + deepSeekConfig.setApiHost(deepSeekConfigProperties.getApiHost()); + deepSeekConfig.setApiKey(deepSeekConfigProperties.getApiKey()); + deepSeekConfig.setChat_completion(deepSeekConfigProperties.getChat_completion()); + configuration.setDeepSeekConfig(deepSeekConfig); + } } diff --git a/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/DeepSeekConfigProperties.java b/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/DeepSeekConfigProperties.java new file mode 100644 index 0000000..3f3f96b --- /dev/null +++ b/ai4j-spring-boot-stater/src/main/java/io/github/lnyocly/ai4j/DeepSeekConfigProperties.java @@ -0,0 +1,40 @@ +package io.github.lnyocly.ai4j; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @Author cly + * @Description TODO + * @Date 2024/8/29 15:01 + */ +@ConfigurationProperties(prefix = "ai.deepseek") +public class DeepSeekConfigProperties { + + private String apiHost = "https://api.deepseek.com/"; + private String apiKey = ""; + private String chat_completion = "chat/completions"; + + public String getApiHost() { + return apiHost; + } + + public void setApiHost(String apiHost) { + this.apiHost = apiHost; + } + + public String getChat_completion() { + return chat_completion; + } + + public void setChat_completion(String chat_completion) { + this.chat_completion = chat_completion; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } +} diff --git a/ai4j/pom.xml b/ai4j/pom.xml index 79c88bc..068ac9c 100644 --- a/ai4j/pom.xml +++ b/ai4j/pom.xml @@ -7,10 +7,10 @@ io.github.lnyo-cly ai4j jar - 0.2.0 + 0.3.0 ai4j - ai4j基础组件项目 + 整合多平台大模型,如OpenAi、Zhipu(ChatGLM)、DeepSeek等等,提供统一的输入输出(对齐OpenAi),优化函数调用(Tool Call),优化RAG调用、支持向量数据库(Pinecone),并且支持JDK1.8,为用户提供快速整合AI的能力。 @@ -78,6 +78,20 @@ logback-classic 1.2.3 + + + org.slf4j + slf4j-api + 1.7.30 + + + + org.slf4j + slf4j-log4j12 + 1.7.30 + + + org.reflections reflections diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/config/DeepSeekConfig.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/config/DeepSeekConfig.java new file mode 100644 index 0000000..002c00d --- /dev/null +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/config/DeepSeekConfig.java @@ -0,0 +1,21 @@ +package io.github.lnyocly.ai4j.config; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * @Author cly + * @Description DeepSeek 配置文件 + * @Date 2024/8/29 10:31 + */ + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class DeepSeekConfig { + + private String apiHost = "https://api.deepseek.com/"; + private String apiKey = ""; + private String chat_completion = "chat/completions"; +} diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/interceptor/ErrorInterceptor.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/interceptor/ErrorInterceptor.java new file mode 100644 index 0000000..eda7e63 --- /dev/null +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/interceptor/ErrorInterceptor.java @@ -0,0 +1,39 @@ +package io.github.lnyocly.ai4j.interceptor; + +import io.github.lnyocly.ai4j.exception.CommonException; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import org.jetbrains.annotations.NotNull; + +import java.io.IOException; + +/** + * @Author cly + * @Description 错误处理器 + * @Date 2024/8/29 14:55 + */ +@Slf4j +public class ErrorInterceptor implements Interceptor { + @NotNull + @Override + public Response intercept(@NotNull Chain chain) throws IOException { + Request original = chain.request(); + + Response response = chain.proceed(original); + + if(!response.isSuccessful()){ + //response.close(); + String errorMsg = response.body().string(); + + log.error("AI服务请求异常:{}", errorMsg); + throw new CommonException(errorMsg); + + + } + + + return response; + } +} diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/listener/SseListener.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/listener/SseListener.java index 13f735b..a1d52bb 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/listener/SseListener.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/listener/SseListener.java @@ -1,8 +1,10 @@ package io.github.lnyocly.ai4j.listener; import com.alibaba.fastjson2.JSON; +import io.github.lnyocly.ai4j.exception.CommonException; import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletionResponse; import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatMessage; +import io.github.lnyocly.ai4j.platform.openai.chat.entity.Choice; import io.github.lnyocly.ai4j.platform.openai.chat.enums.ChatMessageType; import io.github.lnyocly.ai4j.platform.openai.tool.ToolCall; import io.github.lnyocly.ai4j.platform.openai.usage.Usage; @@ -60,7 +62,7 @@ public abstract class SseListener extends EventSourceListener { * 花费token */ @Getter - private Usage usage = null; + private final Usage usage = new Usage(); @Setter @Getter @@ -82,26 +84,38 @@ public abstract class SseListener extends EventSourceListener { @Override public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) { - log.error("流式输出异常 onFailure "); + countDownLatch.countDown(); } @Override public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) { if ("[DONE]".equalsIgnoreCase(data)) { - log.info("模型会话 [DONE]"); + //log.info("模型会话 [DONE]"); return; } ChatCompletionResponse chatCompletionResponse = JSON.parseObject(data, ChatCompletionResponse.class); - ChatMessage responseMessage = chatCompletionResponse.getChoices().get(0).getDelta(); + // 统计token,当设置include_usage = true时,最后一条消息会携带usage, 其他消息中usage为null + Usage currUsage = chatCompletionResponse.getUsage(); + if(currUsage != null){ + usage.setPromptTokens(usage.getPromptTokens() + currUsage.getPromptTokens()); + usage.setCompletionTokens(usage.getCompletionTokens() + currUsage.getCompletionTokens()); + usage.setTotalTokens(usage.getTotalTokens() + currUsage.getTotalTokens()); + } + + List choices = chatCompletionResponse.getChoices(); + if(choices == null || choices.isEmpty()){ + return; + } + ChatMessage responseMessage = choices.get(0).getDelta(); - finishReason = chatCompletionResponse.getChoices().get(0).getFinishReason(); + finishReason = choices.get(0).getFinishReason(); // tool_calls回答已经结束 - if("tool_calls".equals(chatCompletionResponse.getChoices().get(0).getFinishReason())){ + if("tool_calls".equals(choices.get(0).getFinishReason())){ if(toolCall == null && responseMessage.getToolCalls()!=null) { toolCalls = responseMessage.getToolCalls(); if(showToolArgs){ @@ -171,12 +185,11 @@ public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Null - log.info("测试结果:{}", chatCompletionResponse); + //log.info("测试结果:{}", chatCompletionResponse); } @Override public void onClosed(@NotNull EventSource eventSource) { - log.info("调用 onClosed "); countDownLatch.countDown(); countDownLatch = new CountDownLatch(1); diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/DeepSeekChatService.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/DeepSeekChatService.java new file mode 100644 index 0000000..16f5653 --- /dev/null +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/DeepSeekChatService.java @@ -0,0 +1,275 @@ +package io.github.lnyocly.ai4j.platform.deepseek.chat; + +import com.alibaba.fastjson2.JSON; +import io.github.lnyocly.ai4j.config.DeepSeekConfig; +import io.github.lnyocly.ai4j.constant.Constants; +import io.github.lnyocly.ai4j.convert.ParameterConvert; +import io.github.lnyocly.ai4j.convert.ResultConvert; +import io.github.lnyocly.ai4j.listener.SseListener; +import io.github.lnyocly.ai4j.platform.deepseek.chat.entity.DeepSeekChatCompletion; +import io.github.lnyocly.ai4j.platform.deepseek.chat.entity.DeepSeekChatCompletionResponse; +import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletion; +import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletionResponse; +import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatMessage; +import io.github.lnyocly.ai4j.platform.openai.chat.entity.Choice; +import io.github.lnyocly.ai4j.platform.openai.tool.Tool; +import io.github.lnyocly.ai4j.platform.openai.tool.ToolCall; +import io.github.lnyocly.ai4j.platform.openai.usage.Usage; +import io.github.lnyocly.ai4j.platform.zhipu.chat.entity.ZhipuChatCompletionResponse; +import io.github.lnyocly.ai4j.service.Configuration; +import io.github.lnyocly.ai4j.service.IChatService; +import io.github.lnyocly.ai4j.utils.BearerTokenUtils; +import io.github.lnyocly.ai4j.utils.ToolUtil; +import okhttp3.*; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.List; + +/** + * @Author cly + * @Description DeepSeek Chat服务 + * @Date 2024/8/29 10:26 + */ +public class DeepSeekChatService implements IChatService, ParameterConvert, ResultConvert { + private final DeepSeekConfig deepSeekConfig; + private final OkHttpClient okHttpClient; + private final EventSource.Factory factory; + + public DeepSeekChatService(Configuration configuration) { + this.deepSeekConfig = configuration.getDeepSeekConfig(); + this.okHttpClient = configuration.getOkHttpClient(); + this.factory = configuration.createRequestFactory(); + } + + + @Override + public DeepSeekChatCompletion convertChatCompletionObject(ChatCompletion chatCompletion) { + DeepSeekChatCompletion deepSeekChatCompletion = new DeepSeekChatCompletion(); + deepSeekChatCompletion.setModel(chatCompletion.getModel()); + deepSeekChatCompletion.setMessages(chatCompletion.getMessages()); + deepSeekChatCompletion.setFrequencyPenalty(chatCompletion.getFrequencyPenalty()); + deepSeekChatCompletion.setMaxTokens(chatCompletion.getMaxTokens()); + deepSeekChatCompletion.setPresencePenalty(chatCompletion.getPresencePenalty()); + deepSeekChatCompletion.setResponseFormat(chatCompletion.getResponseFormat()); + deepSeekChatCompletion.setStop(chatCompletion.getStop()); + deepSeekChatCompletion.setStream(chatCompletion.getStream()); + deepSeekChatCompletion.setStreamOptions(chatCompletion.getStreamOptions()); + deepSeekChatCompletion.setTemperature(chatCompletion.getTemperature()); + deepSeekChatCompletion.setTopP(chatCompletion.getTopP()); + deepSeekChatCompletion.setTools(chatCompletion.getTools()); + deepSeekChatCompletion.setFunctions(chatCompletion.getFunctions()); + deepSeekChatCompletion.setToolChoice(chatCompletion.getToolChoice()); + deepSeekChatCompletion.setLogprobs(chatCompletion.getLogprobs()); + deepSeekChatCompletion.setTopLogprobs(chatCompletion.getTopLogprobs()); + return deepSeekChatCompletion; + } + + @Override + public EventSourceListener convertEventSource(EventSourceListener eventSourceListener) { + return new EventSourceListener() { + @Override + public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) { + eventSourceListener.onOpen(eventSource, response); + } + + @Override + public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) { + eventSourceListener.onFailure(eventSource, t, response); + } + + @Override + public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) { + if ("[DONE]".equalsIgnoreCase(data)) { + eventSourceListener.onEvent(eventSource, id, type, data); + return; + } + + DeepSeekChatCompletionResponse chatCompletionResponse = JSON.parseObject(data, DeepSeekChatCompletionResponse.class); + ChatCompletionResponse response = convertChatCompletionResponse(chatCompletionResponse); + + eventSourceListener.onEvent(eventSource, id, type, JSON.toJSONString(response)); + } + + @Override + public void onClosed(@NotNull EventSource eventSource) { + eventSourceListener.onClosed(eventSource); + } + }; + } + + @Override + public ChatCompletionResponse convertChatCompletionResponse(DeepSeekChatCompletionResponse deepSeekChatCompletionResponse) { + ChatCompletionResponse chatCompletionResponse = new ChatCompletionResponse(); + chatCompletionResponse.setId(deepSeekChatCompletionResponse.getId()); + chatCompletionResponse.setObject(deepSeekChatCompletionResponse.getObject()); + chatCompletionResponse.setCreated(deepSeekChatCompletionResponse.getCreated()); + chatCompletionResponse.setModel(deepSeekChatCompletionResponse.getModel()); + chatCompletionResponse.setSystemFingerprint(deepSeekChatCompletionResponse.getSystemFingerprint()); + chatCompletionResponse.setChoices(deepSeekChatCompletionResponse.getChoices()); + chatCompletionResponse.setUsage(deepSeekChatCompletionResponse.getUsage()); + return chatCompletionResponse; + } + + @Override + public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, ChatCompletion chatCompletion) throws Exception { + if(baseUrl == null || "".equals(baseUrl)) baseUrl = deepSeekConfig.getApiHost(); + if(apiKey == null || "".equals(apiKey)) apiKey = deepSeekConfig.getApiKey(); + chatCompletion.setStream(false); + chatCompletion.setStreamOptions(null); + + // 转换 请求参数 + DeepSeekChatCompletion deepSeekChatCompletion = this.convertChatCompletionObject(chatCompletion); + + // 如含有function,则添加tool + if(deepSeekChatCompletion.getFunctions()!=null && !deepSeekChatCompletion.getFunctions().isEmpty()){ + List tools = ToolUtil.getAllFunctionTools(deepSeekChatCompletion.getFunctions()); + deepSeekChatCompletion.setTools(tools); + } + + // 总token消耗 + Usage allUsage = new Usage(); + + String finishReason = "first"; + + while("first".equals(finishReason) || "tool_calls".equals(finishReason)){ + + finishReason = null; + + // 构造请求 + String requestString = JSON.toJSONString(deepSeekChatCompletion); + + Request request = new Request.Builder() + .header("Authorization", "Bearer " + apiKey) + .url(baseUrl.concat(deepSeekConfig.getChat_completion())) + .post(RequestBody.create(requestString, MediaType.parse(Constants.JSON_CONTENT_TYPE))) + .build(); + + Response execute = okHttpClient.newCall(request).execute(); + if (execute.isSuccessful() && execute.body() != null){ + DeepSeekChatCompletionResponse deepSeekChatCompletionResponse = JSON.parseObject(execute.body().string(), DeepSeekChatCompletionResponse.class); + + Choice choice = deepSeekChatCompletionResponse.getChoices().get(0); + finishReason = choice.getFinishReason(); + + Usage usage = deepSeekChatCompletionResponse.getUsage(); + allUsage.setCompletionTokens(allUsage.getCompletionTokens() + usage.getCompletionTokens()); + allUsage.setTotalTokens(allUsage.getTotalTokens() + usage.getTotalTokens()); + allUsage.setPromptTokens(allUsage.getPromptTokens() + usage.getPromptTokens()); + + // 判断是否为函数调用返回 + if("tool_calls".equals(finishReason)){ + ChatMessage message = choice.getMessage(); + List toolCalls = message.getToolCalls(); + + List messages = new ArrayList<>(deepSeekChatCompletion.getMessages()); + messages.add(message); + + // 添加 tool 消息 + for (ToolCall toolCall : toolCalls) { + String functionName = toolCall.getFunction().getName(); + String arguments = toolCall.getFunction().getArguments(); + String functionResponse = ToolUtil.invoke(functionName, arguments); + + messages.add(ChatMessage.withTool(functionResponse, toolCall.getId())); + } + deepSeekChatCompletion.setMessages(messages); + + }else{// 其他情况直接返回 + + // 设置包含tool的总token数 + deepSeekChatCompletionResponse.setUsage(allUsage); + //deepSeekChatCompletionResponse.setObject("chat.completion"); + + // 恢复原始请求数据 + chatCompletion.setMessages(deepSeekChatCompletion.getMessages()); + chatCompletion.setTools(deepSeekChatCompletion.getTools()); + + return this.convertChatCompletionResponse(deepSeekChatCompletionResponse); + + } + + } + + } + + + return null; + } + + @Override + public ChatCompletionResponse chatCompletion(ChatCompletion chatCompletion) throws Exception { + return this.chatCompletion(null, null, chatCompletion); + } + + @Override + public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion chatCompletion, SseListener eventSourceListener) throws Exception { + if(baseUrl == null || "".equals(baseUrl)) baseUrl = deepSeekConfig.getApiHost(); + if(apiKey == null || "".equals(apiKey)) apiKey = deepSeekConfig.getApiKey(); + chatCompletion.setStream(true); + + // 转换 请求参数 + DeepSeekChatCompletion deepSeekChatCompletion = this.convertChatCompletionObject(chatCompletion); + + // 如含有function,则添加tool + if(deepSeekChatCompletion.getFunctions()!=null && !deepSeekChatCompletion.getFunctions().isEmpty()){ + List tools = ToolUtil.getAllFunctionTools(deepSeekChatCompletion.getFunctions()); + deepSeekChatCompletion.setTools(tools); + } + + String finishReason = "first"; + + while("first".equals(finishReason) || "tool_calls".equals(finishReason)){ + + finishReason = null; + String jsonString = JSON.toJSONString(deepSeekChatCompletion); + + Request request = new Request.Builder() + .header("Authorization", "Bearer " + apiKey) + .url(baseUrl.concat(deepSeekConfig.getChat_completion())) + .post(RequestBody.create(jsonString, MediaType.parse(Constants.APPLICATION_JSON))) + .build(); + + + factory.newEventSource(request, convertEventSource(eventSourceListener)); + eventSourceListener.getCountDownLatch().await(); + + finishReason = eventSourceListener.getFinishReason(); + List toolCalls = eventSourceListener.getToolCalls(); + + // 需要调用函数 + if("tool_calls".equals(finishReason) && !toolCalls.isEmpty()){ + // 创建tool响应消息 + ChatMessage responseMessage = ChatMessage.withAssistant(eventSourceListener.getToolCalls()); + + List messages = new ArrayList<>(deepSeekChatCompletion.getMessages()); + messages.add(responseMessage); + + // 封装tool结果消息 + for (ToolCall toolCall : toolCalls) { + String functionName = toolCall.getFunction().getName(); + String arguments = toolCall.getFunction().getArguments(); + String functionResponse = ToolUtil.invoke(functionName, arguments); + + messages.add(ChatMessage.withTool(functionResponse, toolCall.getId())); + } + eventSourceListener.setToolCalls(new ArrayList<>()); + eventSourceListener.setToolCall(null); + deepSeekChatCompletion.setMessages(messages); + } + + } + + // 补全原始请求 + chatCompletion.setMessages(deepSeekChatCompletion.getMessages()); + chatCompletion.setTools(deepSeekChatCompletion.getTools()); + } + + @Override + public void chatCompletionStream(ChatCompletion chatCompletion, SseListener eventSourceListener) throws Exception { + this.chatCompletionStream(null, null, chatCompletion, eventSourceListener); + } +} diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletion.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletion.java new file mode 100644 index 0000000..7174f9f --- /dev/null +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletion.java @@ -0,0 +1,154 @@ +package io.github.lnyocly.ai4j.platform.deepseek.chat.entity; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatMessage; +import io.github.lnyocly.ai4j.platform.openai.chat.entity.StreamOptions; +import io.github.lnyocly.ai4j.platform.openai.tool.Tool; +import io.github.lnyocly.ai4j.platform.zhipu.chat.entity.ZhipuChatCompletion; +import lombok.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * @Author cly + * @Description DeepSeek对话请求实体 + * @Date 2024/8/29 10:27 + */ +@Data +@Builder(toBuilder = true) +@NoArgsConstructor() +@AllArgsConstructor() +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +public class DeepSeekChatCompletion { + + + @NonNull + private String model; + + @NonNull + private List messages; + + /** + * 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。 + */ + @Builder.Default + @JsonProperty("frequency_penalty") + private Float frequencyPenalty = 0f; + + /** + * 限制一次请求中模型生成 completion 的最大 token 数。输入 token 和输出 token 的总长度受模型的上下文长度的限制。 + */ + @JsonProperty("max_tokens") + private Integer maxTokens; + + /** + * 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。 + */ + @Builder.Default + @JsonProperty("presence_penalty") + private Float presencePenalty = 0f; + + /** + * 一个 object,指定模型必须输出的格式。 + * + * 设置为 { "type": "json_object" } 以启用 JSON 模式,该模式保证模型生成的消息是有效的 JSON。 + * + * 注意: 使用 JSON 模式时,你还必须通过系统或用户消息指示模型生成 JSON。 + * 否则,模型可能会生成不断的空白字符,直到生成达到令牌限制,从而导致请求长时间运行并显得“卡住”。 + * 此外,如果 finish_reason="length",这表示生成超过了 max_tokens 或对话超过了最大上下文长度,消息内容可能会被部分截断。 + */ + @JsonProperty("response_format") + private Object responseFormat; + + /** + * 在遇到这些词时,API 将停止生成更多的 token。 + */ + private List stop; + + /** + * 如果设置为 True,将会以 SSE(server-sent events)的形式以流式发送消息增量。消息流以 data: [DONE] 结尾 + */ + private Boolean stream = false; + + /** + * 流式输出相关选项。只有在 stream 参数为 true 时,才可设置此参数。 + */ + @Builder.Default + @JsonProperty("stream_options") + private StreamOptions streamOptions = new StreamOptions(); + + /** + * 采样温度,介于 0 和 2 之间。更高的值,如 0.8,会使输出更随机,而更低的值,如 0.2,会使其更加集中和确定。 + * 我们通常建议可以更改这个值或者更改 top_p,但不建议同时对两者进行修改。 + */ + @Builder.Default + private Float temperature = 1f; + + /** + * 作为调节采样温度的替代方案,模型会考虑前 top_p 概率的 token 的结果。所以 0.1 就意味着只有包括在最高 10% 概率中的 token 会被考虑。 + * 我们通常建议修改这个值或者更改 temperature,但不建议同时对两者进行修改。 + */ + @Builder.Default + @JsonProperty("top_p") + private Float topP = 1f; + + /** + * 模型可能会调用的 tool 的列表。目前,仅支持 function 作为工具。使用此参数来提供以 JSON 作为输入参数的 function 列表。 + */ + private List tools; + + /** + * 辅助属性 + */ + @JsonIgnore + private List functions; + + /** + * 控制模型调用 tool 的行为。 + * none 意味着模型不会调用任何 tool,而是生成一条消息。 + * auto 意味着模型可以选择生成一条消息或调用一个或多个 tool。 + * 当没有 tool 时,默认值为 none。如果有 tool 存在,默认值为 auto。 + */ + @JsonProperty("tool_choice") + private String toolChoice; + + /** + * 是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。 + */ + @Builder.Default + private Boolean logprobs = false; + + /** + * 一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。 + */ + @JsonProperty("top_logprobs") + private Integer topLogprobs; + + public static class DeepSeekChatCompletionBuilder { + private List functions; + + public DeepSeekChatCompletion.DeepSeekChatCompletionBuilder functions(String... functions){ + if (this.functions == null) { + this.functions = new ArrayList<>(); + } + this.functions.addAll(Arrays.asList(functions)); + return this; + } + + public DeepSeekChatCompletion.DeepSeekChatCompletionBuilder functions(List functions){ + if (this.functions == null) { + this.functions = new ArrayList<>(); + } + this.functions.addAll(functions); + return this; + } + + + } +} diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletionResponse.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletionResponse.java new file mode 100644 index 0000000..7aa31e6 --- /dev/null +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/deepseek/chat/entity/DeepSeekChatCompletionResponse.java @@ -0,0 +1,61 @@ +package io.github.lnyocly.ai4j.platform.deepseek.chat.entity; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.github.lnyocly.ai4j.platform.openai.chat.entity.Choice; +import io.github.lnyocly.ai4j.platform.openai.usage.Usage; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +/** + * @Author cly + * @Description DeepSeek对话响应实体 + * @Date 2024/8/29 10:28 + */ + +@Data +@NoArgsConstructor() +@AllArgsConstructor() +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +public class DeepSeekChatCompletionResponse { + /** + * 该对话的唯一标识符。 + */ + private String id; + + /** + * 对象的类型, 其值为 chat.completion 或 chat.completion.chunk + */ + private String object; + + /** + * 创建聊天完成时的 Unix 时间戳(以秒为单位)。 + */ + private Long created; + + /** + * 生成该 completion 的模型名。 + */ + private String model; + + /** + * 模型生成的 completion 的选择列表。 + */ + private List choices; + + /** + * 该对话补全请求的用量信息。 + */ + private Usage usage; + + /** + * 该指纹代表模型运行时使用的后端配置。 + */ + @JsonProperty("system_fingerprint") + private String systemFingerprint; +} diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/OpenAiChatService.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/OpenAiChatService.java index 1b22818..dd0f3cd 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/OpenAiChatService.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/OpenAiChatService.java @@ -48,6 +48,8 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat if(baseUrl == null || "".equals(baseUrl)) baseUrl = openAiConfig.getApiHost(); if(apiKey == null || "".equals(apiKey)) apiKey = openAiConfig.getApiKey(); chatCompletion.setStream(false); + chatCompletion.setStreamOptions(null); + if(chatCompletion.getFunctions()!=null && !chatCompletion.getFunctions().isEmpty()){ List tools = ToolUtil.getAllFunctionTools(chatCompletion.getFunctions()); chatCompletion.setTools(tools); @@ -56,9 +58,12 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat // 总token消耗 Usage allUsage = new Usage(); - String finishReason = null; + String finishReason = "first"; + + while("first".equals(finishReason) || "tool_calls".equals(finishReason)){ + + finishReason = null; - while(finishReason == null || "tool_calls".equals(finishReason)){ // 构造请求 String requestString = JSON.toJSONString(chatCompletion); @@ -74,8 +79,6 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat Choice choice = chatCompletionResponse.getChoices().get(0); finishReason = choice.getFinishReason(); - System.out.println("finishReason: " + finishReason); - System.out.println(JSON.toJSONString(chatCompletionResponse)); Usage usage = chatCompletionResponse.getUsage(); allUsage.setCompletionTokens(allUsage.getCompletionTokens() + usage.getCompletionTokens()); @@ -136,10 +139,11 @@ public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion c chatCompletion.setTools(tools); } - String finishReason = null; + String finishReason = "first"; - while(finishReason == null || "tool_calls".equals(finishReason)){ + while("first".equals(finishReason) || "tool_calls".equals(finishReason)){ + finishReason = null; String jsonString = JSON.toJSONString(chatCompletion); Request request = new Request.Builder() @@ -148,6 +152,7 @@ public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion c .post(RequestBody.create(jsonString, MediaType.parse(Constants.APPLICATION_JSON))) .build(); + factory.newEventSource(request, eventSourceListener); eventSourceListener.getCountDownLatch().await(); diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletion.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletion.java index 7cbe23d..20944a1 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletion.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletion.java @@ -40,28 +40,49 @@ public class ChatCompletion { private List messages; /** - * 流式输出 + * 如果设置为 True,将会以 SSE(server-sent events)的形式以流式发送消息增量。消息流以 data: [DONE] 结尾 */ @Builder.Default private Boolean stream = false; + /** + * 流式输出相关选项。只有在 stream 参数为 true 时,才可设置此参数。 + */ + @Builder.Default + @JsonProperty("stream_options") + private StreamOptions streamOptions = new StreamOptions(); + + /** + * 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。 + */ + @Builder.Default @JsonProperty("frequency_penalty") private Float frequencyPenalty = 0f; /** - * [0.0, 2.0] + * 采样温度,介于 0 和 2 之间。更高的值,如 0.8,会使输出更随机,而更低的值,如 0.2,会使其更加集中和确定。 + * 我们通常建议可以更改这个值或者更改 top_p,但不建议同时对两者进行修改。 */ + @Builder.Default private Float temperature = 1f; /** - * [0.0, 1.0] + * 作为调节采样温度的替代方案,模型会考虑前 top_p 概率的 token 的结果。所以 0.1 就意味着只有包括在最高 10% 概率中的 token 会被考虑。 + * 我们通常建议修改这个值或者更改 temperature,但不建议同时对两者进行修改。 */ + @Builder.Default @JsonProperty("top_p") private Float topP = 1f; + /** + * 限制一次请求中模型生成 completion 的最大 token 数。输入 token 和输出 token 的总长度受模型的上下文长度的限制。 + */ @JsonProperty("max_tokens") private Integer maxTokens; + /** + * 模型可能会调用的 tool 的列表。目前,仅支持 function 作为工具。使用此参数来提供以 JSON 作为输入参数的 function 列表。 + */ private List tools; /** @@ -70,26 +91,60 @@ public class ChatCompletion { @JsonIgnore private List functions; + /** + * 控制模型调用 tool 的行为。 + * none 意味着模型不会调用任何 tool,而是生成一条消息。 + * auto 意味着模型可以选择生成一条消息或调用一个或多个 tool。 + * 当没有 tool 时,默认值为 none。如果有 tool 存在,默认值为 auto。 + */ @JsonProperty("tool_choice") private String toolChoice; + @Builder.Default @JsonProperty("parallel_tool_calls") private Boolean parallelToolCalls = true; + /** + * 一个 object,指定模型必须输出的格式。 + * + * 设置为 { "type": "json_object" } 以启用 JSON 模式,该模式保证模型生成的消息是有效的 JSON。 + * + * 注意: 使用 JSON 模式时,你还必须通过系统或用户消息指示模型生成 JSON。 + * 否则,模型可能会生成不断的空白字符,直到生成达到令牌限制,从而导致请求长时间运行并显得“卡住”。 + * 此外,如果 finish_reason="length",这表示生成超过了 max_tokens 或对话超过了最大上下文长度,消息内容可能会被部分截断。 + */ @JsonProperty("response_format") private Object responseFormat; private String user; + @Builder.Default private Integer n = 1; + /** + * 在遇到这些词时,API 将停止生成更多的 token。 + */ private List stop; + /** + * 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其是否已在已有文本中出现受到相应的惩罚,从而增加模型谈论新主题的可能性。 + */ + @Builder.Default + @JsonProperty("presence_penalty") + private Float presencePenalty = 0f; + @JsonProperty("logit_bias") private Map logitBias; + /** + * 是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。 + */ + @Builder.Default private Boolean logprobs = false; + /** + * 一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。 + */ @JsonProperty("top_logprobs") private Integer topLogprobs; diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletionResponse.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletionResponse.java index 2b5d591..03d6158 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletionResponse.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/ChatCompletionResponse.java @@ -23,14 +23,39 @@ @JsonIgnoreProperties(ignoreUnknown = true) @JsonInclude(JsonInclude.Include.NON_NULL) public class ChatCompletionResponse { + /** + * 该对话的唯一标识符。 + */ private String id; + + /** + * 对象的类型, 其值为 chat.completion 或 chat.completion.chunk + */ private String object; + + /** + * 创建聊天完成时的 Unix 时间戳(以秒为单位)。 + */ private Long created; + + /** + * 生成该 completion 的模型名。 + */ private String model; + + /** + * 该指纹代表模型运行时使用的后端配置。 + */ @JsonProperty("system_fingerprint") private String systemFingerprint; + /** + * 模型生成的 completion 的选择列表。 + */ private List choices; + /** + * 该对话补全请求的用量信息。 + */ private Usage usage; } diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/Choice.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/Choice.java index 83748fc..864c8ce 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/Choice.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/Choice.java @@ -7,7 +7,7 @@ /** * @Author cly - * @Description TODO + * @Description 模型生成的 completion * @Date 2024/8/11 20:01 */ @Data @@ -21,6 +21,18 @@ public class Choice { private Object logprobs; + /** + * 模型停止生成 token 的原因。 + * + * [stop, length, content_filter, tool_calls, insufficient_system_resource] + * + * stop:模型自然停止生成,或遇到 stop 序列中列出的字符串。 + * length:输出长度达到了模型上下文长度限制,或达到了 max_tokens 的限制。 + * content_filter:输出内容因触发过滤策略而被过滤。 + * tool_calls:函数调用。 + * insufficient_system_resource:系统推理资源不足,生成被打断。 + * + */ @JsonProperty("finish_reason") private String finishReason; } diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/StreamOptions.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/StreamOptions.java new file mode 100644 index 0000000..20fbbbb --- /dev/null +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/openai/chat/entity/StreamOptions.java @@ -0,0 +1,24 @@ +package io.github.lnyocly.ai4j.platform.openai.chat.entity; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * @Author cly + * @Description 流式输出相关选项 + * @Date 2024/8/29 13:00 + */ +@Data +@NoArgsConstructor() +@AllArgsConstructor() +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +public class StreamOptions { + @JsonProperty("include_usage") + private Boolean includeUsage = true; +} diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/ZhipuChatService.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/ZhipuChatService.java index c99acee..6ca9931 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/ZhipuChatService.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/ZhipuChatService.java @@ -32,11 +32,11 @@ /** * @Author cly - * @Description TODO + * @Description 智谱chat服务 * @Date 2024/8/27 17:29 */ @Slf4j -public class ZhipuChatService implements IChatService, ParameterConvert, ResultConvert { +public class ZhipuChatService implements IChatService, ParameterConvert, ResultConvert { private final ZhipuConfig zhipuConfig; private final OkHttpClient okHttpClient; @@ -54,8 +54,8 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat if(baseUrl == null || "".equals(baseUrl)) baseUrl = zhipuConfig.getApiHost(); if(apiKey == null || "".equals(apiKey)) apiKey = zhipuConfig.getApiKey(); chatCompletion.setStream(false); + chatCompletion.setStreamOptions(null); - String finishReason = null; // 根据key获取token String token = BearerTokenUtils.getToken(apiKey); @@ -72,8 +72,11 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat // 总token消耗 Usage allUsage = new Usage(); + String finishReason = "first"; - while(finishReason == null || "tool_calls".equals(finishReason)){ + while("first".equals(finishReason) || "tool_calls".equals(finishReason)){ + + finishReason = null; // 构造请求 String requestString = JSON.toJSONString(zhipuChatCompletion); @@ -89,8 +92,6 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat Choice choice = zhipuChatCompletionResponse.getChoices().get(0); finishReason = choice.getFinishReason(); - System.out.println("finishReason: " + finishReason); - System.out.println(JSON.toJSONString(zhipuChatCompletionResponse)); Usage usage = zhipuChatCompletionResponse.getUsage(); allUsage.setCompletionTokens(allUsage.getCompletionTokens() + usage.getCompletionTokens()); @@ -118,7 +119,7 @@ public ChatCompletionResponse chatCompletion(String baseUrl, String apiKey, Chat }else{ // 其他情况直接返回 zhipuChatCompletionResponse.setUsage(allUsage); - + zhipuChatCompletionResponse.setObject("chat.completion"); // 恢复原始请求数据 chatCompletion.setMessages(zhipuChatCompletion.getMessages()); chatCompletion.setTools(zhipuChatCompletion.getTools()); @@ -146,7 +147,6 @@ public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion c if(apiKey == null || "".equals(apiKey)) apiKey = zhipuConfig.getApiKey(); chatCompletion.setStream(true); - String finishReason = null; // 根据key获取token String token = BearerTokenUtils.getToken(apiKey); @@ -160,8 +160,11 @@ public void chatCompletionStream(String baseUrl, String apiKey, ChatCompletion c zhipuChatCompletion.setTools(tools); } - while(finishReason == null || "tool_calls".equals(finishReason)){ + String finishReason = "first"; + + while("first".equals(finishReason) || "tool_calls".equals(finishReason)){ + finishReason = null; String jsonString = JSON.toJSONString(zhipuChatCompletion); Request request = new Request.Builder() @@ -216,7 +219,7 @@ public ZhipuChatCompletion convertChatCompletionObject(ChatCompletion chatComple zhipuChatCompletion.setModel(chatCompletion.getModel()); zhipuChatCompletion.setMessages(chatCompletion.getMessages()); zhipuChatCompletion.setStream(chatCompletion.getStream()); - zhipuChatCompletion.setTemperature(chatCompletion.getTemperature()); + zhipuChatCompletion.setTemperature(chatCompletion.getTemperature() / 2); zhipuChatCompletion.setTopP(chatCompletion.getTopP()); zhipuChatCompletion.setMaxTokens(chatCompletion.getMaxTokens()); zhipuChatCompletion.setStop(chatCompletion.getStop()); @@ -247,11 +250,9 @@ public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Null } ZhipuChatCompletionResponse chatCompletionResponse = JSON.parseObject(data, ZhipuChatCompletionResponse.class); + chatCompletionResponse.setObject("chat.completion.chunk"); ChatCompletionResponse response = convertChatCompletionResponse(chatCompletionResponse); - // 把智谱的格式,转为OpenAi的格式传输给onEvent - - eventSourceListener.onEvent(eventSource, id, type, JSON.toJSONString(response)); } diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletion.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletion.java index 8631049..95e0078 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletion.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletion.java @@ -27,7 +27,9 @@ @JsonInclude(JsonInclude.Include.NON_NULL) public class ZhipuChatCompletion { + @NonNull private String model; + @NonNull private List messages; @JsonProperty("request_id") diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletionResponse.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletionResponse.java index 1cbc04c..3e46be5 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletionResponse.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/platform/zhipu/chat/entity/ZhipuChatCompletionResponse.java @@ -24,9 +24,9 @@ @JsonInclude(JsonInclude.Include.NON_NULL) public class ZhipuChatCompletionResponse { private String id; + private String object; private Long created; private String model; private List choices; - private Usage usage; } diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/Configuration.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/Configuration.java index 66d5c95..960242a 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/Configuration.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/Configuration.java @@ -1,5 +1,6 @@ package io.github.lnyocly.ai4j.service; +import io.github.lnyocly.ai4j.config.DeepSeekConfig; import io.github.lnyocly.ai4j.config.OpenAiConfig; import io.github.lnyocly.ai4j.config.PineconeConfig; import io.github.lnyocly.ai4j.config.ZhipuConfig; @@ -28,6 +29,7 @@ public EventSource.Factory createRequestFactory() { private OpenAiConfig openAiConfig; private ZhipuConfig zhipuConfig; + private DeepSeekConfig deepSeekConfig; private PineconeConfig pineconeConfig; diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/PlatformType.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/PlatformType.java index 3f9ced0..7d0d05f 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/PlatformType.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/PlatformType.java @@ -13,6 +13,7 @@ public enum PlatformType { OPENAI("openai"), ZHIPU("zhipu"), + DEEPSEEK("deepseek"), ; private final String platform; diff --git a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/factor/AiService.java b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/factor/AiService.java index a13368b..50d8fc0 100644 --- a/ai4j/src/main/java/io/github/lnyocly/ai4j/service/factor/AiService.java +++ b/ai4j/src/main/java/io/github/lnyocly/ai4j/service/factor/AiService.java @@ -1,5 +1,6 @@ package io.github.lnyocly.ai4j.service.factor; +import io.github.lnyocly.ai4j.platform.deepseek.chat.DeepSeekChatService; import io.github.lnyocly.ai4j.platform.openai.chat.OpenAiChatService; import io.github.lnyocly.ai4j.platform.zhipu.chat.ZhipuChatService; import io.github.lnyocly.ai4j.service.Configuration; @@ -38,6 +39,8 @@ private IChatService createChatService(PlatformType platform) { return new OpenAiChatService(configuration); case ZHIPU: return new ZhipuChatService(configuration); + case DEEPSEEK: + return new DeepSeekChatService(configuration); default: throw new IllegalArgumentException("Unknown platform: " + platform); } diff --git a/ai4j/src/test/java/io/github/lnyocly/OpenAiTest.java b/ai4j/src/test/java/io/github/lnyocly/OpenAiTest.java index df86b67..de415a1 100644 --- a/ai4j/src/test/java/io/github/lnyocly/OpenAiTest.java +++ b/ai4j/src/test/java/io/github/lnyocly/OpenAiTest.java @@ -1,9 +1,11 @@ package io.github.lnyocly; import com.alibaba.fastjson2.JSON; +import io.github.lnyocly.ai4j.config.DeepSeekConfig; import io.github.lnyocly.ai4j.config.OpenAiConfig; import io.github.lnyocly.ai4j.config.ZhipuConfig; +import io.github.lnyocly.ai4j.interceptor.ErrorInterceptor; import io.github.lnyocly.ai4j.listener.SseListener; import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletion; import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletionResponse; @@ -67,10 +69,12 @@ public class OpenAiTest { public void test_init(){ OpenAiConfig openAiConfig = new OpenAiConfig(); ZhipuConfig zhipuConfig = new ZhipuConfig(); + DeepSeekConfig deepSeekConfig = new DeepSeekConfig(); Configuration configuration = new Configuration(); configuration.setOpenAiConfig(openAiConfig); configuration.setZhipuConfig(zhipuConfig); + configuration.setDeepSeekConfig(deepSeekConfig); HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(); httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS); @@ -78,6 +82,7 @@ public void test_init(){ OkHttpClient okHttpClient = new OkHttpClient .Builder() .addInterceptor(httpLoggingInterceptor) + .addInterceptor(new ErrorInterceptor()) .connectTimeout(300, TimeUnit.SECONDS) .writeTimeout(300, TimeUnit.SECONDS) .readTimeout(300, TimeUnit.SECONDS) @@ -89,7 +94,7 @@ public void test_init(){ embeddingService = aiService.getEmbeddingService(PlatformType.OPENAI); //chatService = aiService.getChatService(PlatformType.getPlatform("OPENAI")); - chatService = aiService.getChatService(PlatformType.ZHIPU); + chatService = aiService.getChatService(PlatformType.DEEPSEEK); } @@ -138,7 +143,7 @@ public void test_embed() throws Exception { @Test public void test_chatCompletions_common() throws Exception { ChatCompletion chatCompletion = ChatCompletion.builder() - .model("glm-4-flash") + .model("deepseek-chat") .message(ChatMessage.withUser("鲁迅为什么打周树人")) .build(); @@ -174,32 +179,33 @@ public void test_chatCompletions_multimodal() throws Exception { @Test public void test_chatCompletions_stream() throws Exception { ChatCompletion chatCompletion = ChatCompletion.builder() - .model("gpt-4o-mini") + .model("deepseek-chat") .message(ChatMessage.withUser("鲁迅为什么打周树人")) .build(); System.out.println("请求参数"); System.out.println(chatCompletion); - CountDownLatch countDownLatch = new CountDownLatch(1); - chatService.chatCompletionStream(chatCompletion, new SseListener() { + // 构造监听器 + SseListener sseListener = new SseListener() { @Override protected void send() { - + System.out.println(this.getCurrStr()); } - }); + }; - countDownLatch.await(); + chatService.chatCompletionStream(chatCompletion, sseListener); System.out.println("请求成功"); + System.out.println(sseListener.getOutput()); } @Test public void test_chatCompletions_function() throws Exception { ChatCompletion chatCompletion = ChatCompletion.builder() - .model("gpt-4o-mini") + .model("deepseek-chat") .message(ChatMessage.withUser("查询洛阳明天的天气,并告诉我火车是否发车")) .functions("queryWeather", "queryTrainInfo") .build(); @@ -221,7 +227,7 @@ public void test_chatCompletions_stream_function() throws Exception { // 构造请求参数 ChatCompletion chatCompletion = ChatCompletion.builder() - .model("glm-4-flash") + .model("deepseekaa-chat") .message(ChatMessage.withUser("查询洛阳明天的天气")) .functions("queryWeather", "queryTrainInfo") .build(); @@ -241,6 +247,8 @@ protected void send() { chatService.chatCompletionStream(chatCompletion, sseListener); System.out.println("完整内容: "); System.out.println(sseListener.getOutput()); + System.out.println("内容花费: "); + System.out.println(sseListener.getUsage()); } @Test diff --git a/pom.xml b/pom.xml index 85da238..0a156db 100644 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ io.github.lnyo-cly ai4j-sdk - 0.2.0 + 0.3.0 pom