Skip to content

Commit

Permalink
Merge pull request #124 from tjake/fix-many-tool-calls
Browse files Browse the repository at this point in the history
Fix json parsing of >1 tool calls
  • Loading branch information
tjake authored Nov 24, 2024
2 parents 88a930b + 9b66c67 commit 190464c
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static com.github.tjake.jlama.util.DebugSupport.debug;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.model.functions.*;
Expand Down Expand Up @@ -656,7 +657,7 @@ protected Generator.Response postProcessResponse(PromptContext promptContext, Ge
List<ToolCall> toolCalls = new ArrayList<>(jsonCalls.size());
for (String jsonCall : jsonCalls) {
if (jsonCall.startsWith("[")) {
List<ToolCall> toolCallList = JsonSupport.om.readValue(jsonCall, ToolCall.toolCallListTypeReference);
List<ToolCall> toolCallList = JsonSupport.om.readValue(jsonCall, new TypeReference<>() {});
toolCalls.addAll(toolCallList);
} else {
ToolCall toolCall = JsonSupport.om.readValue(jsonCall, ToolCall.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
*/
package com.github.tjake.jlama.safetensors.prompt;

import static com.github.tjake.jlama.util.JsonSupport.om;

import com.fasterxml.jackson.annotation.*;
import com.fasterxml.jackson.databind.type.ArrayType;

import java.util.Map;
import java.util.Objects;

@JsonPropertyOrder({ ToolResult.JSON_PROPERTY_TOOL_NAME, ToolResult.JSON_PROPERTY_TOOL_ID })
public class ToolCall {
public static final ArrayType toolCallListTypeReference = om.getTypeFactory().constructArrayType(ToolCall.class);

@JsonProperty("name")
private final String name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public static List<String> extractJsonFromString(String s) {
while (endIndex + extra < text.length() && (text.charAt(endIndex + extra) == '}' || text.charAt(endIndex + extra) == ']')) {
extra++;
}
String jsonString = s.substring(i, endIndex + extra);
String jsonString = text.substring(i, endIndex + extra);
jsons.add(jsonString);
found = true;
text = text.substring(endIndex + extra);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static com.github.tjake.jlama.util.JsonSupport.om;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.github.tjake.jlama.math.FloatConversions;
import com.github.tjake.jlama.math.VectorMath;
Expand All @@ -26,13 +27,17 @@
import com.github.tjake.jlama.safetensors.prompt.*;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.safetensors.tokenizer.WordPieceTokenizer;
import com.github.tjake.jlama.util.JsonSupport;
import com.google.common.io.Resources;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;

import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;
Expand Down Expand Up @@ -385,4 +390,34 @@ public void testMistralTools() {

Assert.assertEquals(expected, prompt);
}

@Test
public void testToolParse() throws JsonProcessingException {
String input = "To get the payment status and date of transaction T1005, I will use the retrievePaymentStatus and retrievePaymentDate functions.\n" +
"\n" +
"[{\"name\": \"retrievePaymentStatus\", \"arguments\": {\"arg0\": \"T1005\"}}]\n" +
"[{\"name\": \"retrievePaymentDate\", \"arguments\": {\"arg0\": \"T1005\"}}]\n" +
"\n" +
"Please wait while I retrieve the data...";

List<String> jsonCalls = JsonSupport.extractJsonFromString(input);

Assert.assertEquals(2, jsonCalls.size());

List<ToolCall> toolCalls = new ArrayList<>(jsonCalls.size());
for (String jsonCall : jsonCalls) {
if (jsonCall.startsWith("[")) {
List<ToolCall> toolCallList = JsonSupport.om.readValue(jsonCall, new TypeReference<>() {});
toolCalls.addAll(toolCallList);
} else {
ToolCall toolCall = JsonSupport.om.readValue(jsonCall, ToolCall.class);
toolCalls.add(toolCall);
}
}

// Remove duplicates
toolCalls = toolCalls.stream().sorted(Comparator.comparing(ToolCall::getName)).distinct().collect(Collectors.toList());

Assert.assertEquals(2, toolCalls.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public void DeepCoderRun() throws Exception {

@Test
public void MistralRun() throws Exception {
String modelPrefix = "../models/Mistral-7B-Instruct-v0.3-jlama-Q4";
String modelPrefix = "../models/tjake_Mistral-7B-Instruct-v0.3-JQ4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));
try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) {
BPETokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
<!-- Build property abstractions: versions, etc -->
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<revision>0.8.2</revision>
<revision>0.8.3</revision>

<slf4j-api.version>2.0.7</slf4j-api.version>
<logback.version>1.5.6</logback.version>
Expand Down

0 comments on commit 190464c

Please sign in to comment.