Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

funcation call + stream进行调用时返回的ChatFunctionCall对应中arguments丢失 #505

Open
I-am-DJ opened this issue May 20, 2024 · 3 comments

Comments

@I-am-DJ
Copy link

I-am-DJ commented May 20, 2024

这个是我的示例代码:

     public static void main(String[] args) throws UnknownHostException, InterruptedException {
        try {
            ObjectMapper mapper = defaultObjectMapper();
            OkHttpClient client = defaultClient("", Duration.of(10000L,ChronoUnit.SECONDS))
                    .newBuilder()
                    .build();
            Retrofit retrofit = defaultRetrofit(client, mapper);
            Class<Retrofit> clazz = Retrofit.class;
            Field baseUrl = clazz.getDeclaredField("baseUrl");
            baseUrl.setAccessible(true);
            baseUrl.set(retrofit, HttpUrl.get(BASE_URL));
            OpenAiApi api = retrofit.create(OpenAiApi.class);
            OpenAiService service = new OpenAiService(api);
            List<ChatMessage> messages = Lists.newArrayList();
            messages.add(new ChatMessage("system", "Please use the functions provided below to determine what function needs to be called for the user's problem. " +
                    "If the necessary parameters are missing when calling the function, please return to the user in this format and prompt the user to pass the necessary parameters:\n" +
                    "We also need the following information to complete your request: Required Parameter 1, Required Parameter 2\n" +
                    "Make sure your prompts are accurate, polite, and the directly relevant information is obvious and understandable to users"));
            Scanner scanner = new Scanner(System.in);
            //"Tell me the weather"
            messages.add(new ChatMessage("user", scanner.nextLine()));
            while (true) {
                ChatFunctionDynamic chatFunctionDynamic = getChatFunctionDynamic();
                ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
                        .builder()
                        .model("qwen15-110b.credit-llm")
                        .messages(messages)
                        .n(1)
                        .maxTokens(256)
                        .functions(Lists.newArrayList(chatFunctionDynamic))
                        .functionCall(ChatCompletionRequest.ChatCompletionRequestFunctionCall.of("auto"))
                        .build();

                Flowable<ChatCompletionChunk> flowable = service.streamChatCompletion(chatCompletionRequest);
                AtomicBoolean isFirst = new AtomicBoolean(true);
                ChatMessage responseMessage = service.mapStreamToAccumulator(flowable).doOnNext(accumulator -> {
                            if (accumulator.isFunctionCall()) {
                                ChatFunctionCall functionCall = accumulator.getAccumulatedChatFunctionCall();
                                if (isFirst.getAndSet(false)) {
                                    System.out.println("Executing function " + functionCall.getName() + "...");
                                }
                            } else {
                                if (isFirst.getAndSet(false)) {

                                    System.out.print("Response: ");
                                }
                                if (accumulator.getMessageChunk().getContent() != null) {
                                    System.out.print(accumulator.getMessageChunk().getContent());
                                }
                            }
                        })
                        .doOnComplete(System.out::println)
                        .lastElement()
                        .blockingGet()
                        .getAccumulatedMessage();
                messages.add(responseMessage);
                ChatFunctionCall functionCall = responseMessage.getFunctionCall();
                if (functionCall != null) {
                    if (functionCall.getName().equals("get_weather")) {
                        String location = functionCall.getArguments().get("location").asText();
                        String unit = functionCall.getArguments().get("unit").asText();
                        WeatherResponse weather = getWeather(location, unit);
                        ChatMessage weatherMessage = new ChatMessage(ChatMessageRole.FUNCTION.value(), JSON.toJSONString(weather), "get_weather");
                        messages.add(weatherMessage);
                        continue;
                    }
                }
                System.out.print("Next Query: ");

                String nextLine = scanner.nextLine();
                if (nextLine.equalsIgnoreCase("exit")) {
                    System.exit(0);
                }

                messages.add(new ChatMessage(ChatMessageRole.USER.value(), nextLine));
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

    }


    private static WeatherResponse getWeather(String location, String unit) {
        return new WeatherResponse(location, WeatherUnit.valueOf(unit), new Random().nextInt(40), "sunny");
    }

    public static ChatFunctionDynamic getChatFunctionDynamic() {
        return ChatFunctionDynamic.builder()
                .name("get_weather")
                .description("Get the current weather of a location")
                .addProperty(ChatFunctionProperty.builder()
                        .name("location")
                        .type("string")
                        .description("City and state, for example: León, Guanajuato")
                        .build())
                .addProperty(ChatFunctionProperty.builder()
                        .name("unit")
                        .type("string")
                        .description("The temperature unit, can be 'CELSIUS' or 'FAHRENHEIT'")
                        .enumValues(new HashSet<>(Arrays.asList("CELSIUS", "FAHRENHEIT")))
                        .required(true)
                        .build())
                .build();
    }

对应的报错信息,在String location = functionCall.getArguments().get("location").asText();该行报错

java.lang.NullPointerException
	at com.mybank.bkinfocenter.common.recognition.web.Test.main(Test.java:96)

debug代码查看
com.theokanning.openai.service.OpenAiService#mapStreamToAccumulator方法中messageChunk中的arguments类型为objectNode,从而导致asText()方法返回的结果为""
image

请问我可以用什么简单的方法在不修改源代码的情况下来解决这个问题,非常感谢!

@I-am-DJ
Copy link
Author

I-am-DJ commented May 20, 2024

类似的问题:#463

@I-am-DJ
Copy link
Author

I-am-DJ commented May 21, 2024

我尝试修改了下com.theokanning.openai.service.OpenAiService#mapStreamToAccumulator和com.theokanning.openai.service.ChatFunctionCallArgumentsSerializerAndDeserializer.Deserializer#deserialize代码:

public static class Deserializer extends JsonDeserializer<JsonNode> {

        private Deserializer() {
        }

        @Override
        public JsonNode deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
            String json = p.getValueAsString();

            if (json == null || p.currentToken() == JsonToken.VALUE_NULL) {
                return null;
            }
            // ADDED
            json = MAPPER.writeValueAsString(json);
           // END ADDED
            try {
                JsonNode node = null;
                try {
                    node = MAPPER.readTree(json);
                } catch (JsonParseException ignored) {
                }
                if (node == null || node.getNodeType() == JsonNodeType.MISSING) {
                    node = MAPPER.readTree(p);
                }
                return node;
            } catch (Exception ex) {
                ex.printStackTrace();
                return null;
            }
        }
    }
     public Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ChatCompletionChunk> flowable) {
        ChatFunctionCall functionCall = new ChatFunctionCall(null, null);
        ChatMessage accumulatedMessage = new ChatMessage(ChatMessageRole.ASSISTANT.value(), null);

        return flowable.map(chunk -> {
            ChatMessage messageChunk = chunk.getChoices().get(0).getMessage();
            ChatFunctionCall chunkFunctionCall = new ChatFunctionCall(null, null);
            if (messageChunk.getFunctionCall() != null) {
                if (messageChunk.getFunctionCall().getName() != null) {
                    String namePart = messageChunk.getFunctionCall().getName();
                    chunkFunctionCall.setName((functionCall.getName() == null ? "" : functionCall.getName()) + namePart);
                }
                if (messageChunk.getFunctionCall().getArguments() != null) {
                    String argumentsPart = messageChunk.getFunctionCall().getArguments().asText();
                    chunkFunctionCall.setArguments(new TextNode((functionCall.getArguments() == null ? "" : functionCall.getArguments().asText()) + argumentsPart));
                }
                accumulatedMessage.setFunctionCall(functionCall);
            } else {
                accumulatedMessage.setContent((accumulatedMessage.getContent() == null ? "" : accumulatedMessage.getContent()) + (messageChunk.getContent() == null ? "" : messageChunk.getContent()));
            }

            if (chunk.getChoices().get(0).getFinishReason() != null) { // last
                if (chunkFunctionCall.getArguments() != null) {
                    functionCall.setName(chunkFunctionCall.getName());
                    functionCall.setArguments(mapper.readTree(chunkFunctionCall.getArguments().asText()));
                    accumulatedMessage.setFunctionCall(functionCall);
                }
            }

            return new ChatMessageAccumulator(messageChunk, accumulatedMessage);
        });

修改mapStreamToAccumulator的主要原因是flow会返回两次带有finishReason的情况,第一次存在functioncall,第二次没有,导致原有代码会进入两次last注释下的代码,原有的readTree函数返回的Node类型为ObjectNode,再次进入后调用asText()函数会导致数据为空字符串,所以采用chunkFunctionCall局部变量的方式fix该问题

@Lambdua
Copy link

Lambdua commented May 21, 2024

这个其实根本原因是序列化的问题. 这个库在序列化时对于<",">这个字段序列化有问题,所有会有各种textNode和ObjectNode转换问题. 我fork后的库修复了这个问题. 欢迎使用 openai4j.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants