From 9b66c67445b4afa3c877b930cae704724e52155f Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Sun, 24 Nov 2024 15:56:21 -0500 Subject: [PATCH] Fix json parsing of >1 tool calls --- .../tjake/jlama/model/AbstractModel.java | 3 +- .../jlama/safetensors/prompt/ToolCall.java | 5 +-- .../github/tjake/jlama/util/JsonSupport.java | 2 +- .../tjake/jlama/model/TestCorrectness.java | 35 +++++++++++++++++++ .../github/tjake/jlama/model/TestModels.java | 2 +- pom.xml | 2 +- 6 files changed, 41 insertions(+), 8 deletions(-) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 45754bb6..aecfbd19 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -18,6 +18,7 @@ import static com.github.tjake.jlama.util.DebugSupport.debug; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; import com.github.tjake.jlama.math.ActivationFunction; import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.model.functions.*; @@ -656,7 +657,7 @@ protected Generator.Response postProcessResponse(PromptContext promptContext, Ge List toolCalls = new ArrayList<>(jsonCalls.size()); for (String jsonCall : jsonCalls) { if (jsonCall.startsWith("[")) { - List toolCallList = JsonSupport.om.readValue(jsonCall, ToolCall.toolCallListTypeReference); + List toolCallList = JsonSupport.om.readValue(jsonCall, new TypeReference<>() {}); toolCalls.addAll(toolCallList); } else { ToolCall toolCall = JsonSupport.om.readValue(jsonCall, ToolCall.class); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java index d50a8d70..f961136e 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/ToolCall.java @@ -15,17 +15,14 @@ */ package com.github.tjake.jlama.safetensors.prompt; -import static com.github.tjake.jlama.util.JsonSupport.om; import com.fasterxml.jackson.annotation.*; -import com.fasterxml.jackson.databind.type.ArrayType; + import java.util.Map; import java.util.Objects; @JsonPropertyOrder({ ToolResult.JSON_PROPERTY_TOOL_NAME, ToolResult.JSON_PROPERTY_TOOL_ID }) public class ToolCall { - public static final ArrayType toolCallListTypeReference = om.getTypeFactory().constructArrayType(ToolCall.class); - @JsonProperty("name") private final String name; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java index 5995f503..9a4585dd 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/JsonSupport.java @@ -65,7 +65,7 @@ public static List extractJsonFromString(String s) { while (endIndex + extra < text.length() && (text.charAt(endIndex + extra) == '}' || text.charAt(endIndex + extra) == ']')) { extra++; } - String jsonString = s.substring(i, endIndex + extra); + String jsonString = text.substring(i, endIndex + extra); jsons.add(jsonString); found = true; text = text.substring(endIndex + extra); diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java index 87da9caf..f63454e4 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java @@ -17,6 +17,7 @@ import static com.github.tjake.jlama.util.JsonSupport.om; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.github.tjake.jlama.math.FloatConversions; import com.github.tjake.jlama.math.VectorMath; @@ -26,13 +27,17 @@ import com.github.tjake.jlama.safetensors.prompt.*; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; import com.github.tjake.jlama.safetensors.tokenizer.WordPieceTokenizer; +import com.github.tjake.jlama.util.JsonSupport; import com.google.common.io.Resources; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Comparator; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + import org.junit.Assert; import org.junit.Assume; import org.junit.Test; @@ -385,4 +390,34 @@ public void testMistralTools() { Assert.assertEquals(expected, prompt); } + + @Test + public void testToolParse() throws JsonProcessingException { + String input = "To get the payment status and date of transaction T1005, I will use the retrievePaymentStatus and retrievePaymentDate functions.\n" + + "\n" + + "[{\"name\": \"retrievePaymentStatus\", \"arguments\": {\"arg0\": \"T1005\"}}]\n" + + "[{\"name\": \"retrievePaymentDate\", \"arguments\": {\"arg0\": \"T1005\"}}]\n" + + "\n" + + "Please wait while I retrieve the data..."; + + List jsonCalls = JsonSupport.extractJsonFromString(input); + + Assert.assertEquals(2, jsonCalls.size()); + + List toolCalls = new ArrayList<>(jsonCalls.size()); + for (String jsonCall : jsonCalls) { + if (jsonCall.startsWith("[")) { + List toolCallList = JsonSupport.om.readValue(jsonCall, new TypeReference<>() {}); + toolCalls.addAll(toolCallList); + } else { + ToolCall toolCall = JsonSupport.om.readValue(jsonCall, ToolCall.class); + toolCalls.add(toolCall); + } + } + + // Remove duplicates + toolCalls = toolCalls.stream().sorted(Comparator.comparing(ToolCall::getName)).distinct().collect(Collectors.toList()); + + Assert.assertEquals(2, toolCalls.size()); + } } diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index ab5dedad..ae98fe61 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -171,7 +171,7 @@ public void DeepCoderRun() throws Exception { @Test public void MistralRun() throws Exception { - String modelPrefix = "../models/Mistral-7B-Instruct-v0.3-jlama-Q4"; + String modelPrefix = "../models/tjake_Mistral-7B-Instruct-v0.3-JQ4"; Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) { BPETokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); diff --git a/pom.xml b/pom.xml index 51732e72..6edcd79a 100644 --- a/pom.xml +++ b/pom.xml @@ -42,7 +42,7 @@ UTF-8 - 0.8.2 + 0.8.3 2.0.7 1.5.6