From 14bbf5900b08256d6bb81cad57450c376e8b35ce Mon Sep 17 00:00:00 2001 From: Gerard Date: Mon, 11 Nov 2024 13:07:17 +0100 Subject: [PATCH] feat(eval): generate simple evaluations (#582) Changes the copilot evaluation generator to generate simple evaluations. --- apps/web/src/actions/evaluations/create.ts | 56 +------ .../actions/evaluations/createFromPrompt.ts | 33 ++-- .../generateSuggestedEvaluations.ts | 1 + .../CreateEvaluationModal/index.tsx | 50 +++--- .../evaluations/dashboard/generate/page.tsx | 125 ++++++++++++--- apps/web/src/stores/suggestedEvaluations.ts | 20 ++- .../src/services/evaluations/create.test.ts | 60 +++++++- .../core/src/services/evaluations/create.ts | 144 +++++++++++++----- .../services/evaluations/prompt/index.test.ts | 3 +- .../src/ds/atoms/TypewriterText/index.tsx | 5 +- 10 files changed, 338 insertions(+), 159 deletions(-) diff --git a/apps/web/src/actions/evaluations/create.ts b/apps/web/src/actions/evaluations/create.ts index 1cd29ba45..01944a025 100644 --- a/apps/web/src/actions/evaluations/create.ts +++ b/apps/web/src/actions/evaluations/create.ts @@ -1,16 +1,10 @@ 'use server' import { - EvaluationMetadataLlmAsJudgeAdvanced, - EvaluationMetadataLlmAsJudgeSimple, EvaluationMetadataType, - findFirstModelForProvider, resultConfigurationSchema, - Workspace, } from '@latitude-data/core/browser' -import { NotFoundError } from '@latitude-data/core/lib/errors' import { createEvaluation } from '@latitude-data/core/services/evaluations/create' -import { findDefaultProvider } from '@latitude-data/core/services/providerApiKeys/findDefaultProvider' import { z } from 'zod' import { authProcedure } from '../procedures' @@ -42,64 +36,16 @@ export const createEvaluationAction = authProcedure { type: 'json' }, ) .handler(async ({ input, ctx }) => { - const metadata = await enrichWithProvider({ - metadata: input.metadata, - workspace: ctx.workspace, - }) - const result = await createEvaluation({ workspace: ctx.workspace, user: ctx.user, name: input.name, description: input.description, metadataType: input.metadata.type, - metadata, + metadata: input.metadata, resultType: input.resultConfiguration.type, resultConfiguration: input.resultConfiguration, }) return result.unwrap() }) - -async function enrichWithProvider({ - metadata, - workspace, -}: { - metadata: z.infer< - | typeof advancedEvaluationMetadataSchema - | typeof simpleEvaluationMetadataSchema - > - workspace: Workspace -}): Promise< - EvaluationMetadataLlmAsJudgeSimple | EvaluationMetadataLlmAsJudgeAdvanced -> { - const { type: _, ...rest } = metadata - - if (metadata.type === EvaluationMetadataType.LlmAsJudgeAdvanced) - return rest as EvaluationMetadataLlmAsJudgeAdvanced - if ( - metadata.type === EvaluationMetadataType.LlmAsJudgeSimple && - metadata.providerApiKeyId && - metadata.model - ) { - return rest as EvaluationMetadataLlmAsJudgeSimple - } - - const provider = await findDefaultProvider(workspace) - if (!provider) - throw new NotFoundError( - `No default provider found for workspace ${workspace.id}`, - ) - - const model = findFirstModelForProvider(provider.provider) - if (!model) - throw new NotFoundError( - `No default model found for provider ${provider.provider}`, - ) - - return { - ...rest, - providerApiKeyId: provider.id, - model, - } as EvaluationMetadataLlmAsJudgeSimple -} diff --git a/apps/web/src/actions/evaluations/createFromPrompt.ts b/apps/web/src/actions/evaluations/createFromPrompt.ts index 6fa04b0b0..b516a41c8 100644 --- a/apps/web/src/actions/evaluations/createFromPrompt.ts +++ b/apps/web/src/actions/evaluations/createFromPrompt.ts @@ -1,7 +1,10 @@ 'use server' -import { EvaluationResultableType } from '@latitude-data/core/browser' -import { createAdvancedEvaluation } from '@latitude-data/core/services/evaluations/create' +import { + EvaluationMetadataType, + EvaluationResultableType, +} from '@latitude-data/core/browser' +import { createEvaluation } from '@latitude-data/core/services/evaluations/create' import { z } from 'zod' import { withDocument } from '../procedures' @@ -11,21 +14,26 @@ export const createEvaluationFromPromptAction = withDocument .input( z.object({ name: z.string(), - prompt: z.string(), + objective: z.string(), + additionalInstructions: z.string().optional(), resultType: z.nativeEnum(EvaluationResultableType), - metadata: z - .object({ + metadata: z.union([ + z.object({ minValue: z.number(), maxValue: z.number(), minValueDescription: z.string().optional(), maxValueDescription: z.string().optional(), - }) - .optional(), + }), + z.object({ + falseValueDescription: z.string().optional(), + trueValueDescription: z.string().optional(), + }), + ]), }), { type: 'json' }, ) .handler(async ({ input, ctx }) => { - const result = await createAdvancedEvaluation({ + const result = await createEvaluation({ workspace: ctx.workspace, name: input.name, description: 'AI-generated evaluation', @@ -33,13 +41,12 @@ export const createEvaluationFromPromptAction = withDocument input.resultType === EvaluationResultableType.Number ? EvaluationResultableType.Number : EvaluationResultableType.Boolean, - resultConfiguration: - input.resultType === EvaluationResultableType.Number && input.metadata - ? input.metadata - : {}, + resultConfiguration: input.metadata, metadata: { - prompt: input.prompt, + objective: input.objective, + additionalInstructions: input.additionalInstructions ?? null, }, + metadataType: EvaluationMetadataType.LlmAsJudgeSimple, user: ctx.user, projectId: ctx.project.id, documentUuid: ctx.document.documentUuid, diff --git a/apps/web/src/actions/evaluations/generateSuggestedEvaluations.ts b/apps/web/src/actions/evaluations/generateSuggestedEvaluations.ts index 5ff41928b..f69d34fcf 100644 --- a/apps/web/src/actions/evaluations/generateSuggestedEvaluations.ts +++ b/apps/web/src/actions/evaluations/generateSuggestedEvaluations.ts @@ -51,6 +51,7 @@ export const generateSuggestedEvaluationsAction = authProcedure env.COPILOT_EVALUATION_SUGGESTION_PROMPT_PATH, { stream: false, + versionUuid: 'da47b89d-2bde-4c6c-92ee-11a17241eb73', // TODO: remove parameters: { user_prompt: input.documentContent, }, diff --git a/apps/web/src/app/(private)/evaluations/_components/CreateEvaluationModal/index.tsx b/apps/web/src/app/(private)/evaluations/_components/CreateEvaluationModal/index.tsx index 032de5bed..122805080 100644 --- a/apps/web/src/app/(private)/evaluations/_components/CreateEvaluationModal/index.tsx +++ b/apps/web/src/app/(private)/evaluations/_components/CreateEvaluationModal/index.tsx @@ -65,25 +65,37 @@ export default function CreateEvaluationModal({ }) const onConfirm = useCallback(() => { - create({ - name: title, - description, - metadata: { - type: EvaluationMetadataType.LlmAsJudgeSimple, - objective: '', - additionalInstructions: '', - }, - resultConfiguration: - configuration.type === EvaluationResultableType.Number - ? { - type: configuration.type, - minValue: configuration.detail!.range.from, - maxValue: configuration.detail!.range.to, - } - : { - type: configuration.type, - }, - }) + const resultConfiguration = + configuration.type === EvaluationResultableType.Number + ? { + type: configuration.type, + minValue: configuration.detail!.range.from, + maxValue: configuration.detail!.range.to, + } + : { type: configuration.type } + + if (prompt) { + create({ + name: title, + description, + metadata: { + type: EvaluationMetadataType.LlmAsJudgeAdvanced, + prompt, + }, + resultConfiguration, + }) + } else { + create({ + name: title, + description, + metadata: { + type: EvaluationMetadataType.LlmAsJudgeSimple, + objective: '', + additionalInstructions: '', + }, + resultConfiguration, + }) + } onClose(null) }, [create, onClose, title, description, prompt, configuration]) diff --git a/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/evaluations/dashboard/generate/page.tsx b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/evaluations/dashboard/generate/page.tsx index af01dbefd..950e91180 100644 --- a/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/evaluations/dashboard/generate/page.tsx +++ b/apps/web/src/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/evaluations/dashboard/generate/page.tsx @@ -43,12 +43,14 @@ export default function GenerateEvaluationPage() { const { execute: createEvaluation } = useLatitudeAction( createEvaluationFromPromptAction, ) - const [generatedSuggestion, setGeneratedSuggestion] = useState(null) + const [generatedSuggestion, setGeneratedSuggestion] = + useState(null) const validateSuggestion = (suggestion: SuggestedEvaluation) => { if ( !suggestion.eval_name || !suggestion.eval_description || - !suggestion.eval_prompt + !suggestion.eval_objective || + !suggestion.metadata ) { return false } @@ -100,19 +102,14 @@ export default function GenerateEvaluationPage() { projectId: project.id, documentUuid: document.documentUuid, commitUuid: commit.uuid, - prompt: generatedSuggestion.eval_prompt, + objective: generatedSuggestion.eval_objective, + additionalInstructions: generatedSuggestion.eval_additional_instructions, name: generatedSuggestion.eval_name, resultType: generatedSuggestion.eval_type === 'number' ? EvaluationResultableType.Number : EvaluationResultableType.Boolean, - metadata: - generatedSuggestion.eval_type === 'number' - ? { - minValue: generatedSuggestion.metadata.range.from as number, - maxValue: generatedSuggestion.metadata.range.to as number, - } - : undefined, + metadata: generatedSuggestion.metadata, }) if (newEvaluation) { @@ -146,6 +143,69 @@ export default function GenerateEvaluationPage() { } } + const renderMetadata = (suggestion: SuggestedEvaluation) => { + if (suggestion.eval_type === 'number') { + const metadata = suggestion.metadata as { + minValue: number + maxValue: number + minValueDescription?: string + maxValueDescription?: string + } + + return ( +
+
+ Min Value + {metadata.minValue} + {metadata.minValueDescription && ( + + {metadata.minValueDescription} + + )} +
+
+ Max Value + {metadata.maxValue} + {metadata.maxValueDescription && ( + + {metadata.maxValueDescription} + + )} +
+
+ ) + } + + // Boolean type + const metadata = suggestion.metadata as { + falseValueDescription?: string + trueValueDescription?: string + } + + return ( +
+
+ True Value + True + {metadata.trueValueDescription && ( + + {metadata.trueValueDescription} + + )} +
+
+ False Value + False + {metadata.falseValueDescription && ( + + {metadata.falseValueDescription} + + )} +
+
+ ) + } + return (
- Evaluation Name + Name
- - Evaluation Description - + Description
- Evaluation Prompt -
+ Objective + + +
+ {generatedSuggestion.eval_additional_instructions && ( +
+ + Additional Instructions + + + + +
+ )} +
+ Result Type + + + +
+ +
+ Expected Values +
+ {renderMetadata(generatedSuggestion)}
diff --git a/apps/web/src/stores/suggestedEvaluations.ts b/apps/web/src/stores/suggestedEvaluations.ts index e81922f39..a844e21b7 100644 --- a/apps/web/src/stores/suggestedEvaluations.ts +++ b/apps/web/src/stores/suggestedEvaluations.ts @@ -7,13 +7,19 @@ export interface SuggestedEvaluation { eval_name: string eval_description: string eval_type: 'number' | 'boolean' - eval_prompt: string - metadata?: { - range: { - from: number - to: number - } - } + eval_objective: string + eval_additional_instructions?: string + metadata: + | { + minValue: number + maxValue: number + minValueDescription?: string + maxValueDescription?: string + } + | { + falseValueDescription?: string + trueValueDescription?: string + } } export default function useSuggestedEvaluation( diff --git a/packages/core/src/services/evaluations/create.test.ts b/packages/core/src/services/evaluations/create.test.ts index 5934921e1..1370e9cfb 100644 --- a/packages/core/src/services/evaluations/create.test.ts +++ b/packages/core/src/services/evaluations/create.test.ts @@ -173,13 +173,11 @@ describe('createAdvancedEvaluation', () => { description, metadata: { ...evaluation.metadata, - prompt: ` ---- + prompt: `--- provider: ${provider!.name} model: gpt-4o-mini --- -${metadata.prompt} -`.trim(), +${metadata.prompt}`.trim(), }, workspaceId: workspace.id, }) @@ -326,7 +324,7 @@ describe('createEvaluation', () => { workspace, user, name: 'Test Provider', - type: Providers.Groq, + type: Providers.OpenAI, }) repo = new EvaluationsRepository(workspace.id) @@ -369,6 +367,40 @@ describe('createEvaluation', () => { }) }) + it('adds the provider api key id and model to the metadata if they are not provided', async () => { + const evaluationResult = await createEvaluation({ + workspace, + user, + name: 'Test Evaluation', + description: 'Test Description', + metadataType: EvaluationMetadataType.LlmAsJudgeSimple, + metadata: { + objective: 'Test Objective', + additionalInstructions: 'Test Instructions', + }, + resultType: EvaluationResultableType.Text, + resultConfiguration: { + valueDescription: 'Test Value Description', + }, + }) + + expect(evaluationResult.ok).toBe(true) + const evaluation = await repo + .find(evaluationResult.value!.id) + .then((r) => r.unwrap()) + + expect(evaluation.metadata).toMatchObject({ + providerApiKeyId: provider.id, + model: 'gpt-4o-mini', + objective: 'Test Objective', + additionalInstructions: 'Test Instructions', + }) + + expect(evaluation.resultConfiguration).toMatchObject({ + valueDescription: 'Test Value Description', + }) + }) + it('connects the evaluation to the document', async () => { const prompt = factories.helpers.createPrompt({ provider: 'Latitude', @@ -512,7 +544,11 @@ describe('createEvaluation', () => { .then((r) => r.unwrap()) expect(evaluation.metadata).toMatchObject({ - prompt: 'Test Prompt', + prompt: `--- +provider: Test Provider +model: gpt-4o-mini +--- +Test Prompt`.trim(), templateId: null, }) @@ -547,7 +583,11 @@ describe('createEvaluation', () => { .then((r) => r.unwrap()) expect(evaluation.metadata).toMatchObject({ - prompt: 'Test Prompt', + prompt: `--- +provider: Test Provider +model: gpt-4o-mini +--- +Test Prompt`.trim(), templateId: null, }) @@ -583,7 +623,11 @@ describe('createEvaluation', () => { .then((r) => r.unwrap()) expect(evaluation.metadata).toMatchObject({ - prompt: 'Test Prompt', + prompt: `--- +provider: Test Provider +model: gpt-4o-mini +--- +Test Prompt`.trim(), templateId: null, }) diff --git a/packages/core/src/services/evaluations/create.ts b/packages/core/src/services/evaluations/create.ts index 3daf89850..fe6182a47 100644 --- a/packages/core/src/services/evaluations/create.ts +++ b/packages/core/src/services/evaluations/create.ts @@ -13,7 +13,7 @@ import { User, Workspace, } from '../../browser' -import { database } from '../../client' +import { Database, database } from '../../client' import { findEvaluationTemplateById } from '../../data-access' import { publisher } from '../../events/publisher' import { @@ -49,6 +49,29 @@ type EvaluationResultConfigurationBoolean = Partial< Omit > +type CreateEvaluationMetadata = + M extends EvaluationMetadataType.LlmAsJudgeSimple + ? { objective: string } & Partial< + Omit + > + : M extends EvaluationMetadataType.LlmAsJudgeAdvanced + ? { prompt: string } & Partial< + Omit< + EvaluationMetadataLlmAsJudgeAdvanced, + 'id' | 'configuration' | 'prompt' + > + > + : never + +type CreateEvaluationResultConfiguration = + R extends EvaluationResultableType.Boolean + ? EvaluationResultConfigurationBoolean + : R extends EvaluationResultableType.Number + ? EvaluationResultConfigurationNumerical + : R extends EvaluationResultableType.Text + ? EvaluationResultConfigurationText + : never + export async function createEvaluation< M extends EvaluationMetadataType, R extends EvaluationResultableType, @@ -70,24 +93,9 @@ export async function createEvaluation< name: string description: string metadataType: M - metadata: M extends EvaluationMetadataType.LlmAsJudgeSimple - ? Omit - : M extends EvaluationMetadataType.LlmAsJudgeAdvanced - ? { prompt: string } & Partial< - Omit< - EvaluationMetadataLlmAsJudgeAdvanced, - 'id' | 'configuration' | 'prompt' - > - > - : never + metadata: CreateEvaluationMetadata resultType: R - resultConfiguration: R extends EvaluationResultableType.Boolean - ? EvaluationResultConfigurationBoolean - : R extends EvaluationResultableType.Number - ? EvaluationResultConfigurationNumerical - : R extends EvaluationResultableType.Text - ? EvaluationResultConfigurationText - : never + resultConfiguration: CreateEvaluationResultConfiguration projectId?: number documentUuid?: string }, @@ -125,9 +133,18 @@ export async function createEvaluation< } return await Transaction.call(async (tx) => { + const enrichedMetadata = await enrichWithProvider( + { + metadata, + metadataType, + workspace, + }, + tx, + ) + const metadataRow = (await tx .insert(metadataTables[metadataType]) - .values([metadata]) + .values([enrichedMetadata]) .returning() .then((r) => r[0]!)) as IEvaluationMetadata @@ -251,31 +268,13 @@ export async function createAdvancedEvaluation< }, db = database, ): PromisedResult { - const provider = await findDefaultProvider(workspace, db) - if (!provider) { - return Result.error( - new NotFoundError( - 'In order to create an evaluation you need to first create a provider API key from OpenAI or Anthropic', - ), - ) - } - - const promptWithProvider = provider - ? `--- -provider: ${provider.name} -model: ${findFirstModelForProvider(provider.provider)} ---- -${metadata.prompt} -`.trim() - : metadata.prompt - return createEvaluation( { workspace, ...props, metadataType: EvaluationMetadataType.LlmAsJudgeAdvanced, metadata: { - prompt: promptWithProvider, + prompt: metadata.prompt, configuration: resultConfiguration, templateId: metadata.templateId ?? null, } as Omit, @@ -311,3 +310,72 @@ function validateResultConfiguration({ return Result.ok(resultConfiguration) } + +async function enrichWithProvider( + { + metadata, + metadataType, + workspace, + }: { + metadata: CreateEvaluationMetadata + metadataType: M + workspace: Workspace + }, + db: Database, +) { + if ( + metadataType === EvaluationMetadataType.LlmAsJudgeSimple && + // @ts-expect-error - Metadata is a union type and providerApiKeyId is not defined for the other types + !metadata.providerApiKeyId + ) { + const provider = await findDefaultProvider(workspace, db) + if (!provider) { + throw new NotFoundError( + `In order to create an evaluation you need to first create a provider API key from OpenAI or Anthropic`, + ) + } + const model = + // @ts-expect-error - Metadata is a union type and model is not defined for the other types + metadata.model || findFirstModelForProvider(provider.provider) + if (!model) + throw new NotFoundError( + `In order to create an evaluation you need to first create a provider API key from OpenAI or Anthropic`, + ) + + metadata = { + ...metadata, + model, + providerApiKeyId: provider.id, + } + } + + if (metadataType === EvaluationMetadataType.LlmAsJudgeAdvanced) { + const provider = await findDefaultProvider(workspace, db) + if (!provider) { + throw new NotFoundError( + `In order to create an evaluation you need to first create a provider API key from OpenAI or Anthropic`, + ) + } + + const promptWithProvider = provider + ? `--- +provider: ${provider.name} +model: ${findFirstModelForProvider(provider.provider)} +--- +${ + // @ts-expect-error - Metadata is a union type and prompt is not defined for the other types + metadata.prompt +}`.trim() + : // @ts-expect-error - Metadata is a union type and prompt is not defined for the other types + metadata.prompt + + metadata = { + ...metadata, + prompt: promptWithProvider, + } + } + + return metadata as M extends EvaluationMetadataType.LlmAsJudgeSimple + ? Omit + : Omit +} diff --git a/packages/core/src/services/evaluations/prompt/index.test.ts b/packages/core/src/services/evaluations/prompt/index.test.ts index 98500de8b..45a0e8cfa 100644 --- a/packages/core/src/services/evaluations/prompt/index.test.ts +++ b/packages/core/src/services/evaluations/prompt/index.test.ts @@ -57,7 +57,8 @@ describe('getEvaluationPrompt', () => { evaluation, }).then((r) => r.unwrap()) - expect(obtainedPrompt).toBe(prompt) + // @ts-expect-error - Metadata is a union type and prompt is not defined for the other types + expect(obtainedPrompt).toBe(evaluation.metadata.prompt) }) it('Creates a compilable prompt for a simple evaluation', async () => { diff --git a/packages/web-ui/src/ds/atoms/TypewriterText/index.tsx b/packages/web-ui/src/ds/atoms/TypewriterText/index.tsx index a222734ec..9a74ad497 100644 --- a/packages/web-ui/src/ds/atoms/TypewriterText/index.tsx +++ b/packages/web-ui/src/ds/atoms/TypewriterText/index.tsx @@ -39,7 +39,10 @@ export const TypewriterText: React.FC = ({ return ( {displayedText}