From 44f73c4319d43a52cd48542574fbd836ed099875 Mon Sep 17 00:00:00 2001 From: Ryan Hopper-Lowe Date: Fri, 22 Nov 2024 15:33:08 -0600 Subject: [PATCH] feat: add credential based authentication to chat --- ui/admin/app/components/chat/Message.tsx | 125 ++++++++++++++++++--- ui/admin/app/lib/model/chatEvents.ts | 19 +++- ui/admin/app/lib/model/messages.ts | 6 +- ui/admin/app/lib/routers/apiRoutes.ts | 4 + ui/admin/app/lib/service/api/PromptApi.tsx | 14 +++ 5 files changed, 146 insertions(+), 22 deletions(-) create mode 100644 ui/admin/app/lib/service/api/PromptApi.tsx diff --git a/ui/admin/app/components/chat/Message.tsx b/ui/admin/app/components/chat/Message.tsx index 3708c513..a4b92587 100644 --- a/ui/admin/app/components/chat/Message.tsx +++ b/ui/admin/app/components/chat/Message.tsx @@ -1,21 +1,34 @@ import "@radix-ui/react-tooltip"; import { WrenchIcon } from "lucide-react"; -import React, { useMemo } from "react"; +import React, { useMemo, useState } from "react"; +import { useForm } from "react-hook-form"; import Markdown, { defaultUrlTransform } from "react-markdown"; import rehypeExternalLinks from "rehype-external-links"; import remarkGfm from "remark-gfm"; -import { OAuthPrompt } from "~/lib/model/chatEvents"; +import { AuthPrompt } from "~/lib/model/chatEvents"; import { Message as MessageType } from "~/lib/model/messages"; +import { PromptApiService } from "~/lib/service/api/PromptApi"; import { cn } from "~/lib/utils"; import { TypographyP } from "~/components/Typography"; import { MessageDebug } from "~/components/chat/MessageDebug"; import { ToolCallInfo } from "~/components/chat/ToolCallInfo"; +import { ControlledInput } from "~/components/form/controlledInputs"; import { CustomMarkdownComponents } from "~/components/react-markdown"; import { ToolIcon } from "~/components/tools/ToolIcon"; import { Button } from "~/components/ui/button"; import { Card } from "~/components/ui/card"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "~/components/ui/dialog"; +import { Form } from "~/components/ui/form"; +import { Link } from "~/components/ui/link"; +import { useAsync } from "~/hooks/useAsync"; interface MessageProps { message: MessageType; @@ -126,7 +139,9 @@ export const Message = React.memo(({ message }: MessageProps) => { Message.displayName = "Message"; -function PromptMessage({ prompt }: { prompt: OAuthPrompt }) { +function PromptMessage({ prompt }: { prompt: AuthPrompt }) { + const [open, setOpen] = useState(false); + if (!prompt.metadata) return null; return ( @@ -141,23 +156,107 @@ function PromptMessage({ prompt }: { prompt: OAuthPrompt }) { Tool Call requires authentication - + Authenticate with {prompt.metadata.category} + + )} + + {prompt.metadata.authType === "basic" && prompt.fields && ( + + + + + + + + + Authenticate with {prompt.metadata.category} + + + + setOpen(false)} + /> + + + )} ); } + +function PromptAuthForm({ + prompt, + onSuccess, +}: { + prompt: AuthPrompt; + onSuccess: () => void; +}) { + const authenticate = useAsync(PromptApiService.promptResponse, { + onSuccess, + }); + + const form = useForm>({ + defaultValues: prompt.fields?.reduce( + (acc, field) => { + acc[field] = ""; + return acc; + }, + {} as Record + ), + }); + + const handleSubmit = form.handleSubmit(async (values) => + authenticate.execute({ id: prompt.id, response: values }) + ); + + return ( +
+ + {prompt.fields?.map((field) => ( + + ))} + + + + + ); +} diff --git a/ui/admin/app/lib/model/chatEvents.ts b/ui/admin/app/lib/model/chatEvents.ts index 189bad3d..2ac485d7 100644 --- a/ui/admin/app/lib/model/chatEvents.ts +++ b/ui/admin/app/lib/model/chatEvents.ts @@ -13,23 +13,30 @@ export type ToolCall = { }; }; -type PromptOAuthMeta = { - authType: "oauth"; - authURL: string; +type PromptAuthMetaBase = { category: string; icon: string; toolContext: string; toolDisplayName: string; }; -export type OAuthPrompt = { +type PromptOAuthMeta = PromptAuthMetaBase & { + authType: "oauth"; + authURL: string; +}; + +type PromptAuthBasicMeta = PromptAuthMetaBase & { + authType: "basic"; +}; + +export type AuthPrompt = { id?: string; name: string; time?: Date; message: string; fields?: string[]; sensitive?: boolean; - metadata?: PromptOAuthMeta; + metadata?: PromptOAuthMeta | PromptAuthBasicMeta; }; // note(ryanhopperlowe) renaming this to ChatEvent to differentiate itself specifically for a chat with an agent @@ -45,7 +52,7 @@ export type ChatEvent = { waitingOnModel?: boolean; toolInput?: ToolInput; toolCall?: ToolCall; - prompt?: OAuthPrompt; + prompt?: AuthPrompt; }; export function combineChatEvents(events: ChatEvent[]): ChatEvent[] { diff --git a/ui/admin/app/lib/model/messages.ts b/ui/admin/app/lib/model/messages.ts index d7813b86..72c5c688 100644 --- a/ui/admin/app/lib/model/messages.ts +++ b/ui/admin/app/lib/model/messages.ts @@ -1,4 +1,4 @@ -import { ChatEvent, OAuthPrompt, ToolCall } from "~/lib/model/chatEvents"; +import { AuthPrompt, ChatEvent, ToolCall } from "~/lib/model/chatEvents"; import { Run } from "~/lib/model/runs"; export interface Message { @@ -6,7 +6,7 @@ export interface Message { sender: "user" | "agent"; // note(ryanhopperlowe) we only support one tool call per message for now // leaving it as an array case that changes in the future - prompt?: OAuthPrompt; + prompt?: AuthPrompt; tools?: ToolCall[]; runId?: string; isLoading?: boolean; @@ -40,7 +40,7 @@ export const toolCallMessage = (toolCall: ToolCall): Message => ({ tools: [toolCall], }); -export const promptMessage = (prompt: OAuthPrompt, runID: string): Message => ({ +export const promptMessage = (prompt: AuthPrompt, runID: string): Message => ({ sender: "agent", text: prompt.message, prompt, diff --git a/ui/admin/app/lib/routers/apiRoutes.ts b/ui/admin/app/lib/routers/apiRoutes.ts index 4951a4f7..947be4c5 100644 --- a/ui/admin/app/lib/routers/apiRoutes.ts +++ b/ui/admin/app/lib/routers/apiRoutes.ts @@ -127,6 +127,10 @@ export const ApiRoutes = { buildUrl(`/threads/${threadId}/knowledge`), getFiles: (threadId: string) => buildUrl(`/threads/${threadId}/files`), }, + prompt: { + base: () => buildUrl("/prompt"), + promptResponse: () => buildUrl("/prompt"), + }, runs: { base: () => buildUrl("/runs"), getRunById: (runId: string) => buildUrl(`/runs/${runId}`), diff --git a/ui/admin/app/lib/service/api/PromptApi.tsx b/ui/admin/app/lib/service/api/PromptApi.tsx new file mode 100644 index 00000000..d8db2cce --- /dev/null +++ b/ui/admin/app/lib/service/api/PromptApi.tsx @@ -0,0 +1,14 @@ +import { ApiRoutes } from "~/lib/routers/apiRoutes"; +import { request } from "~/lib/service/api/primitives"; + +async function promptResponse(prompt: { + id?: string; + response?: Record; +}) { + await request({ + method: "POST", + url: ApiRoutes.prompt.promptResponse().url, + data: prompt, + }); +} +export const PromptApiService = { promptResponse };