From f6f06c08b458f25470e454c0ad20a56b5df2aeb3 Mon Sep 17 00:00:00 2001 From: Gerard Clos Date: Fri, 27 Sep 2024 12:20:49 +0200 Subject: [PATCH] feature: synthetic datasets --- .cursorignore | 3 + apps/infra/Pulumi.core.yaml | 2 + apps/infra/src/app/production/shared.ts | 13 + apps/infra/src/core/secrets.ts | 6 + apps/web/appspec.yml | 14 +- .../src/actions/datasets/generateDataset.ts | 98 ++++++ .../sdk/generateDatasetPreviewAction.ts | 64 ++++ apps/web/src/actions/sdk/runDocumentAction.ts | 2 +- apps/web/src/app/(private)/_lib/createSdk.ts | 20 +- .../_components/PreviewDatasetModal/index.tsx | 4 +- .../datasets/generate/CsvPreviewTable.tsx | 48 +++ .../generate/GenerateDatasetContent.tsx | 317 ++++++++++++++++++ .../datasets/generate/LoadingText.tsx | 44 +++ .../app/(private)/datasets/generate/page.tsx | 26 ++ .../web/src/app/(private)/datasets/layout.tsx | 13 +- .../editor/import-logs/page.tsx | 46 +-- .../components/ProjectDocumentSelector.tsx | 55 +++ apps/web/src/env.ts | 8 + apps/web/src/hooks/useStreamableAction.ts | 75 +++++ apps/web/src/services/routes.ts | 3 + apps/web/src/stores/documentVersions.ts | 6 +- .../web-ui/src/ds/atoms/FormField/index.tsx | 38 ++- .../web-ui/src/ds/atoms/Skeleton/index.tsx | 5 +- packages/web-ui/src/ds/atoms/Table/index.tsx | 7 +- packages/web-ui/src/ds/atoms/Text/index.tsx | 4 + .../web-ui/src/ds/atoms/Tooltip/index.tsx | 44 ++- .../src/ds/molecules/TableSkeleton/index.tsx | 20 +- packages/web-ui/tailwind.config.js | 7 +- turbo.json | 5 +- 29 files changed, 909 insertions(+), 88 deletions(-) create mode 100644 .cursorignore create mode 100644 apps/web/src/actions/datasets/generateDataset.ts create mode 100644 apps/web/src/actions/sdk/generateDatasetPreviewAction.ts create mode 100644 apps/web/src/app/(private)/datasets/generate/CsvPreviewTable.tsx create mode 100644 apps/web/src/app/(private)/datasets/generate/GenerateDatasetContent.tsx create mode 100644 apps/web/src/app/(private)/datasets/generate/LoadingText.tsx create mode 100644 apps/web/src/app/(private)/datasets/generate/page.tsx create mode 100644 apps/web/src/components/ProjectDocumentSelector.tsx create mode 100644 apps/web/src/hooks/useStreamableAction.ts diff --git a/.cursorignore b/.cursorignore new file mode 100644 index 000000000..ce90f5140 --- /dev/null +++ b/.cursorignore @@ -0,0 +1,3 @@ +# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv) +node_modules +drizzle diff --git a/apps/infra/Pulumi.core.yaml b/apps/infra/Pulumi.core.yaml index 64e822af2..4e2b4ddb2 100644 --- a/apps/infra/Pulumi.core.yaml +++ b/apps/infra/Pulumi.core.yaml @@ -2,6 +2,8 @@ encryptionsalt: v1:K+dTqOgU40c=:v1:xzsAOOiJEEAsdCQ4:Oe7NKHjXdYdZrcrrBa1/0yolY5M3 config: infra:DATABASE_PASSWORD: secure: v1:WABt/tJjfsAKMplU:RRLtj5mTu301x4sn2Tr91u4O0Msd3mWVJutcDg== + infra:DATASET_GENERATOR_WORKSPACE_APIKEY: + secure: v1:l9AShyNWMRBbgHre:cKM7YpzKRoAxChL40q2RixG0mm081qGrkpKSAitb9E5TMhXR7AVOtDaG9TiwYG9NXQNsCw== infra:DEFAULT_PROJECT_ID: secure: v1:M0giTdD2+Mjre0ps:YIaUlExLI41EFdcpf+6Fxec= infra:DEFAULT_PROVIDER_API_KEY: diff --git a/apps/infra/src/app/production/shared.ts b/apps/infra/src/app/production/shared.ts index a76072cfe..63818d464 100644 --- a/apps/infra/src/app/production/shared.ts +++ b/apps/infra/src/app/production/shared.ts @@ -30,6 +30,9 @@ const defaultProviderApiKeyArn = coreStack.requireOutput( 'defaultProviderApiKeyArn', ) const postHogApiKeyArn = coreStack.requireOutput('postHogApiKeyArn') +const datasetGeneratorWorkspaceApiKeyArn = coreStack.requireOutput( + 'datasetGeneratorWorkspaceApiKeyArn', +) const getSecretString = (arn: pulumi.Output) => { return arn.apply((secretId) => @@ -59,6 +62,9 @@ export const sentryOrg = getSecretString(sentryOrgArn) export const sentryProject = getSecretString(sentryProjectArn) export const defaultProviderApiKey = getSecretString(defaultProviderApiKeyArn) export const postHogApiKey = getSecretString(postHogApiKeyArn) +export const datasetGeneratorWorkspaceApiKey = getSecretString( + datasetGeneratorWorkspaceApiKeyArn, +) export const dbUrl = pulumi.interpolate`postgresql://${dbUsername}:${dbPassword}@${dbEndpoint}/${dbName}?sslmode=verify-full&sslrootcert=/app/packages/core/src/assets/eu-central-1-bundle.pem` export const environment = pulumi @@ -75,6 +81,7 @@ export const environment = pulumi defaultProjectId, defaultProviderApiKey, postHogApiKey, + datasetGeneratorWorkspaceApiKey, ]) .apply(() => { return [ @@ -115,5 +122,11 @@ export const environment = pulumi { name: 'DEFAULT_PROVIDER_API_KEY', value: defaultProviderApiKey }, { name: 'NEXT_PUBLIC_POSTHOG_KEY', value: postHogApiKey }, { name: 'NEXT_PUBLIC_POSTHOG_HOST', value: 'https://eu.i.posthog.com' }, + { + name: 'DATASET_GENERATOR_WORKSPACE_APIKEY', + value: datasetGeneratorWorkspaceApiKey, + }, + { name: 'DATASET_GENERATOR_PROJECT_ID', value: '74' }, + { name: 'DATASET_GENERATOR_DOCUMENT_PATH', value: 'generator' }, ] }) diff --git a/apps/infra/src/core/secrets.ts b/apps/infra/src/core/secrets.ts index 905b5a4ca..1df4e5896 100644 --- a/apps/infra/src/core/secrets.ts +++ b/apps/infra/src/core/secrets.ts @@ -70,6 +70,10 @@ const postHogApiKey = createSecretWithVersion( 'NEXT_PUBLIC_POSTHOG_KEY', 'Posthog API Key for product analytics', ) +const datasetGeneratorWorkspaceApiKey = createSecretWithVersion( + 'DATASET_GENERATOR_WORKSPACE_APIKEY', + 'API key for the dataset generator', +) export const mailerApiKeyArn = mailerApiKey.arn export const sentryDsnArn = sentryDsn.arn @@ -83,3 +87,5 @@ export const workersWebsocketsSecretTokenArn = workersWebsocketsSecretToken.arn export const defaultProjectIdArn = defaultProjectId.arn export const defaultProviderApiKeyArn = defaultProviderApiKey.arn export const postHogApiKeyArn = postHogApiKey.arn +export const datasetGeneratorWorkspaceApiKeyArn = + datasetGeneratorWorkspaceApiKey.arn diff --git a/apps/web/appspec.yml b/apps/web/appspec.yml index 39290fd86..020a73590 100644 --- a/apps/web/appspec.yml +++ b/apps/web/appspec.yml @@ -1,9 +1,9 @@ version: 0.0 Resources: - - TargetService: - Type: AWS::ECS::Service - Properties: - TaskDefinition: arn:aws:ecs:eu-central-1:442420265876:task-definition/LatitudeLLMTaskFamily:91 - LoadBalancerInfo: - ContainerName: "LatitudeLLMAppContainer" - ContainerPort: 8080 + - TargetService: + Type: AWS::ECS::Service + Properties: + TaskDefinition: arn:aws:ecs:eu-central-1:442420265876:task-definition/LatitudeLLMTaskFamily:92 + LoadBalancerInfo: + ContainerName: 'LatitudeLLMAppContainer' + ContainerPort: 8080 diff --git a/apps/web/src/actions/datasets/generateDataset.ts b/apps/web/src/actions/datasets/generateDataset.ts new file mode 100644 index 000000000..f703331be --- /dev/null +++ b/apps/web/src/actions/datasets/generateDataset.ts @@ -0,0 +1,98 @@ +'use server' + +import { + ChainObjectResponse, + Dataset, + StreamEventTypes, +} from '@latitude-data/core/browser' +import { BadRequestError } from '@latitude-data/core/lib/errors' +import { createDataset } from '@latitude-data/core/services/datasets/create' +import { ChainEventDto } from '@latitude-data/sdk' +import { createSdk } from '$/app/(private)/_lib/createSdk' +import env from '$/env' +import { getCurrentUser } from '$/services/auth/getCurrentUser' +import { createStreamableValue } from 'ai/rsc' + +type GenerateDatasetActionProps = { + parameters: Record + description: string + rowCount: number + name: string +} + +export async function generateDatasetAction({ + parameters, + description, + rowCount, + name, +}: GenerateDatasetActionProps) { + if (!env.DATASET_GENERATOR_PROJECT_ID) { + throw new BadRequestError('PROJECT_ID_DATASET_GENERATION is not set') + } + if (!env.DATASET_GENERATOR_DOCUMENT_PATH) { + throw new BadRequestError('DATASET_GENERATOR_DOCUMENT_PATH is not set') + } + if (!env.DATASET_GENERATOR_WORKSPACE_APIKEY) { + throw new BadRequestError('DATASET_GENERATOR_WORKSPACE_APIKEY is not set') + } + + let response: Dataset | undefined + const { user, workspace } = await getCurrentUser() + const stream = createStreamableValue< + { event: StreamEventTypes; data: ChainEventDto }, + Error + >() + const sdk = await createSdk({ + apiKey: env.DATASET_GENERATOR_WORKSPACE_APIKEY, + projectId: env.DATASET_GENERATOR_PROJECT_ID, + }).then((r) => r.unwrap()) + const sdkResponse = await sdk.run(env.DATASET_GENERATOR_DOCUMENT_PATH, { + parameters: { + row_count: rowCount, + parameters, + user_message: description, + }, + onError: (error) => { + stream.error({ + name: error.name, + message: error.message, + stack: error.stack, + }) + }, + }) + + try { + const sdkResult = await sdkResponse + const csv = (sdkResult?.response! as ChainObjectResponse).object.csv + const result = await createDataset({ + author: user, + workspace, + data: { + name, + file: new File([csv], `${name}.csv`, { type: 'text/csv' }), + csvDelimiter: ',', + }, + }) + if (result.error) { + stream.error({ + name: result.error.name, + message: result.error.message, + stack: result.error.stack, + }) + } else { + response = result.value + stream.done() + } + } catch (error) { + stream.error({ + name: (error as Error).name, + message: (error as Error).message, + stack: (error as Error).stack, + }) + } + + return { + output: stream.value, + response, + } +} diff --git a/apps/web/src/actions/sdk/generateDatasetPreviewAction.ts b/apps/web/src/actions/sdk/generateDatasetPreviewAction.ts new file mode 100644 index 000000000..16ec33cf6 --- /dev/null +++ b/apps/web/src/actions/sdk/generateDatasetPreviewAction.ts @@ -0,0 +1,64 @@ +'use server' + +import { StreamEventTypes } from '@latitude-data/core/browser' +import { BadRequestError } from '@latitude-data/core/lib/errors' +import { ChainEventDto } from '@latitude-data/sdk' +import { createSdk } from '$/app/(private)/_lib/createSdk' +import env from '$/env' +import { createStreamableValue } from 'ai/rsc' + +type RunDocumentActionProps = { + projectId: number + documentUuid: string + parameters: Record + description: string +} + +export async function generateDatasetPreviewAction({ + parameters, + description, +}: RunDocumentActionProps) { + const stream = createStreamableValue< + { event: StreamEventTypes; data: ChainEventDto }, + Error + >() + if (!env.DATASET_GENERATOR_PROJECT_ID) { + throw new BadRequestError('PROJECT_ID_DATASET_GENERATION is not set') + } + if (!env.DATASET_GENERATOR_DOCUMENT_PATH) { + throw new BadRequestError('DATASET_GENERATOR_DOCUMENT_PATH is not set') + } + if (!env.DATASET_GENERATOR_WORKSPACE_APIKEY) { + throw new BadRequestError('DATASET_GENERATOR_WORKSPACE_APIKEY is not set') + } + + const sdk = await createSdk({ + apiKey: env.DATASET_GENERATOR_WORKSPACE_APIKEY, + projectId: env.DATASET_GENERATOR_PROJECT_ID, + }).then((r) => r.unwrap()) + const response = await sdk.run(env.DATASET_GENERATOR_DOCUMENT_PATH, { + parameters: { + row_count: 10, + parameters, + user_message: description, + }, + onEvent: (event) => { + stream.update(event) + }, + onError: (error) => { + stream.error({ + name: error.name, + message: error.message, + stack: error.stack, + }) + }, + onFinished: () => { + stream.done() + }, + }) + + return { + output: stream.value, + response, + } +} diff --git a/apps/web/src/actions/sdk/runDocumentAction.ts b/apps/web/src/actions/sdk/runDocumentAction.ts index c0ba9a2a2..988cf00e9 100644 --- a/apps/web/src/actions/sdk/runDocumentAction.ts +++ b/apps/web/src/actions/sdk/runDocumentAction.ts @@ -25,7 +25,7 @@ export async function runDocumentAction({ commitUuid, parameters, }: RunDocumentActionProps) { - const sdk = await createSdk(projectId).then((r) => r.unwrap()) + const sdk = await createSdk({ projectId }).then((r) => r.unwrap()) const stream = createStreamableValue< { event: StreamEventTypes; data: ChainEventDto }, Error diff --git a/apps/web/src/app/(private)/_lib/createSdk.ts b/apps/web/src/app/(private)/_lib/createSdk.ts index 131a8eb9f..73246483d 100644 --- a/apps/web/src/app/(private)/_lib/createSdk.ts +++ b/apps/web/src/app/(private)/_lib/createSdk.ts @@ -21,18 +21,24 @@ async function getLatitudeApiKey() { return Result.ok(firstApiKey) } -export async function createSdk(projectId?: number) { - const result = await getLatitudeApiKey() - if (result.error) return result +export async function createSdk({ + projectId, + apiKey, +}: { + projectId?: number + apiKey?: string +} = {}) { + if (!apiKey) { + const result = await getLatitudeApiKey() + if (result.error) return result - const latitudeApiKey = result.value.token + apiKey = result.value.token + } const gateway = { host: env.GATEWAY_HOSTNAME, port: env.GATEWAY_PORT, ssl: env.GATEWAY_SSL, } - return Result.ok( - new Latitude(latitudeApiKey, compactObject({ gateway, projectId })), - ) + return Result.ok(new Latitude(apiKey, compactObject({ gateway, projectId }))) } diff --git a/apps/web/src/app/(private)/datasets/_components/PreviewDatasetModal/index.tsx b/apps/web/src/app/(private)/datasets/_components/PreviewDatasetModal/index.tsx index e1b59e8b0..67c3df079 100644 --- a/apps/web/src/app/(private)/datasets/_components/PreviewDatasetModal/index.tsx +++ b/apps/web/src/app/(private)/datasets/_components/PreviewDatasetModal/index.tsx @@ -54,7 +54,7 @@ function PreviewModal({ ) : ( - + # @@ -68,7 +68,7 @@ function PreviewModal({ {rows.map((row, rowIndex) => { return ( - + {row.map((cell, cellIndex) => ( {cell} ))} diff --git a/apps/web/src/app/(private)/datasets/generate/CsvPreviewTable.tsx b/apps/web/src/app/(private)/datasets/generate/CsvPreviewTable.tsx new file mode 100644 index 000000000..7458dcbca --- /dev/null +++ b/apps/web/src/app/(private)/datasets/generate/CsvPreviewTable.tsx @@ -0,0 +1,48 @@ +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, + Text, +} from '@latitude-data/web-ui' + +interface CsvPreviewTableProps { + csvData: { + headers: string[] + data: { + record: Record + info: { columns: { name: string }[] } + }[] + } +} + +export function CsvPreviewTable({ csvData }: CsvPreviewTableProps) { + return ( +
+ + + {csvData.headers.map((header, index) => ( + + {header} + + ))} + + + + {csvData.data.map(({ record }, rowIndex) => ( + + {csvData.headers.map((header, cellIndex) => ( + +
+ {record[header]} +
+
+ ))} +
+ ))} +
+
+ ) +} diff --git a/apps/web/src/app/(private)/datasets/generate/GenerateDatasetContent.tsx b/apps/web/src/app/(private)/datasets/generate/GenerateDatasetContent.tsx new file mode 100644 index 000000000..47945a0c2 --- /dev/null +++ b/apps/web/src/app/(private)/datasets/generate/GenerateDatasetContent.tsx @@ -0,0 +1,317 @@ +'use client' + +import { FormEvent, useEffect, useState } from 'react' + +import { ConversationMetadata, readMetadata } from '@latitude-data/compiler' +import { + ChainEventTypes, + DocumentVersion, + HEAD_COMMIT, + StreamEventTypes, +} from '@latitude-data/core/browser' +import { syncReadCsv } from '@latitude-data/core/lib/readCsv' +import { + Alert, + Badge, + Button, + CloseTrigger, + FormField, + FormWrapper, + Icon, + Input, + Modal, + TableSkeleton, + Text, + TextArea, + Tooltip, +} from '@latitude-data/web-ui' +import { generateDatasetAction } from '$/actions/datasets/generateDataset' +import { generateDatasetPreviewAction } from '$/actions/sdk/generateDatasetPreviewAction' +import { ProjectDocumentSelector } from '$/components/ProjectDocumentSelector' +import { useNavigate } from '$/hooks/useNavigate' +import { useStreamableAction } from '$/hooks/useStreamableAction' +import { ROUTES } from '$/services/routes' +import useDatasets from '$/stores/datasets' +import useDocumentVersions from '$/stores/documentVersions' + +import { CsvPreviewTable } from './CsvPreviewTable' +import { LoadingText } from './LoadingText' + +interface GenerateDatasetContentProps { + projectId: number + fallbackDocuments: DocumentVersion[] +} + +export function GenerateDatasetContent({ + projectId: fallbackProjectId, + fallbackDocuments, +}: GenerateDatasetContentProps) { + const navigate = useNavigate() + const [documentUuid, setDocumentUuid] = useState() + const [metadata, setMetadata] = useState() + const [projectId, setProjectId] = useState( + fallbackProjectId, + ) + const [previewCsv, setPreviewCsv] = useState<{ + data: { + record: Record + info: { columns: { name: string }[] } + }[] + headers: string[] + }>() + const { data: datasets, mutate } = useDatasets() + const [explanation, setExplanation] = useState() + const { data: documents } = useDocumentVersions( + { + commitUuid: HEAD_COMMIT, + projectId, + }, + { fallbackData: fallbackDocuments }, + ) + const document = documents?.find( + (document) => document.documentUuid === documentUuid, + ) + + const { + runAction: runPreviewAction, + done: previewDone, + isLoading: previewIsLoading, + error: previewError, + } = useStreamableAction( + generateDatasetPreviewAction, + async (event, data) => { + if ( + event === StreamEventTypes.Latitude && + data.type === ChainEventTypes.Complete + ) { + const parsedCsv = await syncReadCsv(data.response.object.csv, { + delimiter: ',', + }).then((r) => r.unwrap()) + setPreviewCsv(parsedCsv) + setExplanation(data.response.object.explanation) + } + }, + ) + + const { + runAction: runGenerateAction, + isLoading: generateIsLoading, + done: generateIsDone, + error: generateError, + } = useStreamableAction( + generateDatasetAction, + async (event, data) => { + if ( + event === StreamEventTypes.Latitude && + data.type === ChainEventTypes.Complete + ) { + const parsedCsv = await syncReadCsv(data.response.object.csv, { + delimiter: ',', + }).then((r) => r.unwrap()) + setPreviewCsv(parsedCsv) + } + }, + ) + + const handleProjectChange = (projectId: number) => { + setProjectId(projectId) + } + + const handleDocumentChange = (newDocumentUuid: string) => { + setDocumentUuid(newDocumentUuid) + } + + const handleSubmit = async (e: FormEvent) => { + e.preventDefault() + const formData = new FormData(e.target as HTMLFormElement) + + if (!previewCsv) { + await runPreviewAction({ + parameters: formData.get('parameters') as string, + description: formData.get('description') as string, + }) + } else { + const response = await runGenerateAction({ + parameters: formData.get('parameters') as string, + description: formData.get('description') as string, + rowCount: parseInt(formData.get('rows') as string, 10), + name: formData.get('name') as string, + }) + + try { + const dataset = await response + if (!dataset) return + + mutate([...datasets, dataset]) + navigate.push(ROUTES.datasets.root) + } catch (error) { + console.error(error) + } + } + } + + const handleRegeneratePreview = async () => { + const form = window.document.getElementById( + 'generateDatasetForm', + ) as HTMLFormElement + const formData = new FormData(form) + + await runPreviewAction({ + parameters: formData.get('parameters') as string, + description: formData.get('description') as string, + }) + } + + useEffect(() => { + const fetchMetadata = async () => { + if (document) { + const metadata = await readMetadata({ + prompt: document.resolvedContent ?? '', + fullPath: document.path, + }) + + setMetadata(metadata) + } + } + + fetchMetadata() + }, [document]) + + return ( + !open && navigate.push(ROUTES.datasets.root)} + title='Generate new dataset' + description='Generate a dataset of parameters using AI from one of your prompts. Datasets can be used to run batch evaluations over prompts.' + footer={ + <> + + {previewCsv && ( + + )} + + + } + > +
+
+ + + + + + +