Skip to content

Commit

Permalink
feat: 新增对DeepSeek平台的支持、新增stream_options可以直接统计usage、新增错误拦截器`ErrorInter…
Browse files Browse the repository at this point in the history
…ceptor.java`、发布0.3.0版本
  • Loading branch information
LnYo-Cly committed Aug 29, 2024
1 parent 9fbff06 commit 8121b36
Show file tree
Hide file tree
Showing 24 changed files with 835 additions and 54 deletions.
22 changes: 17 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
## 支持的平台
+ OpenAi
+ Zhipu
+ DeepSeek
+ 待添加

## 支持的服务
Expand All @@ -19,11 +20,13 @@
+ 支持流式输出。支持函数调用参数输出
+ 轻松使用Tool Calls
+ 支持多个函数同时调用(智谱不支持)
+ 支持stream_options,流式输出直接获取token usage
+ 内置向量数据库支持: Pinecone
+ 使用Tika读取文件
+ Token统计`TikTokensUtil.java`

## 更新日志
+ [2024-08-29] 新增对DeepSeek平台的支持、新增stream_options可以直接统计usage、新增错误拦截器`ErrorInterceptor.java`、发布0.3.0版本.
+ [2024-08-29] 修改SseListener以兼容智谱函数调用
+ [2024-08-28] 添加token统计、添加智谱AI的Chat服务、优化函数调用可以支持多轮多函数。
+ [2024-08-17] 增强SseListener监听器功能。发布0.2.0版本。
Expand All @@ -32,11 +35,11 @@
## 导入
### Gradle
```groovy
implementation group: 'io.github.lnyo-cly', name: 'ai4j', version: '0.1.0'
implementation group: 'io.github.lnyo-cly', name: 'ai4j', version: '0.3.0'
```

```groovy
implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', version: '0.1.0'
implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', version: '0.3.0'
```


Expand All @@ -46,7 +49,7 @@ implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', ver
<dependency>
<groupId>io.github.lnyo-cly</groupId>
<artifactId>ai4j</artifactId>
<version>0.2.0</version>
<version>0.3.0</version>
</dependency>

```
Expand All @@ -55,7 +58,7 @@ implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', ver
<dependency>
<groupId>io.github.lnyo-cly</groupId>
<artifactId>ai4j-spring-boot-stater</artifactId>
<version>0.2.0</version>
<version>0.3.0</version>
</dependency>
```

Expand Down Expand Up @@ -89,7 +92,7 @@ implementation group: 'io.github.lnyo-cly', name: 'ai4j-spring-boot-stater', ver

}
```
#### Spring获取
### Spring获取
```yml
# 国内访问默认需要代理
ai:
Expand All @@ -98,11 +101,20 @@ ai:
okhttp:
proxy-port: 10809
proxy-url: "127.0.0.1"
zhipu:
api-key: "xxx"
#other...
```

```java
// 注入Ai服务
@Autowired
private AiService aiService;

