diff --git a/assistant-ui/README.md b/assistant-ui/README.md index 401ec45d..487122a1 100644 --- a/assistant-ui/README.md +++ b/assistant-ui/README.md @@ -34,7 +34,7 @@ pnpm add @inferable/assistant-ui ```typescript import { useInferableRuntime } from '@inferable/assistant-ui'; -import { Thread } from "@assistant-ui/react"; +import { AssistantRuntimeProvider, Thread } from "@assistant-ui/react"; const { runtime, run } = useInferableRuntime({ clusterId: '', @@ -46,7 +46,9 @@ const { runtime, run } = useInferableRuntime({ return (
- + + +
); ``` @@ -68,6 +70,49 @@ You can handle errors by providing an `onError` callback: }) ``` +### Rendering function UI + +You can provide assistant-ui with [custom UI components](https://www.assistant-ui.com/docs/guides/ToolUI) for rendering Inferable function calls / results. + +#### Example + +```typescript +// Fallback UI +const FallbackToolUI = ({args, result, toolName}) => +
+

Tool: {toolName}

+

Input:

+
{JSON.stringify(args, null, 2)}
+

Output:

+ {result &&
{JSON.stringify(result, null, 2)}
} + {!result &&

No output

} +
+ +// Custom UI example +const SearchToolUI = makeAssistantToolUI({ + toolName: "default_webSearch", + render: ({ args }) => { + return

webSearch({args.query})

; + }, +}); + +return ( +
+ + + +
+); +``` + ## Local Development diff --git a/assistant-ui/demo/TestPage.jsx b/assistant-ui/demo/TestPage.jsx index cf254fbc..225f9211 100644 --- a/assistant-ui/demo/TestPage.jsx +++ b/assistant-ui/demo/TestPage.jsx @@ -1,6 +1,15 @@ import { useInferableRuntime } from '../src' -import { Thread } from "@assistant-ui/react"; -import toast from "react-hot-toast"; +import { AssistantRuntimeProvider, Thread } from "@assistant-ui/react"; + +const FallbackToolUI = ({args, result, toolName}) => +
+

Tool: {toolName}

+

Input:

+
{JSON.stringify(args, null, 2)}
+

Output:

+ {result &&
{JSON.stringify(result, null, 2)}
} + {!result &&

No output

} +
const TestPage = () => { const existingRunId = localStorage.getItem("runID") @@ -24,7 +33,13 @@ const TestPage = () => { return (
- + + +
); }; diff --git a/assistant-ui/src/inferable-provider-runtime.ts b/assistant-ui/src/inferable-provider-runtime.ts index dc90a1e2..0d67cd43 100644 --- a/assistant-ui/src/inferable-provider-runtime.ts +++ b/assistant-ui/src/inferable-provider-runtime.ts @@ -81,7 +81,7 @@ export function useInferableRuntime({ runtime: useExternalStoreRuntime({ isRunning, messages, - convertMessage, + convertMessage: (message) => convertMessage(message, messages), onNew, }), run, @@ -89,7 +89,7 @@ export function useInferableRuntime({ } -const convertMessage = (message: any): ThreadMessageLike => { +const convertMessage = (message: any, allMessages: any): ThreadMessageLike => { switch (message.type) { case "human": { const parsedData = genericMessageDataSchema.parse(message.data); @@ -114,31 +114,49 @@ const convertMessage = (message: any): ThreadMessageLike => { }); } - return { - id: message.id, - role: "assistant", - content: content + if (parsedData.invocations) { + + parsedData.invocations.forEach((invocation) => { + + // Attempt to find corresponding `invocation-result` message + let result = null; + allMessages.forEach((message: any) => { + if ('type' in message && message.type !== "invocation-result") { + return false + } + + const parsedResult = resultDataSchema.parse(message.data); + + if (parsedResult.id === invocation.id) { + result = parsedResult.result; + return true; + } + }); + + content.push({ + type: "tool-call", + toolName: invocation.toolName, + args: invocation.input, + toolCallId: invocation.id, + result + }); + }) } - } - case "invocation-result": { - const parsedData = resultDataSchema.parse(message.data); - // TODO: Search chat history for the corresponding invocation mesasge (With args) + if (content.length === 0) { + return { + id: message.id, + role: "system", + content: "MESSAGE HAS NO CONTENT" + } + } return { id: message.id, role: "assistant", - content: [{ - type: "tool-call", - toolName: "inferable", - args: {}, - toolCallId: parsedData.id, - result: parsedData.result, - }] + content: content } - } - } return { @@ -146,7 +164,7 @@ const convertMessage = (message: any): ThreadMessageLike => { role: "system", content: [{ type: "text", - text: "" + text: `UNKNON MESSAGE TYPE: ${message.type}` }], }; }