Skip to content

Commit

Permalink
Preserve direct tool outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Dec 10, 2024
1 parent 99eb5d2 commit 7f9a8f8
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 34 deletions.
17 changes: 16 additions & 1 deletion langchain-core/src/messages/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,23 @@ export interface ToolMessageFieldsWithToolCallId extends BaseMessageFields {
status?: "success" | "error";
}

export interface DirectToolOutput {
readonly lc_tool_output: boolean;
}

export function isDirectToolOutput(x: unknown): x is DirectToolOutput {
return (
x != null &&
typeof x === "object" &&
"lc_tool_output" in x &&
x.lc_tool_output === true
);
}

/**
* Represents a tool message in a conversation.
*/
export class ToolMessage extends BaseMessage {
export class ToolMessage extends BaseMessage implements DirectToolOutput {
static lc_name() {
return "ToolMessage";
}
Expand All @@ -40,6 +53,8 @@ export class ToolMessage extends BaseMessage {
return { tool_call_id: "tool_call_id" };
}

lc_tool_output = true;

/**
* Status of the tool invocation.
* @version 0.2.19
Expand Down
61 changes: 40 additions & 21 deletions langchain-core/src/tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
type RunnableConfig,
} from "../runnables/config.js";
import type { RunnableFunc, RunnableInterface } from "../runnables/base.js";
import { ToolCall, ToolMessage } from "../messages/tool.js";
import { isDirectToolOutput, ToolCall, ToolMessage } from "../messages/tool.js";
import { MessageContent } from "../messages/base.js";
import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js";
import { _isToolCall, ToolInputParsingException } from "./utils.js";
Expand Down Expand Up @@ -159,7 +159,7 @@ export abstract class StructuredTool<
protected abstract _call(
arg: z.output<T>,
runManager?: CallbackManagerForToolRun,
parentConfig?: RunnableConfig
parentConfig?: RunnableConfig & { toolCall?: ToolCall }
): Promise<ToolReturnType>;

/**
Expand All @@ -182,21 +182,24 @@ export abstract class StructuredTool<
| ToolCall
| undefined;

let enrichedConfig: RunnableConfig & { toolCall?: ToolCall } =
ensureConfig(config);
if (_isToolCall(input)) {
tool_call_id = input.id;
toolInput = input.args;
enrichedConfig = {
...enrichedConfig,
toolCall: input,
configurable: {
...enrichedConfig.configurable,
tool_call_id,
},
};
} else {
toolInput = input;
}

const ensuredConfig = ensureConfig(config);
return this.call(toolInput, {
...ensuredConfig,
configurable: {
...ensuredConfig.configurable,
tool_call_id,
},
});
return this.call(toolInput, enrichedConfig);
}

/**
Expand All @@ -211,8 +214,8 @@ export abstract class StructuredTool<
* @returns A Promise that resolves with a string.
*/
async call(
arg: (z.output<T> extends string ? string : never) | z.input<T> | ToolCall,
configArg?: Callbacks | RunnableConfig,
arg: (z.output<T> extends string ? string : never) | z.input<T>,
configArg?: Callbacks | (RunnableConfig & { toolCall?: ToolCall }),
/** @deprecated */
tags?: string[]
): Promise<ToolReturnType> {
Expand All @@ -229,7 +232,7 @@ export abstract class StructuredTool<
}

const config = parseCallbackConfigArg(configArg);
const callbackManager_ = await CallbackManager.configure(
const callbackManager_ = CallbackManager.configure(
config.callbacks,
this.callbacks,
config.tags || tags,
Expand Down Expand Up @@ -350,7 +353,7 @@ export interface DynamicToolInput extends BaseDynamicToolInput {
func: (
input: string,
runManager?: CallbackManagerForToolRun,
config?: RunnableConfig
config?: RunnableConfig & { toolCall?: ToolCall }
) => Promise<ToolReturnType>;
}

Expand Down Expand Up @@ -400,7 +403,7 @@ export class DynamicTool extends Tool {
*/
async call(
arg: string | undefined | z.input<this["schema"]> | ToolCall,
configArg?: RunnableConfig | Callbacks
configArg?: (RunnableConfig & { toolCall?: ToolCall }) | Callbacks
): Promise<ToolReturnType> {
const config = parseCallbackConfigArg(configArg);
if (config.runName === undefined) {
Expand All @@ -413,7 +416,7 @@ export class DynamicTool extends Tool {
async _call(
input: string,
runManager?: CallbackManagerForToolRun,
parentConfig?: RunnableConfig
parentConfig?: RunnableConfig & { toolCall?: ToolCall }
): Promise<ToolReturnType> {
return this.func(input, runManager, parentConfig);
}
Expand Down Expand Up @@ -553,26 +556,42 @@ interface ToolWrapperParams<
* @returns {DynamicStructuredTool<T>} A new StructuredTool instance.
*/
export function tool<T extends z.ZodString>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
func: RunnableFunc<
z.output<T>,
ToolReturnType,
RunnableConfig & { toolCall?: ToolCall }
>,
fields: ToolWrapperParams<T>
): DynamicTool;

export function tool<T extends ZodObjectAny>(
func: RunnableFunc<z.output<T>, ToolReturnType>,
func: RunnableFunc<
z.output<T>,
ToolReturnType,
RunnableConfig & { toolCall?: ToolCall }
>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T>;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function tool<T extends Record<string, any>>(
func: RunnableFunc<T, ToolReturnType>,
func: RunnableFunc<
T,
ToolReturnType,
RunnableConfig & { toolCall?: ToolCall }
>,
fields: ToolWrapperParams<T>
): DynamicStructuredTool<T>;

export function tool<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends ZodObjectAny | z.ZodString | Record<string, any> = ZodObjectAny
>(
func: RunnableFunc<T extends ZodObjectAny ? z.output<T> : T, ToolReturnType>,
func: RunnableFunc<
T extends ZodObjectAny ? z.output<T> : T,
ToolReturnType,
RunnableConfig & { toolCall?: ToolCall }
>,
fields: ToolWrapperParams<T>
):
| DynamicStructuredTool<T extends ZodObjectAny ? T : ZodObjectAny>
Expand Down Expand Up @@ -649,7 +668,7 @@ function _formatToolOutput(params: {
toolCallId?: string;
}): ToolReturnType {
const { content, artifact, toolCallId } = params;
if (toolCallId) {
if (toolCallId && !isDirectToolOutput(content)) {
if (
typeof content === "string" ||
(Array.isArray(content) &&
Expand Down
81 changes: 69 additions & 12 deletions langchain-core/src/tools/tests/tools.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ test("Does not return tool message if responseFormat is content_and_artifact and
const weatherSchema = z.object({
location: z.string(),
});
const toolCall = {
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
} as const;

const weatherTool = tool(
(input) => {
(input, config) => {
expect(config.toolCall).toEqual(toolCall);
return ["msg_content", input];
},
{
Expand All @@ -63,11 +69,7 @@ test("Does not return tool message if responseFormat is content_and_artifact and
}
);

const toolResult = await weatherTool.invoke({
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
});
const toolResult = await weatherTool.invoke(toolCall);

expect(toolResult).toBe("msg_content");
});
Expand All @@ -77,8 +79,16 @@ test("Returns tool message if responseFormat is content_and_artifact and returns
location: z.string(),
});

const toolCall = {
id: "testid",
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
} as const;

const weatherTool = tool(
(input) => {
(input, config) => {
expect(config.toolCall).toEqual(toolCall);
return ["msg_content", input];
},
{
Expand All @@ -88,23 +98,63 @@ test("Returns tool message if responseFormat is content_and_artifact and returns
}
);

const toolResult = await weatherTool.invoke({
const toolResult = await weatherTool.invoke(toolCall);

expect(toolResult).toBeInstanceOf(ToolMessage);
expect(toolResult.content).toBe("msg_content");
expect(toolResult.artifact).toEqual({ location: "San Francisco" });
expect(toolResult.name).toBe("weather");
});

test("Does not double wrap a returned tool message even if a tool call with id is passed in", async () => {
const weatherSchema = z.object({
location: z.string(),
});

const toolCall = {
id: "testid",
args: { location: "San Francisco" },
name: "weather",
type: "tool_call",
});
} as const;

const weatherTool = tool(
(_, config) => {
expect(config.toolCall).toEqual(toolCall);
return new ToolMessage({
tool_call_id: "not_original",
content: "bar",
name: "baz",
});
},
{
name: "weather",
schema: weatherSchema,
}
);

const toolResult = await weatherTool.invoke(toolCall);

expect(toolResult).toBeInstanceOf(ToolMessage);
expect(toolResult.content).toBe("msg_content");
expect(toolResult.artifact).toEqual({ location: "San Francisco" });
expect(toolResult.name).toBe("weather");
expect(toolResult.tool_call_id).toBe("not_original");
expect(toolResult.content).toBe("bar");
expect(toolResult.name).toBe("baz");
});

test("Tool can accept single string input", async () => {
const toolCall = {
id: "testid",
args: { input: "b" },
name: "string_tool",
type: "tool_call",
} as const;

const stringTool = tool<z.ZodString>(
(input: string, config): string => {
expect(config).toMatchObject({ configurable: { foo: "bar" } });
if (config.configurable.usesToolCall) {
expect(config.toolCall).toEqual(toolCall);
}
return `${input}a`;
},
{
Expand All @@ -116,6 +166,13 @@ test("Tool can accept single string input", async () => {

const result = await stringTool.invoke("b", { configurable: { foo: "bar" } });
expect(result).toBe("ba");

const result2 = await stringTool.invoke(toolCall, {
configurable: { foo: "bar", usesToolCall: true },
});
expect(result2).toBeInstanceOf(ToolMessage);
expect(result2.content).toBe("ba");
expect(result2.name).toBe("string_tool");
});

test("Tool declared with JSON schema", async () => {
Expand Down

0 comments on commit 7f9a8f8

Please sign in to comment.