Skip to content

Commit

Permalink
Add validator to chat completion with json response
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubknejzlik committed Aug 1, 2024
1 parent c9732b0 commit f74be96
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/chains/prompt-with-retry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export type PromptWithRetryOpts<T extends ShapeType> = {
maxRetries?: number
validator?: (obj: InferedType<T>) => Promise<boolean | void>
} & ThreadPromptWithJsonResponse<T>

export const promptWithRetry = async <T extends ShapeType>({
thread,
validator,
Expand Down
17 changes: 14 additions & 3 deletions src/completions/completion-with-json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import {
CompletionOptsWithFunctionOpts,
completionWithFunctions
} from './completion-with-functions'
import { InferedType } from '../chains/prompt-with-retry'

export type CompletionOptsWithJsonResponse<T extends z.ZodRawShape> =
CompletionOptsWithFunctionOpts & {
responseObject: z.ZodObject<T>
validator?: (obj: z.infer<z.ZodObject<T>>) => Promise<boolean | void>
}

export const functionToOpenAIChatCompletionTool = <T extends z.ZodRawShape>(
Expand All @@ -27,9 +29,10 @@ export const functionToOpenAIChatCompletionTool = <T extends z.ZodRawShape>(
}
}

export const completionWithJsonResponse = async <T extends z.ZodRawShape>(
opts: CompletionOptsWithJsonResponse<T>
): Promise<z.infer<z.ZodObject<T>>> => {
export const completionWithJsonResponse = async <T extends z.ZodRawShape>({
validator,
...opts
}: CompletionOptsWithJsonResponse<T>): Promise<z.infer<z.ZodObject<T>>> => {
const { responseObject, prompt, ...rest } = opts
const responseObjectSchema = JSON.stringify(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -54,6 +57,14 @@ export const completionWithJsonResponse = async <T extends z.ZodRawShape>(
parsedContent = parsedContent.properties
}
const parsed = responseObject.parse(parsedContent)

if (validator) {
const isValid = await validator(parsed)
if (isValid === false) {
throw new Error('Validation of the response failed. Please try again.')
}
}

return parsed
} catch (err) {
throw new Error(`Failed to parse response: ${err}, json: '${res.content}'`)
Expand Down

0 comments on commit f74be96

Please sign in to comment.