From 7ddcc9a1a0326fcf36683c321e3c9e0896ea5fbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Sans=C3=B3n?= Date: Thu, 25 Jul 2024 16:52:18 +0200 Subject: [PATCH] Compile iterator --- .../src/compiler/base/nodes/config.ts | 9 +- .../compiler/src/compiler/base/nodes/each.ts | 29 +- .../compiler/src/compiler/base/nodes/tag.ts | 13 + .../src/compiler/base/nodes/tags/chainStep.ts | 53 +++ packages/compiler/src/compiler/base/types.ts | 23 +- packages/compiler/src/compiler/chain.test.ts | 425 ++++++++++++++++++ packages/compiler/src/compiler/chain.ts | 81 ++++ packages/compiler/src/compiler/compile.ts | 178 ++++++-- packages/compiler/src/compiler/index.ts | 12 +- packages/compiler/src/compiler/scope.ts | 30 +- packages/compiler/src/compiler/types.ts | 1 + packages/compiler/src/compiler/utils.ts | 13 + packages/compiler/src/constants.ts | 2 + packages/compiler/src/error/errors.ts | 4 + packages/compiler/src/parser/interfaces.ts | 2 + 15 files changed, 822 insertions(+), 53 deletions(-) create mode 100644 packages/compiler/src/compiler/base/nodes/tags/chainStep.ts create mode 100644 packages/compiler/src/compiler/chain.test.ts create mode 100644 packages/compiler/src/compiler/chain.ts diff --git a/packages/compiler/src/compiler/base/nodes/config.ts b/packages/compiler/src/compiler/base/nodes/config.ts index f5a8e877d..977d8af70 100644 --- a/packages/compiler/src/compiler/base/nodes/config.ts +++ b/packages/compiler/src/compiler/base/nodes/config.ts @@ -1,7 +1,10 @@ -import { Config } from '$compiler/parser/interfaces' +import { Config as ConfigNode } from '$compiler/parser/interfaces' import { CompileNodeContext } from '../types' -export async function compile(_: CompileNodeContext) { - /* do nothing */ +export async function compile({ + node, + setConfig, +}: CompileNodeContext): Promise { + setConfig(node.value) } diff --git a/packages/compiler/src/compiler/base/nodes/each.ts b/packages/compiler/src/compiler/base/nodes/each.ts index 08d92697c..92bed01af 100644 --- a/packages/compiler/src/compiler/base/nodes/each.ts +++ b/packages/compiler/src/compiler/base/nodes/each.ts @@ -2,7 +2,13 @@ import { hasContent, isIterable } from '$compiler/compiler/utils' import errors from '$compiler/error/errors' import { EachBlock } from '$compiler/parser/interfaces' -import { CompileNodeContext } from '../types' +import { CompileNodeContext, TemplateNodeWithStatus } from '../types' + +type EachNodeWithStatus = TemplateNodeWithStatus & { + status: TemplateNodeWithStatus['status'] & { + loopIterationIndex: number + } +} export async function compile({ node, @@ -13,6 +19,12 @@ export async function compile({ resolveExpression, expressionError, }: CompileNodeContext) { + const nodeWithStatus = node as EachNodeWithStatus + nodeWithStatus.status = { + ...nodeWithStatus.status, + scopePointers: scope.getPointers(), + } + const iterableElement = await resolveExpression(node.expression, scope) if (!isIterable(iterableElement) || !(await hasContent(iterableElement))) { const childScope = scope.copy() @@ -44,7 +56,14 @@ export async function compile({ } let i = 0 + for await (const element of iterableElement) { + if (i < (nodeWithStatus.status.loopIterationIndex ?? 0)) { + i++ + continue + } + nodeWithStatus.status.loopIterationIndex = i + const localScope = scope.copy() localScope.set(contextVarName, element) if (indexVarName) { @@ -54,14 +73,22 @@ export async function compile({ } localScope.set(indexVarName, indexValue) } + for await (const childNode of node.children ?? []) { await resolveBaseNode({ node: childNode, scope: localScope, isInsideMessageTag, isInsideContentTag, + completedValue: `step_${i}`, }) } + i++ } + + nodeWithStatus.status = { + ...nodeWithStatus.status, + loopIterationIndex: 0, + } } diff --git a/packages/compiler/src/compiler/base/nodes/tag.ts b/packages/compiler/src/compiler/base/nodes/tag.ts index e793acaa8..324f9e426 100644 --- a/packages/compiler/src/compiler/base/nodes/tag.ts +++ b/packages/compiler/src/compiler/base/nodes/tag.ts @@ -1,4 +1,5 @@ import { + isChainStepTag, isContentTag, isMessageTag, isRefTag, @@ -6,6 +7,7 @@ import { } from '$compiler/compiler/utils' import errors from '$compiler/error/errors' import { + ChainStepTag, ContentTag, ElementTag, MessageTag, @@ -14,6 +16,7 @@ import { } from '$compiler/parser/interfaces' import { CompileNodeContext } from '../types' +import { compile as resolveChainStep } from './tags/chainStep' import { compile as resolveContent } from './tags/content' import { compile as resolveMessage } from './tags/message' import { compile as resolveRef } from './tags/ref' @@ -91,8 +94,18 @@ export async function compile(props: CompileNodeContext) { return } + if (isChainStepTag(node)) { + await resolveChainStep( + props as CompileNodeContext, + attributes, + ) + return + } + + //@ts-ignore - Linter knows there *should* not be another type of tag. baseNodeError(errors.unknownTag(node.name), node) + //@ts-ignore - ditto for await (const childNode of node.children ?? []) { await resolveBaseNode({ node: childNode, diff --git a/packages/compiler/src/compiler/base/nodes/tags/chainStep.ts b/packages/compiler/src/compiler/base/nodes/tags/chainStep.ts new file mode 100644 index 000000000..07da5d457 --- /dev/null +++ b/packages/compiler/src/compiler/base/nodes/tags/chainStep.ts @@ -0,0 +1,53 @@ +import { tagAttributeIsLiteral } from '$compiler/compiler/utils' +import errors from '$compiler/error/errors' +import { ChainStepTag } from '$compiler/parser/interfaces' +import { Config, ContentType } from '$compiler/types' + +import { CompileNodeContext } from '../../types' + +function isValidConfig(value: unknown): value is Config | undefined { + if (value === undefined) return true + if (Array.isArray(value)) return false + return typeof value === 'object' +} + +export async function compile( + { + node, + scope, + popStepResponse, + addMessage, + groupContent, + baseNodeError, + stop, + }: CompileNodeContext, + attributes: Record, +) { + const stepResponse = popStepResponse() + + const { as: varName, ...config } = attributes + + if (stepResponse === undefined) { + if (!isValidConfig(config)) { + baseNodeError(errors.invalidStepConfig, node) + } + + stop(config as Config) + } + + if ('as' in attributes) { + if (!tagAttributeIsLiteral(node, 'as')) { + baseNodeError(errors.invalidStaticAttribute('as'), node) + } + + const responseText = stepResponse?.content + .filter((c) => c.type === ContentType.text) + .map((c) => c.value) + .join(' ') + + scope.set(String(varName), responseText) + } + + groupContent() + addMessage(stepResponse!) +} diff --git a/packages/compiler/src/compiler/base/types.ts b/packages/compiler/src/compiler/base/types.ts index 2450bc79d..9ee489803 100644 --- a/packages/compiler/src/compiler/base/types.ts +++ b/packages/compiler/src/compiler/base/types.ts @@ -1,6 +1,11 @@ -import Scope from '$compiler/compiler/scope' +import Scope, { ScopePointers } from '$compiler/compiler/scope' import { TemplateNode } from '$compiler/parser/interfaces' -import { Message, MessageContent } from '$compiler/types' +import { + AssistantMessage, + Config, + Message, + MessageContent, +} from '$compiler/types' import type { Node as LogicalExpression } from 'estree' import { ResolveBaseNodeProps, ToolCallReference } from '../types' @@ -27,6 +32,16 @@ type RaiseErrorFn = ( node: N, ) => T +type NodeStatus = { + completedAs?: unknown + scopePointers?: ScopePointers | undefined + eachIterationIndex?: number +} + +export type TemplateNodeWithStatus = TemplateNode & { + status?: NodeStatus +} + export type CompileNodeContext = { node: N scope: Scope @@ -41,6 +56,7 @@ export type CompileNodeContext = { isInsideMessageTag: boolean isInsideContentTag: boolean + setConfig: (config: Config) => void addMessage: (message: Message) => void addStrayText: (text: string) => void popStrayText: () => string @@ -50,4 +66,7 @@ export type CompileNodeContext = { groupContent: () => void addToolCall: (toolCallRef: ToolCallReference) => void popToolCalls: () => ToolCallReference[] + popStepResponse: () => AssistantMessage | undefined + + stop: (config?: Config) => void } diff --git a/packages/compiler/src/compiler/chain.test.ts b/packages/compiler/src/compiler/chain.test.ts new file mode 100644 index 000000000..4f6a8e421 --- /dev/null +++ b/packages/compiler/src/compiler/chain.test.ts @@ -0,0 +1,425 @@ +import { CHAIN_STEP_TAG } from '$compiler/constants' +import CompileError from '$compiler/error/error' +import { + AssistantMessage, + ContentType, + Conversation, + MessageRole, +} from '$compiler/types' +import { describe, expect, it, vi } from 'vitest' + +import { Chain } from './chain' +import { removeCommonIndent } from './utils' + +const getExpectedError = async ( + action: () => Promise, + errorClass: new () => T, +): Promise => { + try { + await action() + } catch (err) { + expect(err).toBeInstanceOf(errorClass) + return err as T + } + throw new Error('Expected an error to be thrown') +} + +const assistantMessage = (content?: string): AssistantMessage => ({ + role: MessageRole.assistant, + content: [ + { + type: ContentType.text, + value: content ?? '', + }, + ], + toolCalls: [], +}) + +async function defaultCallback(): Promise { + return assistantMessage('') +} + +async function complete({ + chain, + callback, + maxSteps = 50, +}: { + chain: Chain + callback?: (convo: Conversation) => Promise + maxSteps?: number +}): Promise { + let steps = 0 + let response: AssistantMessage | undefined = undefined + while (true) { + const { completed, conversation } = await chain.step(response) + if (completed) return conversation + response = await (callback ?? defaultCallback)(conversation) + + steps++ + if (steps > maxSteps) throw new Error('too many chain steps') + } +} + +describe('chain', async () => { + it('computes in a single iteration when there is no step tag', async () => { + const prompt = removeCommonIndent(` + {{ foo = 'foo' }} + System messate + + {{#each [1, 2, 3] as element}} + + User message: {{element}} + + {{/each}} + + + Assistant message: {{foo}} + + `) + + const chain = new Chain({ + prompt: removeCommonIndent(prompt), + parameters: {}, + }) + const { completed } = await chain.step() + expect(completed).toBe(true) + }) + + it('correctly computes the whole prompt in a single iteration', async () => { + const prompt = removeCommonIndent(` + {{foo = 'foo'}} + System message + + {{#each [1, 2, 3] as element}} + + User message: {{element}} + + {{/each}} + + + Assistant message: {{foo}} + + `) + + const chain = new Chain({ + prompt: removeCommonIndent(prompt), + parameters: {}, + }) + + const { conversation } = await chain.step() + expect(conversation.messages.length).toBe(5) + + const systemMessage = conversation.messages[0]! + expect(systemMessage.role).toBe('system') + expect(systemMessage.content[0]!.value).toBe('System message') + + const userMessage = conversation.messages[1]! + expect(userMessage.role).toBe('user') + expect(userMessage.content[0]!.value).toBe('User message: 1') + + const userMessage2 = conversation.messages[2]! + expect(userMessage2.role).toBe('user') + expect(userMessage2.content[0]!.value).toBe('User message: 2') + + const userMessage3 = conversation.messages[3]! + expect(userMessage3.role).toBe('user') + expect(userMessage3.content[0]!.value).toBe('User message: 3') + + const assistantMessage = conversation.messages[4]! + expect(assistantMessage.role).toBe('assistant') + expect(assistantMessage.content[0]!.value).toBe('Assistant message: foo') + }) + + it('stops at a step tag', async () => { + const prompt = removeCommonIndent(` + Message 1 + + <${CHAIN_STEP_TAG} /> + + Message 2 + `) + + const chain = new Chain({ + prompt, + parameters: {}, + }) + + const { completed: completed1, conversation: conversation1 } = + await chain.step() + + expect(completed1).toBe(false) + expect(conversation1.messages.length).toBe(1) + expect(conversation1.messages[0]!.content[0]!.value).toBe('Message 1') + + const { completed: completed2, conversation: conversation2 } = + await chain.step(assistantMessage('response')) + + expect(completed2).toBe(true) + expect(conversation2.messages.length).toBe(3) + expect(conversation2.messages[0]!.content[0]!.value).toBe('Message 1') + expect(conversation2.messages[1]!.content[0]!.value).toBe('response') + expect(conversation2.messages[2]!.content[0]!.value).toBe('Message 2') + }) + + it('does not reevaluate nodes', async () => { + const prompt = removeCommonIndent(` + {{func1()}} + + <${CHAIN_STEP_TAG} /> + + {{func2()}} + `) + + const func1 = vi.fn().mockReturnValue('1') + const func2 = vi.fn().mockReturnValue('2') + + const chain = new Chain({ + prompt, + parameters: { + func1, + func2, + }, + }) + + const conversation = await complete({ chain }) + expect(conversation.messages[0]!.content[0]!.value).toBe('1') + expect(conversation.messages[1]!.content[0]!.value).toBe('') + expect(conversation.messages[2]!.content[0]!.value).toBe('2') + expect(func1).toHaveBeenCalledTimes(1) + expect(func2).toHaveBeenCalledTimes(1) + }) + + it('maintains the scope on simple structures', async () => { + const prompt = removeCommonIndent(` + {{foo = 5}} + + <${CHAIN_STEP_TAG} /> + + {{foo += 1}} + + <${CHAIN_STEP_TAG} /> + + {{foo}} + `) + + const chain = new Chain({ + prompt, + parameters: {}, + }) + + const conversation = await complete({ chain }) + expect( + conversation.messages[conversation.messages.length - 1]!.content[0]! + .value, + ).toBe('6') + }) + + it('maintains the scope in if statements', async () => { + const correctPrompt = removeCommonIndent(` + {{foo = 5}} + + {{#if true}} + <${CHAIN_STEP_TAG} /> + {{foo += 1}} + {{/if}} + + {{foo}} + `) + + const incorrectPrompt = removeCommonIndent(` + {{foo = 5}} + + {{#if true}} + {{bar = 1}} + <${CHAIN_STEP_TAG} /> + {{/if}} + + {{bar}} + `) + + const correctChain = new Chain({ + prompt: correctPrompt, + parameters: {}, + }) + + const conversation = await complete({ chain: correctChain }) + expect( + conversation.messages[conversation.messages.length - 1]!.content[0]! + .value, + ).toBe('6') + + const incorrectChain = new Chain({ + prompt: incorrectPrompt, + parameters: {}, + }) + + const action = () => complete({ chain: incorrectChain }) + const error = await getExpectedError(action, CompileError) + expect(error.code).toBe('variable-not-declared') + }) + + it('maintains the scope in each blocks', async () => { + const prompt = removeCommonIndent(` + {{ foo = 0 }} + + {{#each [1, 2, 3] as element}} + + {{foo}} + + + <${CHAIN_STEP_TAG} /> + + {{foo = element}} + {{/each}} + + {{foo}} + `) + + const chain = new Chain({ + prompt, + parameters: {}, + }) + + const conversation = await complete({ chain, maxSteps: 5 }) + expect(conversation.messages.length).toBe(7) + expect(conversation.messages[0]!.content[0]!.value).toBe('0') + expect(conversation.messages[2]!.content[0]!.value).toBe('1') + expect(conversation.messages[4]!.content[0]!.value).toBe('2') + expect(conversation.messages[6]!.content[0]!.value).toBe('3') + }) + + it('cannot access variables created in a loop outside its scope', async () => { + const prompt = removeCommonIndent(` + {{#each [1, 2, 3] as i}} + {{foo = i}} + <${CHAIN_STEP_TAG} /> + {{/each}} + + {{foo}} + `) + + const chain = new Chain({ + prompt, + parameters: {}, + }) + + const action = () => complete({ chain }) + const error = await getExpectedError(action, CompileError) + expect(error.code).toBe('variable-not-declared') + }) + + it('maintains the scope in nested loops', async () => { + const prompt = removeCommonIndent(` + {{ foo = 0 }} + + {{#each [1, 2, 3] as i}} + + {{#each [1, 2, 3] as j}} + + {{i}}.{{j}} + + + <${CHAIN_STEP_TAG} /> + + {{foo = i * j}} + {{/each}} + + {{ foo }} + {{/each}} + `) + + const chain = new Chain({ + prompt, + parameters: {}, + }) + + const conversation = await complete({ chain }) + const userMessages = conversation.messages.filter( + (m) => m.role === MessageRole.user, + ) + const userMessageText = userMessages + .map((m) => m.content.map((c) => c.value).join(' ')) + .join('\n') + expect(userMessageText).toBe( + removeCommonIndent(` + 1.1 + 1.2 + 1.3 + 2.1 + 2.2 + 2.3 + 3.1 + 3.2 + 3.3 + `), + ) + expect( + conversation.messages[conversation.messages.length - 1]!.content[0]! + .value, + ).toBe('9') + }) + + it('saves the response in a variable', async () => { + const prompt = removeCommonIndent(` + <${CHAIN_STEP_TAG} as="response" /> + + {{response}} + `) + + const chain = new Chain({ + prompt, + parameters: {}, + }) + + const response = { + role: MessageRole.assistant, + content: [ + { + type: ContentType.text, + value: 'foo', + }, + ], + } as AssistantMessage + + await chain.step() + const { conversation } = await chain.step(response) + + expect(conversation.messages.length).toBe(2) + expect(conversation.messages[0]!.content[0]!.value).toBe('foo') + expect(conversation.messages[1]!.content[0]!.value).toBe('foo') + }) + + it('returns the correct configuration in all steps', async () => { + const prompt = removeCommonIndent(` + --- + model: foo-1 + temperature: 0.5 + --- + /* step1 */ + /* step2 */ + /* step3 */ + `) + + const chain = new Chain({ + prompt, + parameters: {}, + }) + + const { conversation: step1 } = await chain.step() + expect(step1.config.model).toBe('foo-1') + expect(step1.config.temperature).toBe(0.5) + + const { conversation: step2 } = await chain.step(assistantMessage()) + expect(step2.config.model).toBe('foo-2') + expect(step2.config.temperature).toBe(0.5) + + const { conversation: step3 } = await chain.step(assistantMessage()) + expect(step3.config.model).toBe('foo-1') + expect(step3.config.temperature).toBe('1') + + const { conversation: finalConversation } = + await chain.step(assistantMessage()) + expect(finalConversation.config.model).toBe('foo-1') + expect(finalConversation.config.temperature).toBe(0.5) + }) +}) diff --git a/packages/compiler/src/compiler/chain.ts b/packages/compiler/src/compiler/chain.ts new file mode 100644 index 000000000..5edc238ae --- /dev/null +++ b/packages/compiler/src/compiler/chain.ts @@ -0,0 +1,81 @@ +import parse from '$compiler/parser' +import { Fragment } from '$compiler/parser/interfaces' +import { + AssistantMessage, + Config, + Conversation, + Message, +} from '$compiler/types' + +import { Compile } from './compile' +import Scope from './scope' + +type ChainStep = { + conversation: Conversation + completed: boolean +} + +export class Chain { + private rawText: string + private ast: Fragment + private scope: Scope + private didStart: boolean = false + private completed: boolean = false + + private messages: Message[] = [] + private config: Config | undefined + + constructor({ + prompt, + parameters, + }: { + prompt: string + parameters: Record + }) { + this.rawText = prompt + this.ast = parse(prompt) + this.scope = new Scope(parameters) + } + + async step(response?: AssistantMessage): Promise { + if (!this.didStart && response !== undefined) { + throw new Error('A response is not allowed before the chain has started') + } + if (this.didStart && response === undefined) { + throw new Error('A response is required to continue a chain') + } + if (this.completed) { + throw new Error('The chain has already completed') + } + this.didStart = true + + const compile = new Compile({ + ast: this.ast, + rawText: this.rawText, + globalScope: this.scope, + stepResponse: response, + }) + + const { completed, scopeStash, ast, messages, globalConfig, stepConfig } = + await compile.run() + + this.scope = Scope.withStash(scopeStash).copy(this.scope.getPointers()) + this.ast = ast + this.messages.push(...messages) + this.config = globalConfig ?? this.config + this.completed = completed || this.completed + + const config = { + ...this.config, + ...stepConfig, + } + + return { + conversation: { + messages: this.messages, + config, + }, + completed: this.completed, + } + } +} diff --git a/packages/compiler/src/compiler/compile.ts b/packages/compiler/src/compiler/compile.ts index b8faf1fc8..939c61fac 100644 --- a/packages/compiler/src/compiler/compile.ts +++ b/packages/compiler/src/compiler/compile.ts @@ -1,10 +1,14 @@ import { error } from '$compiler/error/error' import errors from '$compiler/error/errors' -import parse from '$compiler/parser/index' -import type { BaseNode, TemplateNode } from '$compiler/parser/interfaces' +import type { + BaseNode, + Fragment, + TemplateNode, +} from '$compiler/parser/interfaces' import { + AssistantMessage, + Config, ContentType, - Conversation, Message, MessageContent, MessageRole, @@ -20,68 +24,106 @@ import { compile as resolveIfBlock } from './base/nodes/if' import { compile as resolveMustache } from './base/nodes/mustache' import { compile as resolveElementTag } from './base/nodes/tag' import { compile as resolveText } from './base/nodes/text' -import { CompileNodeContext } from './base/types' -import { readConfig } from './config' +import { CompileNodeContext, TemplateNodeWithStatus } from './base/types' import { resolveLogicNode } from './logic' -import Scope from './scope' +import Scope, { ScopeStash } from './scope' import type { ResolveBaseNodeProps, ToolCallReference } from './types' import { removeCommonIndent } from './utils' +export type CompilationStatus = { + completed: boolean + scopeStash: ScopeStash + ast: Fragment + messages: Message[] + stepConfig: Config | undefined + globalConfig: Config | undefined +} + +class StopIteration extends Error { + public readonly config: Config | undefined + constructor(config: Config | undefined) { + super('StopIteration') + this.config = config + } +} + export type ReferencePromptFn = (prompt: string) => Promise export class Compile { + private ast: Fragment private rawText: string - - private initialScope: Scope + private globalScope: Scope private messages: Message[] = [] + private config: Config | undefined + private stepResponse: AssistantMessage | undefined + private accumulatedText: string = '' private accumulatedContent: MessageContent[] = [] private accumulatedToolCalls: ToolCallReference[] = [] constructor({ - prompt, - parameters, + ast, + rawText, + globalScope, + stepResponse, }: { - prompt: string - parameters: Record + rawText: string + globalScope: Scope + ast: Fragment + stepResponse?: AssistantMessage }) { - this.rawText = prompt - this.initialScope = new Scope(parameters) - } - - async run(): Promise { - const fragment = parse(this.rawText) - const config = readConfig(fragment) as Record - await this.resolveBaseNode({ - node: fragment, - scope: this.initialScope, - isInsideMessageTag: false, - isInsideContentTag: false, - }) + this.rawText = rawText + this.globalScope = globalScope + this.ast = ast + this.stepResponse = stepResponse + } + + async run(): Promise { + let completed = true + let stepConfig: Config | undefined = undefined + + try { + await this.resolveBaseNode({ + node: this.ast, + scope: this.globalScope, + isInsideMessageTag: false, + isInsideContentTag: false, + }) + } catch (e) { + if (e instanceof StopIteration) { + completed = false + stepConfig = e.config + } else { + throw e + } + } + this.groupContent() return { - config, + ast: this.ast, + scopeStash: this.globalScope.getStash(), messages: this.messages, + globalConfig: this.config, + stepConfig, + completed, } } - private async resolveExpression( - expression: LogicalExpression, - scope: Scope, - ): Promise { - return await resolveLogicNode({ - node: expression, - scope, - raiseError: this.expressionError.bind(this), - }) + private stop(config: Config | undefined): void { + throw new StopIteration(config) } private addMessage(message: Message): void { this.messages.push(message) } + private setConfig(config: Config): void { + if (this.config !== undefined) return + this.config = config + } + private addStrayText(text: string) { this.accumulatedText += text } @@ -143,21 +185,58 @@ export class Compile { return toolCalls } + private popStepResponse(): AssistantMessage | undefined { + const response = this.stepResponse + this.stepResponse = undefined + return response + } + + private async resolveExpression( + expression: LogicalExpression, + scope: Scope, + ): Promise { + return await resolveLogicNode({ + node: expression, + scope, + raiseError: this.expressionError.bind(this), + }) + } + private async resolveBaseNode({ node, scope, isInsideMessageTag, isInsideContentTag, + completedValue = true, }: ResolveBaseNodeProps): Promise { + const nodeWithStatus = node as TemplateNodeWithStatus + const compileStatus = nodeWithStatus.status + if (compileStatus?.completedAs === completedValue) { + return + } + + if (compileStatus?.scopePointers) { + scope.setPointers(compileStatus.scopePointers) + } + + const resolveBaseNodeFn = (props: ResolveBaseNodeProps) => { + const completedValueProp = props.completedValue ?? completedValue + return this.resolveBaseNode({ + ...props, + completedValue: completedValueProp, + }) + } + const context: CompileNodeContext = { node, scope, isInsideMessageTag, isInsideContentTag, - resolveBaseNode: this.resolveBaseNode.bind(this), + resolveBaseNode: resolveBaseNodeFn.bind(this), resolveExpression: this.resolveExpression.bind(this), baseNodeError: this.baseNodeError.bind(this), expressionError: this.expressionError.bind(this), + setConfig: this.setConfig.bind(this), addMessage: this.addMessage.bind(this), addStrayText: this.addStrayText.bind(this), popStrayText: this.popStrayText.bind(this), @@ -167,6 +246,8 @@ export class Compile { groupContent: this.groupContent.bind(this), addToolCall: this.addToolCall.bind(this), popToolCalls: this.popToolCalls.bind(this), + popStepResponse: this.popStepResponse.bind(this), + stop: this.stop.bind(this), } const nodeResolver = { @@ -184,10 +265,27 @@ export class Compile { this.baseNodeError(errors.unsupportedBaseNodeType(node.type), node) } - const resolverFn = nodeResolver[node.type] as ( - context: CompileNodeContext, - ) => Promise - await resolverFn(context) + try { + const resolverFn = nodeResolver[node.type] as ( + context: CompileNodeContext, + ) => Promise + await resolverFn(context) + } catch (e) { + if (e instanceof StopIteration) { + nodeWithStatus.status = { + ...(nodeWithStatus.status ?? {}), + scopePointers: scope.getPointers(), + } + } + + throw e + } + + nodeWithStatus.status = { + ...(nodeWithStatus.status ?? {}), + scopePointers: undefined, + completedAs: completedValue, + } } private baseNodeError( diff --git a/packages/compiler/src/compiler/index.ts b/packages/compiler/src/compiler/index.ts index aba609825..e1806bc9b 100644 --- a/packages/compiler/src/compiler/index.ts +++ b/packages/compiler/src/compiler/index.ts @@ -1,16 +1,22 @@ import { Conversation, ConversationMetadata } from '$compiler/types' -import { Compile, type ReferencePromptFn } from './compile' +import { Chain } from './chain' +import { type ReferencePromptFn } from './compile' import { ReadMetadata } from './readMetadata' -export function compile({ +export async function compile({ prompt, parameters, }: { prompt: string parameters: Record }): Promise { - return new Compile({ prompt, parameters }).run() + const iterator = new Chain({ prompt, parameters }) + const { conversation, completed } = await iterator.step() + if (!completed) { + throw new Error('Use a Chain to compile prompts with multiple steps') + } + return conversation } export function readMetadata({ diff --git a/packages/compiler/src/compiler/scope.ts b/packages/compiler/src/compiler/scope.ts index 0c259acc6..68b9a9b6f 100644 --- a/packages/compiler/src/compiler/scope.ts +++ b/packages/compiler/src/compiler/scope.ts @@ -1,3 +1,6 @@ +export type ScopePointers = { [key: string]: number } +export type ScopeStash = unknown[] + export default class Scope { /** * Global stash @@ -19,8 +22,8 @@ export default class Scope { * Local pointers * Every scope has its own local pointers that contains the indexes of the variables in the global stash. */ - private globalStash: unknown[] = [] // Stash of every variable value in the global scope - private localPointers: Record = {} // Index of every variable in the stash in the current scope + private globalStash: ScopeStash = [] // Stash of every variable value in the global scope + private localPointers: ScopePointers = {} // Index of every variable in the stash in the current scope constructor(initialState: Record = {}) { for (const [key, value] of Object.entries(initialState)) { @@ -28,9 +31,16 @@ export default class Scope { } } + static withStash(stash: ScopeStash): Scope { + const scope = new Scope() + scope.globalStash = stash + return scope + } + private readFromStash(index: number): unknown { return this.globalStash[index] } + private addToStash(value: unknown): number { this.globalStash.push(value) return this.globalStash.length - 1 @@ -60,12 +70,24 @@ export default class Scope { this.modifyStash(index, value) } - copy(): Scope { + copy(localPointers?: ScopePointers): Scope { const scope = new Scope() scope.globalStash = this.globalStash - scope.localPointers = { ...this.localPointers } + scope.localPointers = { ...(localPointers ?? this.localPointers) } return scope } + + getStash(): ScopeStash { + return this.globalStash + } + + getPointers(): ScopePointers { + return this.localPointers + } + + setPointers(pointers: ScopePointers): void { + this.localPointers = pointers + } } export type ScopeContext = { diff --git a/packages/compiler/src/compiler/types.ts b/packages/compiler/src/compiler/types.ts index 8b4304516..fe7713d7f 100644 --- a/packages/compiler/src/compiler/types.ts +++ b/packages/compiler/src/compiler/types.ts @@ -8,6 +8,7 @@ export type ResolveBaseNodeProps = { scope: Scope isInsideMessageTag: boolean isInsideContentTag: boolean + completedValue?: unknown } export type ToolCallReference = { node: ToolCallTag; value: ToolCall } diff --git a/packages/compiler/src/compiler/utils.ts b/packages/compiler/src/compiler/utils.ts index 7e32d1812..45518ba54 100644 --- a/packages/compiler/src/compiler/utils.ts +++ b/packages/compiler/src/compiler/utils.ts @@ -1,9 +1,11 @@ import { + CHAIN_STEP_TAG, CUSTOM_MESSAGE_TAG, REFERENCE_PROMPT_TAG, TOOL_CALL_TAG, } from '$compiler/constants' import { + ChainStepTag, ContentTag, ElementTag, MessageTag, @@ -53,6 +55,17 @@ export function isRefTag(tag: ElementTag): tag is ReferenceTag { return tag.name === REFERENCE_PROMPT_TAG } +export function isChainStepTag(tag: ElementTag): tag is ChainStepTag { + return tag.name === CHAIN_STEP_TAG +} + export function isToolCallTag(tag: ElementTag): tag is ToolCallTag { return tag.name === TOOL_CALL_TAG } + +export function tagAttributeIsLiteral(tag: ElementTag, name: string): boolean { + const attr = tag.attributes.find((attr) => attr.name === name) + if (!attr) return false + if (attr.value === true) return true + return attr.value.every((v) => v.type === 'Text') +} diff --git a/packages/compiler/src/constants.ts b/packages/compiler/src/constants.ts index ca296ad33..f58f27a0c 100644 --- a/packages/compiler/src/constants.ts +++ b/packages/compiler/src/constants.ts @@ -11,3 +11,5 @@ export const REFERENCE_PROMPT_ATTR = 'prompt' as const // { content } export const TOOL_CALL_TAG = 'tool-call' as const + +export const CHAIN_STEP_TAG = 'step' as const diff --git a/packages/compiler/src/error/errors.ts b/packages/compiler/src/error/errors.ts index 2f59be0b4..80f557564 100644 --- a/packages/compiler/src/error/errors.ts +++ b/packages/compiler/src/error/errors.ts @@ -229,4 +229,8 @@ export default { message: `Error calling function: \n${errorKlassName} ${error.message}`, } }, + invalidStepConfig: { + code: 'invalid-step-config', + message: 'Step config must be an object', + }, } diff --git a/packages/compiler/src/parser/interfaces.ts b/packages/compiler/src/parser/interfaces.ts index c53326dfc..0e4d9741b 100644 --- a/packages/compiler/src/parser/interfaces.ts +++ b/packages/compiler/src/parser/interfaces.ts @@ -1,4 +1,5 @@ import { + CHAIN_STEP_TAG, CUSTOM_MESSAGE_TAG, REFERENCE_PROMPT_TAG, TOOL_CALL_TAG, @@ -47,6 +48,7 @@ export type MessageTag = | IElementTag | IElementTag export type ReferenceTag = IElementTag +export type ChainStepTag = IElementTag export type ToolCallTag = IElementTag export type ElementTag = | ContentTag