// 获取需要的服务实例
IChatService chatService = aiService.getChatService(PlatformType.OPENAI);
IEmbeddingService embeddingService = aiService.getEmbeddingService(PlatformType.OPENAI);
// ......
```

## Chat服务
Expand Down
4 changes: 2 additions & 2 deletions ai4j-spring-boot-stater/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
<groupId>io.github.lnyo-cly</groupId>
<artifactId>ai4j-spring-boot-stater</artifactId>
<packaging>jar</packaging>
<version>0.2.0</version>
<version>0.3.0</version>

<name>ai4j-spring-boot-stater</name>
<description>ai4j-spring-boot-stater</description>
<description>为aj4j所提供的spring-stater,便于接入spring项目。关于ai4j: 整合多平台大模型,如OpenAi、Zhipu(ChatGLM)、DeepSeek等等,提供统一的输入输出(对齐OpenAi),优化函数调用(Tool Call),优化RAG调用、支持向量数据库(Pinecone),并且支持JDK1.8,为用户提供快速整合AI的能力。</description>

<licenses>
<license>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package io.github.lnyocly.ai4j;

import io.github.lnyocly.ai4j.config.DeepSeekConfig;
import io.github.lnyocly.ai4j.config.OpenAiConfig;
import io.github.lnyocly.ai4j.config.PineconeConfig;
import io.github.lnyocly.ai4j.config.ZhipuConfig;
import io.github.lnyocly.ai4j.interceptor.ErrorInterceptor;
import io.github.lnyocly.ai4j.service.PlatformType;
import io.github.lnyocly.ai4j.service.factor.AiService;
import io.github.lnyocly.ai4j.vector.service.PineconeService;
Expand All @@ -27,21 +29,24 @@
OpenAiConfigProperties.class,
OkHttpConfigProperties.class,
PineconeConfigProperties.class,
ZhipuConfigProperties.class})
ZhipuConfigProperties.class,
DeepSeekConfigProperties.class})
public class AiConfigAutoConfiguration {

private final OkHttpConfigProperties okHttpConfigProperties;
private final OpenAiConfigProperties openAiConfigProperties;
private final PineconeConfigProperties pineconeConfigProperties;
private final ZhipuConfigProperties zhipuConfigProperties;
private final DeepSeekConfigProperties deepSeekConfigProperties;

private io.github.lnyocly.ai4j.service.Configuration configuration = new io.github.lnyocly.ai4j.service.Configuration();

public AiConfigAutoConfiguration(OkHttpConfigProperties okHttpConfigProperties, OpenAiConfigProperties openAiConfigProperties, PineconeConfigProperties pineconeConfigProperties, ZhipuConfigProperties zhipuConfigProperties) {
public AiConfigAutoConfiguration(OkHttpConfigProperties okHttpConfigProperties, OpenAiConfigProperties openAiConfigProperties, PineconeConfigProperties pineconeConfigProperties, ZhipuConfigProperties zhipuConfigProperties, DeepSeekConfigProperties deepSeekConfigProperties) {
this.okHttpConfigProperties = okHttpConfigProperties;
this.openAiConfigProperties = openAiConfigProperties;
this.pineconeConfigProperties = pineconeConfigProperties;
this.zhipuConfigProperties = zhipuConfigProperties;
this.deepSeekConfigProperties = deepSeekConfigProperties;
}

@Bean
Expand All @@ -60,6 +65,7 @@ private void init() {
initOpenAiConfig();
initPineconeConfig();
initZhipuConfig();
initDeepSeekConfig();
}

private void initOkHttp() {
Expand All @@ -75,6 +81,7 @@ private void initOkHttp() {
OkHttpClient okHttpClient = new OkHttpClient
.Builder()
.addInterceptor(httpLoggingInterceptor)
.addInterceptor(new ErrorInterceptor())
.connectTimeout(okHttpConfigProperties.getConnectTimeout(), okHttpConfigProperties.getTimeUnit())
.writeTimeout(okHttpConfigProperties.getWriteTimeout(), okHttpConfigProperties.getTimeUnit())
.readTimeout(okHttpConfigProperties.getReadTimeout(), okHttpConfigProperties.getTimeUnit())
Expand Down Expand Up @@ -115,6 +122,13 @@ private void initPineconeConfig() {
configuration.setPineconeConfig(pineconeConfig);
}

private void initDeepSeekConfig(){
DeepSeekConfig deepSeekConfig = new DeepSeekConfig();
deepSeekConfig.setApiHost(deepSeekConfigProperties.getApiHost());
deepSeekConfig.setApiKey(deepSeekConfigProperties.getApiKey());
deepSeekConfig.setChat_completion(deepSeekConfigProperties.getChat_completion());

configuration.setDeepSeekConfig(deepSeekConfig);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.github.lnyocly.ai4j;

import org.springframework.boot.context.properties.ConfigurationProperties;

/**
* @Author cly
* @Description TODO
* @Date 2024/8/29 15:01
*/
@ConfigurationProperties(prefix = "ai.deepseek")
public class DeepSeekConfigProperties {

private String apiHost = "https://api.deepseek.com/";
private String apiKey = "";
private String chat_completion = "chat/completions";

public String getApiHost() {
return apiHost;
}

public void setApiHost(String apiHost) {
this.apiHost = apiHost;
}

public String getChat_completion() {
return chat_completion;
}

public void setChat_completion(String chat_completion) {
this.chat_completion = chat_completion;
}

public String getApiKey() {
return apiKey;
}

public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}
}
18 changes: 16 additions & 2 deletions ai4j/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
<groupId>io.github.lnyo-cly</groupId>
<artifactId>ai4j</artifactId>
<packaging>jar</packaging>
<version>0.2.0</version>
<version>0.3.0</version>

<name>ai4j</name>
<description>ai4j基础组件项目</description>
<description>整合多平台大模型,如OpenAi、Zhipu(ChatGLM)、DeepSeek等等,提供统一的输入输出(对齐OpenAi),优化函数调用(Tool Call),优化RAG调用、支持向量数据库(Pinecone),并且支持JDK1.8,为用户提供快速整合AI的能力。</description>


<properties>
Expand Down Expand Up @@ -78,6 +78,20 @@
<artifactId>logback-classic</artifactId>
<version>1.2.3</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.slf4j/slf4j-api -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.30</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.slf4j/slf4j-log4j12 -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>1.7.30</version>
<!-- <scope>test</scope>-->
</dependency>

<dependency>
<groupId>org.reflections</groupId>
<artifactId>reflections</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package io.github.lnyocly.ai4j.config;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
* @Author cly
* @Description DeepSeek 配置文件
* @Date 2024/8/29 10:31
*/

@Data
@NoArgsConstructor
@AllArgsConstructor
public class DeepSeekConfig {

private String apiHost = "https://api.deepseek.com/";
private String apiKey = "";
private String chat_completion = "chat/completions";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package io.github.lnyocly.ai4j.interceptor;

import io.github.lnyocly.ai4j.exception.CommonException;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;

/**
* @Author cly
* @Description 错误处理器
* @Date 2024/8/29 14:55
*/
@Slf4j
public class ErrorInterceptor implements Interceptor {
@NotNull
@Override
public Response intercept(@NotNull Chain chain) throws IOException {
Request original = chain.request();

Response response = chain.proceed(original);

if(!response.isSuccessful()){
//response.close();
String errorMsg = response.body().string();

log.error("AI服务请求异常:{}", errorMsg);
throw new CommonException(errorMsg);


}


return response;
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package io.github.lnyocly.ai4j.listener;

import com.alibaba.fastjson2.JSON;
import io.github.lnyocly.ai4j.exception.CommonException;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletionResponse;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatMessage;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.Choice;
import io.github.lnyocly.ai4j.platform.openai.chat.enums.ChatMessageType;
import io.github.lnyocly.ai4j.platform.openai.tool.ToolCall;
import io.github.lnyocly.ai4j.platform.openai.usage.Usage;
Expand Down Expand Up @@ -60,7 +62,7 @@ public abstract class SseListener extends EventSourceListener {
* 花费token
*/
@Getter
private Usage usage = null;
private final Usage usage = new Usage();

@Setter
@Getter
Expand All @@ -82,26 +84,38 @@ public abstract class SseListener extends EventSourceListener {

@Override
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
log.error("流式输出异常 onFailure ");

countDownLatch.countDown();
}

@Override
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
if ("[DONE]".equalsIgnoreCase(data)) {
log.info("模型会话 [DONE]");
//log.info("模型会话 [DONE]");
return;
}

ChatCompletionResponse chatCompletionResponse = JSON.parseObject(data, ChatCompletionResponse.class);
ChatMessage responseMessage = chatCompletionResponse.getChoices().get(0).getDelta();
// 统计token,当设置include_usage = true时,最后一条消息会携带usage, 其他消息中usage为null
Usage currUsage = chatCompletionResponse.getUsage();
if(currUsage != null){
usage.setPromptTokens(usage.getPromptTokens() + currUsage.getPromptTokens());
usage.setCompletionTokens(usage.getCompletionTokens() + currUsage.getCompletionTokens());
usage.setTotalTokens(usage.getTotalTokens() + currUsage.getTotalTokens());
}


List<Choice> choices = chatCompletionResponse.getChoices();

if(choices == null || choices.isEmpty()){
return;
}
ChatMessage responseMessage = choices.get(0).getDelta();

finishReason = chatCompletionResponse.getChoices().get(0).getFinishReason();
finishReason = choices.get(0).getFinishReason();

// tool_calls回答已经结束
if("tool_calls".equals(chatCompletionResponse.getChoices().get(0).getFinishReason())){
if("tool_calls".equals(choices.get(0).getFinishReason())){
if(toolCall == null && responseMessage.getToolCalls()!=null) {
toolCalls = responseMessage.getToolCalls();
if(showToolArgs){
Expand Down Expand Up @@ -171,12 +185,11 @@ public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Null



log.info("测试结果:{}", chatCompletionResponse);
//log.info("测试结果:{}", chatCompletionResponse);
}

@Override
public void onClosed(@NotNull EventSource eventSource) {
log.info("调用 onClosed ");
countDownLatch.countDown();
countDownLatch = new CountDownLatch(1);

Expand Down
Loading

0 comments on commit 8121b36

Please sign in to comment.