diff --git a/packages/promptl/rollup.config.mjs b/packages/promptl/rollup.config.mjs index 88cf71a97..6489c8905 100644 --- a/packages/promptl/rollup.config.mjs +++ b/packages/promptl/rollup.config.mjs @@ -22,7 +22,7 @@ function isInternalCircularDependency(warning) { const __dirname = path.dirname(url.fileURLToPath(import.meta.url)) const aliasEntries = { - entries: [{ find: '$compiler', replacement: path.resolve(__dirname, 'src') }], + entries: [{ find: '$promptl', replacement: path.resolve(__dirname, 'src') }], } /** @type {import('rollup').RollupOptions} */ @@ -54,7 +54,7 @@ export default [ 'node:crypto', 'yaml', 'crypto', - 'zod' + 'zod', ], }, { diff --git a/packages/promptl/src/compiler/base/nodes/comment.ts b/packages/promptl/src/compiler/base/nodes/comment.ts index f65702c60..33add6c68 100644 --- a/packages/promptl/src/compiler/base/nodes/comment.ts +++ b/packages/promptl/src/compiler/base/nodes/comment.ts @@ -1,4 +1,4 @@ -import { Comment } from '$compiler/parser/interfaces' +import { Comment } from '$promptl/parser/interfaces' import { CompileNodeContext } from '../types' diff --git a/packages/promptl/src/compiler/base/nodes/config.ts b/packages/promptl/src/compiler/base/nodes/config.ts index cbd2941c6..e553c106d 100644 --- a/packages/promptl/src/compiler/base/nodes/config.ts +++ b/packages/promptl/src/compiler/base/nodes/config.ts @@ -1,4 +1,4 @@ -import { Config as ConfigNode } from '$compiler/parser/interfaces' +import { Config as ConfigNode } from '$promptl/parser/interfaces' import yaml from 'yaml' import { CompileNodeContext } from '../types' diff --git a/packages/promptl/src/compiler/base/nodes/for.test.ts b/packages/promptl/src/compiler/base/nodes/for.test.ts index e4fa695f9..518077a76 100644 --- a/packages/promptl/src/compiler/base/nodes/for.test.ts +++ b/packages/promptl/src/compiler/base/nodes/for.test.ts @@ -1,6 +1,6 @@ -import { getExpectedError } from '$compiler/compiler/test/helpers' -import CompileError from '$compiler/error/error' -import { Message, MessageContent, TextContent } from '$compiler/types' +import CompileError from '$promptl/error/error' +import { getExpectedError } from '$promptl/test/helpers' +import { Message, MessageContent, TextContent } from '$promptl/types' import { describe, expect, it } from 'vitest' import { render } from '../..' diff --git a/packages/promptl/src/compiler/base/nodes/for.ts b/packages/promptl/src/compiler/base/nodes/for.ts index 44013b0c2..ae3fc1762 100644 --- a/packages/promptl/src/compiler/base/nodes/for.ts +++ b/packages/promptl/src/compiler/base/nodes/for.ts @@ -1,6 +1,6 @@ -import { hasContent, isIterable } from '$compiler/compiler/utils' -import errors from '$compiler/error/errors' -import { ForBlock } from '$compiler/parser/interfaces' +import { hasContent, isIterable } from '$promptl/compiler/utils' +import errors from '$promptl/error/errors' +import { ForBlock } from '$promptl/parser/interfaces' import { CompileNodeContext, TemplateNodeWithStatus } from '../types' diff --git a/packages/promptl/src/compiler/base/nodes/fragment.ts b/packages/promptl/src/compiler/base/nodes/fragment.ts index dede7ad18..309537d91 100644 --- a/packages/promptl/src/compiler/base/nodes/fragment.ts +++ b/packages/promptl/src/compiler/base/nodes/fragment.ts @@ -1,4 +1,4 @@ -import { Fragment } from '$compiler/parser/interfaces' +import { Fragment } from '$promptl/parser/interfaces' import { CompileNodeContext } from '../types' diff --git a/packages/promptl/src/compiler/base/nodes/if.test.ts b/packages/promptl/src/compiler/base/nodes/if.test.ts index 37b1f98ef..a9f1a9b7a 100644 --- a/packages/promptl/src/compiler/base/nodes/if.test.ts +++ b/packages/promptl/src/compiler/base/nodes/if.test.ts @@ -1,4 +1,4 @@ -import { AssistantMessage, MessageRole, UserMessage } from '$compiler/types' +import { AssistantMessage, MessageRole, UserMessage } from '$promptl/types' import { describe, expect, it, vi } from 'vitest' import { render } from '../..' diff --git a/packages/promptl/src/compiler/base/nodes/if.ts b/packages/promptl/src/compiler/base/nodes/if.ts index cb146355c..af056e1a0 100644 --- a/packages/promptl/src/compiler/base/nodes/if.ts +++ b/packages/promptl/src/compiler/base/nodes/if.ts @@ -1,4 +1,4 @@ -import { IfBlock } from '$compiler/parser/interfaces' +import { IfBlock } from '$promptl/parser/interfaces' import { CompileNodeContext } from '../types' diff --git a/packages/promptl/src/compiler/base/nodes/mustache.ts b/packages/promptl/src/compiler/base/nodes/mustache.ts index d48bebf78..c5255825e 100644 --- a/packages/promptl/src/compiler/base/nodes/mustache.ts +++ b/packages/promptl/src/compiler/base/nodes/mustache.ts @@ -1,4 +1,4 @@ -import { MustacheTag } from '$compiler/parser/interfaces' +import { MustacheTag } from '$promptl/parser/interfaces' import { CompileNodeContext } from '../types' diff --git a/packages/promptl/src/compiler/base/nodes/tag.ts b/packages/promptl/src/compiler/base/nodes/tag.ts index d38c896bf..62072a0ba 100644 --- a/packages/promptl/src/compiler/base/nodes/tag.ts +++ b/packages/promptl/src/compiler/base/nodes/tag.ts @@ -3,24 +3,21 @@ import { isContentTag, isMessageTag, isRefTag, - isToolCallTag, -} from '$compiler/compiler/utils' -import errors from '$compiler/error/errors' +} from '$promptl/compiler/utils' +import errors from '$promptl/error/errors' import { ChainStepTag, ContentTag, ElementTag, MessageTag, ReferenceTag, - ToolCallTag, -} from '$compiler/parser/interfaces' +} from '$promptl/parser/interfaces' import { CompileNodeContext } from '../types' import { compile as resolveChainStep } from './tags/chainStep' import { compile as resolveContent } from './tags/content' import { compile as resolveMessage } from './tags/message' import { compile as resolveRef } from './tags/ref' -import { compile as resolveToolCall } from './tags/toolCall' async function resolveTagAttributes({ node: tagNode, @@ -81,11 +78,6 @@ export async function compile(props: CompileNodeContext) { const attributes = await resolveTagAttributes(props) - if (isToolCallTag(node)) { - await resolveToolCall(props as CompileNodeContext, attributes) - return - } - if (isContentTag(node)) { await resolveContent(props as CompileNodeContext, attributes) return diff --git a/packages/promptl/src/compiler/base/nodes/tags/chainStep.ts b/packages/promptl/src/compiler/base/nodes/tags/chainStep.ts index 9dc0aa4b0..853ca7c1e 100644 --- a/packages/promptl/src/compiler/base/nodes/tags/chainStep.ts +++ b/packages/promptl/src/compiler/base/nodes/tags/chainStep.ts @@ -1,7 +1,7 @@ -import { tagAttributeIsLiteral } from '$compiler/compiler/utils' -import errors from '$compiler/error/errors' -import { ChainStepTag } from '$compiler/parser/interfaces' -import { Config } from '$compiler/types' +import { tagAttributeIsLiteral } from '$promptl/compiler/utils' +import errors from '$promptl/error/errors' +import { ChainStepTag } from '$promptl/parser/interfaces' +import { Config } from '$promptl/types' import { CompileNodeContext } from '../../types' diff --git a/packages/promptl/src/compiler/base/nodes/tags/content.test.ts b/packages/promptl/src/compiler/base/nodes/tags/content.test.ts new file mode 100644 index 000000000..3d62ef47d --- /dev/null +++ b/packages/promptl/src/compiler/base/nodes/tags/content.test.ts @@ -0,0 +1,149 @@ +import { render } from '$promptl/compiler' +import { removeCommonIndent } from '$promptl/compiler/utils' +import CompileError from '$promptl/error/error' +import { getExpectedError } from '$promptl/test/helpers' +import { + ImageContent, + SystemMessage, + TextContent, + ToolCallContent, + UserMessage, +} from '$promptl/types' +import { describe, expect, it } from 'vitest' + +describe('content tags', async () => { + it('adds stray text at root level as text contents inside a system message', async () => { + const prompt = 'Test message' + const result = await render({ prompt }) + expect(result.messages.length).toBe(1) + const message = result.messages[0]! as SystemMessage + expect(message.role).toBe('system') + expect(message.content.length).toBe(1) + expect(message.content[0]!.type).toBe('text') + expect((message.content[0] as TextContent).text).toBe('Test message') + }) + + it('adds stray text inside a message as the message content', async () => { + const prompt = 'Test user message' + const result = await render({ prompt }) + expect(result.messages.length).toBe(1) + const message = result.messages[0]! as UserMessage + expect(message.role).toBe('user') + expect(message.content.length).toBe(1) + expect(message.content[0]!.type).toBe('text') + expect((message.content[0] as TextContent).text).toBe('Test user message') + }) + + it('Can add multiple text and image contents inside a message', async () => { + const prompt = removeCommonIndent(` + + Text 1 + Image 1 + Text 2 + Text 3 + + `) + const result = await render({ prompt }) + expect(result.messages.length).toBe(1) + const message = result.messages[0]! as UserMessage + expect(message.role).toBe('user') + expect(message.content.length).toBe(4) + expect(message.content[0]!.type).toBe('text') + expect((message.content[0] as TextContent).text).toBe('Text 1') + expect(message.content[1]!.type).toBe('image') + expect((message.content[1] as ImageContent).image).toBe('Image 1') + expect(message.content[2]!.type).toBe('text') + expect((message.content[2] as TextContent).text).toBe('Text 2') + expect(message.content[3]!.type).toBe('text') + expect((message.content[3] as TextContent).text).toBe('Text 3') + }) + + it('Can interpolate stray text and defined content tags', async () => { + const prompt = removeCommonIndent(` + Text 1 + Text 2 + Text 3 + `) + const result = await render({ prompt }) + + expect(result.messages.length).toBe(1) + const message = result.messages[0]! + expect(message.content.length).toBe(3) + expect(message.content[0]!.type).toBe('text') + expect((message.content[0] as TextContent).text).toBe('Text 1') + expect(message.content[1]!.type).toBe('text') + expect((message.content[1] as TextContent).text).toBe('Text 2') + expect(message.content[2]!.type).toBe('text') + expect((message.content[2] as TextContent).text).toBe('Text 3') + }) + + it('Can use the general "content" tag with a type attribute', async () => { + const prompt = removeCommonIndent(` + Text + Image + `) + const result = await render({ prompt }) + + expect(result.messages.length).toBe(1) + const message = result.messages[0]! + expect(message.content.length).toBe(2) + expect(message.content[0]!.type).toBe('text') + expect((message.content[0] as TextContent).text).toBe('Text') + expect(message.content[1]!.type).toBe('image') + expect((message.content[1] as ImageContent).image).toBe('Image') + }) + + it('Cannot include a content tag inside another content tag', async () => { + const prompt = removeCommonIndent(` + + Text + + `) + const error = await getExpectedError(() => render({ prompt }), CompileError) + expect(error.code).toBe('content-tag-inside-content') + }) +}) + +describe('tool-call tags', async () => { + it('returns tool calls in the content of assistant messages', async () => { + const prompt = removeCommonIndent(` + + + + `) + const result = await render({ prompt }) + + expect(result.messages.length).toBe(1) + const message = result.messages[0]! as SystemMessage + expect(message.role).toBe('assistant') + expect(message.content.length).toBe(1) + expect(message.content[0]!.type).toBe('tool-call') + const toolCall = message.content[0]! as ToolCallContent + expect(toolCall.toolName).toBe('get_weather') + expect(toolCall.toolCallId).toBe('123') + }) + + it('fails when not in an assistant message tag', async () => { + const prompt = removeCommonIndent(` + + + + `) + + const error = await getExpectedError(() => render({ prompt }), CompileError) + expect(error.code).toBe('invalid-tool-call-placement') + }) + + it('fails when a tool call is inside another tool call', async () => { + const prompt = removeCommonIndent(` + + + + + + `) + + const error = await getExpectedError(() => render({ prompt }), CompileError) + expect(error.code).toBe('content-tag-inside-content') + }) +}) diff --git a/packages/promptl/src/compiler/base/nodes/tags/content.ts b/packages/promptl/src/compiler/base/nodes/tags/content.ts index b52e28e53..a96862e2e 100644 --- a/packages/promptl/src/compiler/base/nodes/tags/content.ts +++ b/packages/promptl/src/compiler/base/nodes/tags/content.ts @@ -1,7 +1,11 @@ -import { removeCommonIndent } from '$compiler/compiler/utils' -import errors from '$compiler/error/errors' -import { ContentTag } from '$compiler/parser/interfaces' -import { ContentType } from '$compiler/types' +import { removeCommonIndent } from '$promptl/compiler/utils' +import { + CUSTOM_CONTENT_TAG, + CUSTOM_CONTENT_TYPE_ATTR, +} from '$promptl/constants' +import errors from '$promptl/error/errors' +import { ContentTag } from '$promptl/parser/interfaces' +import { ContentType, ContentTypeTagName } from '$promptl/types' import { CompileNodeContext } from '../../types' @@ -32,19 +36,64 @@ export async function compile( } const textContent = removeCommonIndent(popStrayText()) - // TODO: This if else is probably not required but the types enforce it. - // Improve types. - if (node.name === 'text') { + let type: ContentType + if (node.name === CUSTOM_CONTENT_TAG) { + if (attributes[CUSTOM_CONTENT_TYPE_ATTR] === undefined) { + baseNodeError(errors.messageTagWithoutRole, node) + } + type = attributes[CUSTOM_CONTENT_TYPE_ATTR] as ContentType + delete attributes[CUSTOM_CONTENT_TYPE_ATTR] + } else { + const contentTypeKeysFromTagName = Object.fromEntries( + Object.entries(ContentTypeTagName).map(([k, v]) => [v, k]), + ) + type = + ContentType[ + contentTypeKeysFromTagName[node.name] as keyof typeof ContentType + ] + } + + if (type === ContentType.text) { addContent({ - ...attributes, - type: ContentType.text, - text: textContent, + node, + content: { + ...attributes, + type: ContentType.text, + text: textContent, + }, }) - } else { + return + } + + if (type === ContentType.image) { addContent({ - ...attributes, - type: ContentType.image, - image: textContent, + node, + content: { + ...attributes, + type: ContentType.image, + image: textContent, + }, }) + return } + + if (type == ContentType.toolCall) { + const { id, name, ...rest } = attributes + if (!id) baseNodeError(errors.toolCallTagWithoutId, node) + if (!name) baseNodeError(errors.toolCallWithoutName, node) + + addContent({ + node, + content: { + ...rest, + type: ContentType.toolCall, + toolCallId: String(id), + toolName: String(name), + toolArguments: {}, // TODO: Issue for a future PR + }, + }) + return + } + + baseNodeError(errors.invalidContentType(type), node) } diff --git a/packages/promptl/src/compiler/base/nodes/tags/message.test.ts b/packages/promptl/src/compiler/base/nodes/tags/message.test.ts index 2750f456a..103b096e7 100644 --- a/packages/promptl/src/compiler/base/nodes/tags/message.test.ts +++ b/packages/promptl/src/compiler/base/nodes/tags/message.test.ts @@ -1,15 +1,15 @@ -import { render } from '$compiler/compiler' -import { getExpectedError } from '$compiler/compiler/test/helpers' -import { removeCommonIndent } from '$compiler/compiler/utils' -import { CUSTOM_TAG_END, CUSTOM_TAG_START } from '$compiler/constants' -import CompileError from '$compiler/error/error' +import { render } from '$promptl/compiler' +import { removeCommonIndent } from '$promptl/compiler/utils' +import { CUSTOM_TAG_END, CUSTOM_TAG_START } from '$promptl/constants' +import CompileError from '$promptl/error/error' +import { getExpectedError } from '$promptl/test/helpers' import { AssistantMessage, ImageContent, SystemMessage, TextContent, UserMessage, -} from '$compiler/types' +} from '$promptl/types' import { describe, expect, it } from 'vitest' describe('messages', async () => { @@ -171,7 +171,6 @@ describe('messages', async () => { }, ], foo: 'assistant_bar', - toolCalls: [], }, { role: 'system', @@ -191,9 +190,9 @@ describe('message contents', async () => { it('all messages can have multiple content tags', async () => { const prompt = ` - text content - image content - another text content + text content + image content + another text content ` const result = await render({ @@ -256,10 +255,10 @@ describe('message contents', async () => { it('allows content tag to have extra attributes', async () => { const prompt = ` - + Long text cached... - - Short text not cached + + Short text not cached ` const result = await render({ diff --git a/packages/promptl/src/compiler/base/nodes/tags/message.ts b/packages/promptl/src/compiler/base/nodes/tags/message.ts index 8ca081d32..462eb679c 100644 --- a/packages/promptl/src/compiler/base/nodes/tags/message.ts +++ b/packages/promptl/src/compiler/base/nodes/tags/message.ts @@ -1,19 +1,15 @@ -import { ToolCallReference } from '$compiler/compiler/types' import { CUSTOM_MESSAGE_ROLE_ATTR, CUSTOM_MESSAGE_TAG, -} from '$compiler/constants' -import errors from '$compiler/error/errors' -import { MessageTag } from '$compiler/parser/interfaces' +} from '$promptl/constants' +import errors from '$promptl/error/errors' +import { MessageTag, TemplateNode } from '$promptl/parser/interfaces' import { - AssistantMessage, + ContentType, Message, MessageContent, MessageRole, - SystemMessage, - ToolMessage, - UserMessage, -} from '$compiler/types' +} from '$promptl/types' import { CompileNodeContext } from '../../types' @@ -31,7 +27,6 @@ export async function compile( groupContent, groupStrayText, popContent, - popToolCalls, addMessage, } = props @@ -60,14 +55,12 @@ export async function compile( } groupStrayText() - const messageContent = popContent() - const toolCalls = popToolCalls() + const content = popContent() const message = buildMessage(props as CompileNodeContext, { role, attributes, - content: messageContent, - toolCalls, + content, })! addMessage(message) } @@ -75,44 +68,33 @@ export async function compile( type BuildProps = { role: R attributes: Record - content: R extends MessageRole.user ? MessageContent[] : string - toolCalls: ToolCallReference[] + content: { node?: TemplateNode; content: MessageContent }[] } function buildMessage( { node, baseNodeError }: CompileNodeContext, - { role, attributes, content, toolCalls }: BuildProps, + { role, attributes, content }: BuildProps, ): Message | undefined { + if (!Object.values(MessageRole).includes(role)) { + baseNodeError(errors.invalidMessageRole(role), node) + } + if (role !== MessageRole.assistant) { - toolCalls.forEach(({ node: toolNode }) => { - baseNodeError(errors.invalidToolCallPlacement, toolNode) + content.forEach((item) => { + if (item.content.type === ContentType.toolCall) { + baseNodeError(errors.invalidToolCallPlacement, item.node ?? node) + } }) } - if (role === MessageRole.system) { - return { - ...attributes, - role, - content, - } as SystemMessage - } + const message = { + ...attributes, + role, + content: content.map((item) => item.content), + } as Message if (role === MessageRole.user) { - return { - ...attributes, - role, - name: attributes.name ? String(attributes.name) : undefined, - content, - } as UserMessage - } - - if (role === MessageRole.assistant) { - return { - ...attributes, - role, - toolCalls: toolCalls.map(({ value }) => value), - content, - } as AssistantMessage + message.name = attributes.name ? String(attributes.name) : undefined } if (role === MessageRole.tool) { @@ -120,11 +102,8 @@ function buildMessage( baseNodeError(errors.toolMessageWithoutId, node) } - return { - role, - content, - } as ToolMessage + message.toolId = String(attributes.id) } - baseNodeError(errors.invalidMessageRole(role), node) + return message } diff --git a/packages/promptl/src/compiler/base/nodes/tags/ref.ts b/packages/promptl/src/compiler/base/nodes/tags/ref.ts index 537cc0d7d..74b39408d 100644 --- a/packages/promptl/src/compiler/base/nodes/tags/ref.ts +++ b/packages/promptl/src/compiler/base/nodes/tags/ref.ts @@ -1,5 +1,5 @@ -import errors from '$compiler/error/errors' -import { ReferenceTag } from '$compiler/parser/interfaces' +import errors from '$promptl/error/errors' +import { ReferenceTag } from '$promptl/parser/interfaces' import { CompileNodeContext } from '../../types' diff --git a/packages/promptl/src/compiler/base/nodes/tags/toolCall.ts b/packages/promptl/src/compiler/base/nodes/tags/toolCall.ts deleted file mode 100644 index 6e7c1e463..000000000 --- a/packages/promptl/src/compiler/base/nodes/tags/toolCall.ts +++ /dev/null @@ -1,61 +0,0 @@ -import errors from '$compiler/error/errors' -import { ToolCallTag } from '$compiler/parser/interfaces' - -import { CompileNodeContext } from '../../types' - -export async function compile( - { - node, - scope, - isInsideMessageTag, - isInsideContentTag, - resolveBaseNode, - baseNodeError, - popStrayText, - addToolCall, - }: CompileNodeContext, - attributes: Record, -) { - if (isInsideContentTag) { - baseNodeError(errors.toolCallTagInsideContent, node) - } - - if (attributes['id'] === undefined) { - baseNodeError(errors.toolCallTagWithoutId, node) - } - - if (attributes['name'] === undefined) { - baseNodeError(errors.toolCallWithoutName, node) - } - - for await (const childNode of node.children ?? []) { - await resolveBaseNode({ - node: childNode, - scope, - isInsideMessageTag, - isInsideContentTag: true, - }) - } - - const textContent = popStrayText() - - let jsonContent: Record = {} - if (textContent) { - try { - jsonContent = JSON.parse(textContent) - } catch (error: unknown) { - if (error instanceof SyntaxError) { - baseNodeError(errors.invalidToolCallArguments, node) - } - } - } - - addToolCall({ - node: node as ToolCallTag, - value: { - id: String(attributes['id']), - name: String(attributes['name']), - arguments: jsonContent, - }, - }) -} diff --git a/packages/promptl/src/compiler/base/nodes/text.ts b/packages/promptl/src/compiler/base/nodes/text.ts index 7cbd3d1dd..0feb7bef5 100644 --- a/packages/promptl/src/compiler/base/nodes/text.ts +++ b/packages/promptl/src/compiler/base/nodes/text.ts @@ -1,4 +1,4 @@ -import { Text } from '$compiler/parser/interfaces' +import { Text } from '$promptl/parser/interfaces' import { CompileNodeContext } from '../types' diff --git a/packages/promptl/src/compiler/base/types.ts b/packages/promptl/src/compiler/base/types.ts index 9ee489803..a6dcc04a7 100644 --- a/packages/promptl/src/compiler/base/types.ts +++ b/packages/promptl/src/compiler/base/types.ts @@ -1,14 +1,14 @@ -import Scope, { ScopePointers } from '$compiler/compiler/scope' -import { TemplateNode } from '$compiler/parser/interfaces' +import Scope, { ScopePointers } from '$promptl/compiler/scope' +import { TemplateNode } from '$promptl/parser/interfaces' import { AssistantMessage, Config, Message, MessageContent, -} from '$compiler/types' +} from '$promptl/types' import type { Node as LogicalExpression } from 'estree' -import { ResolveBaseNodeProps, ToolCallReference } from '../types' +import { ResolveBaseNodeProps } from '../types' export enum NodeType { Literal = 'Literal', @@ -61,11 +61,9 @@ export type CompileNodeContext = { addStrayText: (text: string) => void popStrayText: () => string groupStrayText: () => void - addContent: (content: MessageContent) => void - popContent: () => MessageContent[] + addContent: (item: { node?: TemplateNode; content: MessageContent }) => void + popContent: () => { node?: TemplateNode; content: MessageContent }[] groupContent: () => void - addToolCall: (toolCallRef: ToolCallReference) => void - popToolCalls: () => ToolCallReference[] popStepResponse: () => AssistantMessage | undefined stop: (config?: Config) => void diff --git a/packages/promptl/src/compiler/chain.test.ts b/packages/promptl/src/compiler/chain.test.ts index 3b7b7f26a..268430ae7 100644 --- a/packages/promptl/src/compiler/chain.test.ts +++ b/packages/promptl/src/compiler/chain.test.ts @@ -1,30 +1,19 @@ -import { CHAIN_STEP_TAG } from '$compiler/constants' -import CompileError from '$compiler/error/error' +import { CHAIN_STEP_TAG } from '$promptl/constants' +import CompileError from '$promptl/error/error' +import { getExpectedError } from '$promptl/test/helpers' import { AssistantMessage, + ContentType, Conversation, MessageRole, TextContent, UserMessage, -} from '$compiler/types' +} from '$promptl/types' import { describe, expect, it, vi } from 'vitest' import { Chain } from './chain' import { removeCommonIndent } from './utils' -const getExpectedError = async ( - action: () => Promise, - errorClass: new () => T, -): Promise => { - try { - await action() - } catch (err) { - expect(err).toBeInstanceOf(errorClass) - return err as T - } - throw new Error('Expected an error to be thrown') -} - async function defaultCallback(): Promise { return '' } @@ -179,7 +168,12 @@ describe('chain', async () => { }, ], }) - expect(conversation2.messages[1]!.content).toBe('response') + expect(conversation2.messages[1]!.content.length).toBe(1) + expect(conversation2.messages[1]!.content[0]!.type).toBe(ContentType.text) + expect((conversation2.messages[1]!.content[0] as TextContent).text).toBe( + 'response', + ) + expect(conversation2.messages[2]!).toEqual({ role: MessageRole.system, content: [ @@ -243,6 +237,7 @@ describe('chain', async () => { let { completed: stop } = await chain.step() while (!stop) { + await new Promise((resolve) => setTimeout(resolve, 100)) const { completed } = await chain.step('') stop = completed } @@ -284,8 +279,12 @@ describe('chain', async () => { }) expect(conversation.messages[1]).toEqual({ role: MessageRole.assistant, - toolCalls: [], - content: '', + content: [ + { + type: 'text', + text: '', + }, + ], }) expect(conversation.messages[2]).toEqual({ role: MessageRole.system, @@ -503,7 +502,7 @@ describe('chain', async () => { it('saves the response in a variable', async () => { const prompt = removeCommonIndent(` <${CHAIN_STEP_TAG} as="response" /> - + {{response}} `) @@ -516,13 +515,18 @@ describe('chain', async () => { const { conversation } = await chain.step('foo') expect(conversation.messages.length).toBe(2) - expect(conversation.messages[0]!.content).toBe('foo') - expect(conversation.messages[1]!.content).toEqual([ - { - type: 'text', - text: 'foo', - }, - ]) + + const responseMessage = conversation.messages[0]! + expect(responseMessage.role).toBe(MessageRole.assistant) + expect(responseMessage.content.length).toBe(1) + expect((responseMessage.content[0] as TextContent).text).toBe('foo') + + const additionalMessage = conversation.messages[1]! + expect(additionalMessage.role).toBe(MessageRole.system) + expect(additionalMessage.content.length).toBe(1) + expect((additionalMessage.content[0] as TextContent).text).toBe( + JSON.stringify(responseMessage.content), + ) }) it('returns the correct configuration in all steps', async () => { diff --git a/packages/promptl/src/compiler/chain.ts b/packages/promptl/src/compiler/chain.ts index d392c65c2..ed2e0d0f2 100644 --- a/packages/promptl/src/compiler/chain.ts +++ b/packages/promptl/src/compiler/chain.ts @@ -1,9 +1,16 @@ -import parse from '$compiler/parser' -import { Fragment } from '$compiler/parser/interfaces' -import { Config, Conversation, Message } from '$compiler/types' +import parse from '$promptl/parser' +import { Fragment } from '$promptl/parser/interfaces' +import { + Config, + ContentType, + Conversation, + Message, + MessageContent, +} from '$promptl/types' import { Compile } from './compile' import Scope from './scope' +import { CompileOptions } from './types' type ChainStep = { conversation: Conversation @@ -13,6 +20,7 @@ type ChainStep = { export class Chain { public rawText: string + private compileOptions: CompileOptions private ast: Fragment private scope: Scope private didStart: boolean = false @@ -24,16 +32,18 @@ export class Chain { constructor({ prompt, parameters, + ...compileOptions }: { prompt: string parameters: Record - }) { + } & CompileOptions) { this.rawText = prompt this.ast = parse(prompt) this.scope = new Scope(parameters) + this.compileOptions = compileOptions } - async step(response?: string): Promise { + async step(response?: MessageContent[] | string): Promise { if (this._completed) { throw new Error('The chain has already completed') } @@ -49,7 +59,8 @@ export class Chain { ast: this.ast, rawText: this.rawText, globalScope: this.scope, - stepResponse: response, + stepResponse: buildStepResponseContent(response), + ...this.compileOptions, }) const { completed, scopeStash, ast, messages, globalConfig, stepConfig } = @@ -79,3 +90,17 @@ export class Chain { return this._completed } } + +function buildStepResponseContent( + response?: MessageContent[] | string, +): MessageContent[] | undefined { + if (response == undefined) return response + if (Array.isArray(response)) return response + + return [ + { + type: ContentType.text, + text: response, + }, + ] +} diff --git a/packages/promptl/src/compiler/compile.test.ts b/packages/promptl/src/compiler/compile.test.ts index 3c2937d4f..757082a34 100644 --- a/packages/promptl/src/compiler/compile.test.ts +++ b/packages/promptl/src/compiler/compile.test.ts @@ -1,11 +1,12 @@ -import { getExpectedError } from '$compiler/compiler/test/helpers' -import CompileError from '$compiler/error/error' +import CompileError from '$promptl/error/error' +import { getExpectedError } from '$promptl/test/helpers' import { + ContentType, Message, MessageContent, MessageRole, TextContent, -} from '$compiler/types' +} from '$promptl/types' import { describe, expect, it } from 'vitest' import { render } from '.' @@ -32,6 +33,48 @@ async function getCompiledText( }, '') } +describe('automatic message grouping', async () => { + it('returns system messages by default', async () => { + const prompt = 'Hello world!' + const result = await render({ prompt }) + const message = result.messages[0]! + expect(message.role).toBe(MessageRole.system) + }) + + it('groups consecutive contents with the same role', async () => { + const prompt = ` + Hello world + + This is + + your + + Captain + + speaking + ` + const result = await render({ prompt }) + const messages = result.messages + + expect(messages.length).toBe(1) + const message = messages[0]! + expect(message.role).toBe(MessageRole.system) + expect(message.content.length).toBe(5) + expect(message.content[0]!.type).toBe(ContentType.text) + expect(message.content[1]!.type).toBe(ContentType.text) + expect(message.content[2]!.type).toBe(ContentType.text) + expect(message.content[3]!.type).toBe(ContentType.image) + expect(message.content[4]!.type).toBe(ContentType.text) + }) + + it('allows defining the default role', async () => { + const prompt = 'Hello world!' + const result = await render({ prompt, defaultRole: MessageRole.user }) + const message = result.messages[0]! + expect(message.role).toBe(MessageRole.user) + }) +}) + describe('config section', async () => { it('compiles the YAML written in the config section and returns it as the config attribute in the result', async () => { const prompt = ` diff --git a/packages/promptl/src/compiler/compile.ts b/packages/promptl/src/compiler/compile.ts index aeceb2df3..b70bb1f50 100644 --- a/packages/promptl/src/compiler/compile.ts +++ b/packages/promptl/src/compiler/compile.ts @@ -1,10 +1,10 @@ -import { error } from '$compiler/error/error' -import errors from '$compiler/error/errors' +import { error } from '$promptl/error/error' +import errors from '$promptl/error/errors' import type { BaseNode, Fragment, TemplateNode, -} from '$compiler/parser/interfaces' +} from '$promptl/parser/interfaces' import { AssistantMessage, Config, @@ -13,7 +13,7 @@ import { MessageContent, MessageRole, SystemMessage, -} from '$compiler/types' +} from '$promptl/types' import type { Node as LogicalExpression } from 'estree' import { compile as resolveComment } from './base/nodes/comment' @@ -27,7 +27,7 @@ import { compile as resolveText } from './base/nodes/text' import { CompileNodeContext, TemplateNodeWithStatus } from './base/types' import { resolveLogicNode } from './logic' import Scope, { ScopeStash } from './scope' -import type { ResolveBaseNodeProps, ToolCallReference } from './types' +import type { CompileOptions, ResolveBaseNodeProps } from './types' import { removeCommonIndent } from './utils' export type CompilationStatus = { @@ -51,30 +51,35 @@ export class Compile { private ast: Fragment private rawText: string private globalScope: Scope + private defaultRole: MessageRole private messages: Message[] = [] private config: Config | undefined - private stepResponse: string | undefined + private stepResponse: MessageContent[] | undefined private accumulatedText: string = '' - private accumulatedContent: MessageContent[] = [] - private accumulatedToolCalls: ToolCallReference[] = [] + private accumulatedContent: { + node?: TemplateNode + content: MessageContent + }[] = [] constructor({ ast, rawText, globalScope, stepResponse, + defaultRole = MessageRole.system, }: { rawText: string globalScope: Scope ast: Fragment - stepResponse?: string - }) { + stepResponse?: MessageContent[] + } & CompileOptions) { this.rawText = rawText this.globalScope = globalScope this.ast = ast this.stepResponse = stepResponse + this.defaultRole = defaultRole } async run(): Promise { @@ -126,70 +131,58 @@ export class Compile { this.accumulatedText += text } + private popStrayText(): string { + const text = this.accumulatedText + this.accumulatedText = '' + return text + } + private groupStrayText(): void { if (this.accumulatedText.trim() !== '') { this.accumulatedContent.push({ - type: ContentType.text, - text: removeCommonIndent(this.accumulatedText).trim(), + content: { + type: ContentType.text, + text: removeCommonIndent(this.accumulatedText).trim(), + }, }) } this.accumulatedText = '' } - private popStrayText(): string { - const text = this.accumulatedText - this.accumulatedText = '' - return text + private addContent(item: { + node?: TemplateNode + content: MessageContent + }): void { + this.groupStrayText() + this.accumulatedContent.push(item) } - private addContent(content: MessageContent): void { - this.groupStrayText() - this.accumulatedContent.push(content) + private popContent(): { node?: TemplateNode; content: MessageContent }[] { + const content = [...this.accumulatedContent] + this.accumulatedContent = [] + return content } private groupContent(): void { this.groupStrayText() - const toolCalls = this.popToolCalls() - const content = this.popContent() - - toolCalls.forEach(({ node: toolNode }) => { - this.baseNodeError(errors.invalidToolCallPlacement, toolNode) - }) + const contentItems = this.popContent() - if (!content.length) return + if (!contentItems.length) return const message = { - role: MessageRole.system, - content, + role: this.defaultRole, + content: contentItems.map((item) => item.content), } as SystemMessage this.addMessage(message) } - private popContent(): MessageContent[] { - const content = this.accumulatedContent - this.accumulatedContent = [] - return content - } - - private addToolCall(toolCallRef: ToolCallReference): void { - this.groupStrayText() - this.accumulatedToolCalls.push(toolCallRef) - } - - private popToolCalls(): ToolCallReference[] { - const toolCalls = this.accumulatedToolCalls - this.accumulatedToolCalls = [] - return toolCalls - } - private popStepResponse() { if (this.stepResponse === undefined) return undefined const response: AssistantMessage = { role: MessageRole.assistant, content: this.stepResponse, - toolCalls: [], } this.stepResponse = undefined @@ -253,8 +246,6 @@ export class Compile { addContent: this.addContent.bind(this), popContent: this.popContent.bind(this), groupContent: this.groupContent.bind(this), - addToolCall: this.addToolCall.bind(this), - popToolCalls: this.popToolCalls.bind(this), popStepResponse: this.popStepResponse.bind(this), stop: this.stop.bind(this), } diff --git a/packages/promptl/src/compiler/errors.test.ts b/packages/promptl/src/compiler/errors.test.ts index ab42e899e..97923192d 100644 --- a/packages/promptl/src/compiler/errors.test.ts +++ b/packages/promptl/src/compiler/errors.test.ts @@ -1,4 +1,4 @@ -import CompileError from '$compiler/error/error' +import CompileError from '$promptl/error/error' import { describe, expect, it } from 'vitest' import { readMetadata, render } from '.' @@ -140,11 +140,11 @@ describe(`all compilation errors that don't require value resolution are caught it('content-tag-inside-content', async () => { const prompt = ` - - + + Foo - - + + ` await expectBothErrors({ @@ -153,19 +153,6 @@ describe(`all compilation errors that don't require value resolution are caught }) }) - it('tool-call-tag-inside-content', async () => { - const prompt = ` - - - - ` - - await expectBothErrors({ - code: 'tool-call-tag-inside-content', - prompt, - }) - }) - it('tool-call-tag-without-id', async () => { const prompt = ` diff --git a/packages/promptl/src/compiler/index.ts b/packages/promptl/src/compiler/index.ts index dc3ded224..a34736498 100644 --- a/packages/promptl/src/compiler/index.ts +++ b/packages/promptl/src/compiler/index.ts @@ -1,4 +1,4 @@ -import { Conversation, ConversationMetadata } from '$compiler/types' +import { Conversation, ConversationMetadata } from '$promptl/types' import { z } from 'zod' import { Chain } from './chain' @@ -7,15 +7,17 @@ import { type Document, type ReferencePromptFn, } from './readMetadata' +import { CompileOptions } from './types' export async function render({ prompt, parameters = {}, + ...compileOptions }: { prompt: string parameters?: Record -}): Promise { - const iterator = new Chain({ prompt, parameters }) +} & CompileOptions): Promise { + const iterator = new Chain({ prompt, parameters, ...compileOptions }) const { conversation, completed } = await iterator.step() if (!completed) { throw new Error('Use a Chain to render prompts with multiple steps') diff --git a/packages/promptl/src/compiler/logic/nodes/arrayExpression.ts b/packages/promptl/src/compiler/logic/nodes/arrayExpression.ts index b2f75cddc..e7dd0bfdc 100644 --- a/packages/promptl/src/compiler/logic/nodes/arrayExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/arrayExpression.ts @@ -1,13 +1,13 @@ import { resolveLogicNode, updateScopeContextForNode, -} from '$compiler/compiler/logic' +} from '$promptl/compiler/logic' import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' -import { isIterable } from '$compiler/compiler/utils' -import errors from '$compiler/error/errors' +} from '$promptl/compiler/logic/types' +import { isIterable } from '$promptl/compiler/utils' +import errors from '$promptl/error/errors' import type { ArrayExpression } from 'estree' /** diff --git a/packages/promptl/src/compiler/logic/nodes/assignmentExpression.ts b/packages/promptl/src/compiler/logic/nodes/assignmentExpression.ts index 159224d03..ba8550de5 100644 --- a/packages/promptl/src/compiler/logic/nodes/assignmentExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/assignmentExpression.ts @@ -1,9 +1,9 @@ -import { ASSIGNMENT_OPERATOR_METHODS } from '$compiler/compiler/logic/operators' +import { ASSIGNMENT_OPERATOR_METHODS } from '$promptl/compiler/logic/operators' import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' -import errors from '$compiler/error/errors' +} from '$promptl/compiler/logic/types' +import errors from '$promptl/error/errors' import type { AssignmentExpression, AssignmentOperator, diff --git a/packages/promptl/src/compiler/logic/nodes/binaryExpression.ts b/packages/promptl/src/compiler/logic/nodes/binaryExpression.ts index b62eaba52..8e3624361 100644 --- a/packages/promptl/src/compiler/logic/nodes/binaryExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/binaryExpression.ts @@ -1,9 +1,9 @@ -import { BINARY_OPERATOR_METHODS } from '$compiler/compiler/logic/operators' +import { BINARY_OPERATOR_METHODS } from '$promptl/compiler/logic/operators' import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' -import errors from '$compiler/error/errors' +} from '$promptl/compiler/logic/types' +import errors from '$promptl/error/errors' import type { BinaryExpression, LogicalExpression } from 'estree' import { resolveLogicNode, updateScopeContextForNode } from '..' diff --git a/packages/promptl/src/compiler/logic/nodes/callExpression.ts b/packages/promptl/src/compiler/logic/nodes/callExpression.ts index 539368094..04f3824c5 100644 --- a/packages/promptl/src/compiler/logic/nodes/callExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/callExpression.ts @@ -1,9 +1,9 @@ import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' -import CompileError from '$compiler/error/error' -import errors from '$compiler/error/errors' +} from '$promptl/compiler/logic/types' +import CompileError from '$promptl/error/error' +import errors from '$promptl/error/errors' import type { SimpleCallExpression } from 'estree' import { resolveLogicNode, updateScopeContextForNode } from '..' diff --git a/packages/promptl/src/compiler/logic/nodes/chainExpression.ts b/packages/promptl/src/compiler/logic/nodes/chainExpression.ts index d0a34a69e..63f0b720e 100644 --- a/packages/promptl/src/compiler/logic/nodes/chainExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/chainExpression.ts @@ -1,7 +1,7 @@ import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' +} from '$promptl/compiler/logic/types' import type { ChainExpression } from 'estree' import { resolveLogicNode, updateScopeContextForNode } from '..' diff --git a/packages/promptl/src/compiler/logic/nodes/conditionalExpression.ts b/packages/promptl/src/compiler/logic/nodes/conditionalExpression.ts index ac8967021..50f4d7948 100644 --- a/packages/promptl/src/compiler/logic/nodes/conditionalExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/conditionalExpression.ts @@ -1,7 +1,7 @@ import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' +} from '$promptl/compiler/logic/types' import type { ConditionalExpression } from 'estree' import { resolveLogicNode, updateScopeContextForNode } from '..' diff --git a/packages/promptl/src/compiler/logic/nodes/identifier.ts b/packages/promptl/src/compiler/logic/nodes/identifier.ts index bfaa1a388..33bd632a0 100644 --- a/packages/promptl/src/compiler/logic/nodes/identifier.ts +++ b/packages/promptl/src/compiler/logic/nodes/identifier.ts @@ -1,8 +1,8 @@ import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' -import errors from '$compiler/error/errors' +} from '$promptl/compiler/logic/types' +import errors from '$promptl/error/errors' import type { Identifier } from 'estree' /** diff --git a/packages/promptl/src/compiler/logic/nodes/index.ts b/packages/promptl/src/compiler/logic/nodes/index.ts index 22e46987f..1b0c54e78 100644 --- a/packages/promptl/src/compiler/logic/nodes/index.ts +++ b/packages/promptl/src/compiler/logic/nodes/index.ts @@ -2,7 +2,7 @@ import { NodeType, ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' +} from '$promptl/compiler/logic/types' import { Node } from 'estree' import { diff --git a/packages/promptl/src/compiler/logic/nodes/literal.ts b/packages/promptl/src/compiler/logic/nodes/literal.ts index 6e5f77df0..d10868181 100644 --- a/packages/promptl/src/compiler/logic/nodes/literal.ts +++ b/packages/promptl/src/compiler/logic/nodes/literal.ts @@ -1,4 +1,4 @@ -import { type ResolveNodeProps } from '$compiler/compiler/logic/types' +import { type ResolveNodeProps } from '$promptl/compiler/logic/types' import { type Literal } from 'estree' /** diff --git a/packages/promptl/src/compiler/logic/nodes/memberExpression.ts b/packages/promptl/src/compiler/logic/nodes/memberExpression.ts index c0d906d29..0b3bbd219 100644 --- a/packages/promptl/src/compiler/logic/nodes/memberExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/memberExpression.ts @@ -1,8 +1,8 @@ -import { MEMBER_EXPRESSION_METHOD } from '$compiler/compiler/logic/operators' +import { MEMBER_EXPRESSION_METHOD } from '$promptl/compiler/logic/operators' import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' +} from '$promptl/compiler/logic/types' import type { Identifier, MemberExpression } from 'estree' import { resolveLogicNode, updateScopeContextForNode } from '..' diff --git a/packages/promptl/src/compiler/logic/nodes/objectExpression.ts b/packages/promptl/src/compiler/logic/nodes/objectExpression.ts index 5a7b91513..c939dce43 100644 --- a/packages/promptl/src/compiler/logic/nodes/objectExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/objectExpression.ts @@ -1,12 +1,12 @@ import { resolveLogicNode, updateScopeContextForNode, -} from '$compiler/compiler/logic' +} from '$promptl/compiler/logic' import { UpdateScopeContextProps, type ResolveNodeProps, -} from '$compiler/compiler/logic/types' -import errors from '$compiler/error/errors' +} from '$promptl/compiler/logic/types' +import errors from '$promptl/error/errors' import { type Identifier, type ObjectExpression } from 'estree' /** diff --git a/packages/promptl/src/compiler/logic/nodes/sequenceExpression.ts b/packages/promptl/src/compiler/logic/nodes/sequenceExpression.ts index 59ecf4814..9f6c6ac81 100644 --- a/packages/promptl/src/compiler/logic/nodes/sequenceExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/sequenceExpression.ts @@ -1,11 +1,11 @@ import { resolveLogicNode, updateScopeContextForNode, -} from '$compiler/compiler/logic' +} from '$promptl/compiler/logic' import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' +} from '$promptl/compiler/logic/types' import type { SequenceExpression } from 'estree' /** diff --git a/packages/promptl/src/compiler/logic/nodes/unaryExpression.ts b/packages/promptl/src/compiler/logic/nodes/unaryExpression.ts index c60aea860..a77a8aa45 100644 --- a/packages/promptl/src/compiler/logic/nodes/unaryExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/unaryExpression.ts @@ -1,9 +1,9 @@ -import { UNARY_OPERATOR_METHODS } from '$compiler/compiler/logic/operators' +import { UNARY_OPERATOR_METHODS } from '$promptl/compiler/logic/operators' import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' -import errors from '$compiler/error/errors' +} from '$promptl/compiler/logic/types' +import errors from '$promptl/error/errors' import type { UnaryExpression } from 'estree' import { resolveLogicNode, updateScopeContextForNode } from '..' diff --git a/packages/promptl/src/compiler/logic/nodes/updateExpression.ts b/packages/promptl/src/compiler/logic/nodes/updateExpression.ts index 86fa7e3c5..cd2ab12b4 100644 --- a/packages/promptl/src/compiler/logic/nodes/updateExpression.ts +++ b/packages/promptl/src/compiler/logic/nodes/updateExpression.ts @@ -1,12 +1,12 @@ import { resolveLogicNode, updateScopeContextForNode, -} from '$compiler/compiler/logic' +} from '$promptl/compiler/logic' import type { ResolveNodeProps, UpdateScopeContextProps, -} from '$compiler/compiler/logic/types' -import errors from '$compiler/error/errors' +} from '$promptl/compiler/logic/types' +import errors from '$promptl/error/errors' import type { AssignmentExpression, UpdateExpression } from 'estree' /** diff --git a/packages/promptl/src/compiler/logic/types.ts b/packages/promptl/src/compiler/logic/types.ts index 274036aae..7b9eb9933 100644 --- a/packages/promptl/src/compiler/logic/types.ts +++ b/packages/promptl/src/compiler/logic/types.ts @@ -1,4 +1,4 @@ -import Scope, { ScopeContext } from '$compiler/compiler/scope' +import Scope, { ScopeContext } from '$promptl/compiler/scope' import { Node } from 'estree' export enum NodeType { diff --git a/packages/promptl/src/compiler/readMetadata.test.ts b/packages/promptl/src/compiler/readMetadata.test.ts index 41b4ac2e5..ec1611888 100644 --- a/packages/promptl/src/compiler/readMetadata.test.ts +++ b/packages/promptl/src/compiler/readMetadata.test.ts @@ -1,4 +1,4 @@ -import CompileError from '$compiler/error/error' +import CompileError from '$promptl/error/error' import { describe, expect, it } from 'vitest' import { z } from 'zod' diff --git a/packages/promptl/src/compiler/readMetadata.ts b/packages/promptl/src/compiler/readMetadata.ts index 7aab4ede3..9f3bae915 100644 --- a/packages/promptl/src/compiler/readMetadata.ts +++ b/packages/promptl/src/compiler/readMetadata.ts @@ -4,19 +4,24 @@ import { REFERENCE_DEPTH_LIMIT, REFERENCE_PROMPT_ATTR, REFERENCE_PROMPT_TAG, -} from '$compiler/constants' -import CompileError, { error } from '$compiler/error/error' -import errors from '$compiler/error/errors' -import parse from '$compiler/parser/index' +} from '$promptl/constants' +import CompileError, { error } from '$promptl/error/error' +import errors from '$promptl/error/errors' +import parse from '$promptl/parser/index' import type { Attribute, BaseNode, + ContentTag, ElementTag, Fragment, TemplateNode, - ToolCallTag, -} from '$compiler/parser/interfaces' -import { Config, ConversationMetadata, MessageRole } from '$compiler/types' +} from '$promptl/parser/interfaces' +import { + Config, + ContentTypeTagName, + ConversationMetadata, + MessageRole, +} from '$promptl/types' import { Node as LogicalExpression } from 'estree' import yaml, { Node as YAMLItem } from 'yaml' import { z } from 'zod' @@ -29,7 +34,6 @@ import { isContentTag, isMessageTag, isRefTag, - isToolCallTag, } from './utils' function copyScopeContext(scopeContext: ScopeContext): ScopeContext { @@ -61,7 +65,7 @@ export class ReadMetadata { private resolvedPromptOffset: number = 0 private hasContent: boolean = false - private accumulatedToolCalls: ToolCallTag[] = [] + private accumulatedToolCalls: ContentTag[] = [] private errors: CompileError[] = [] private references: { [from: string]: string[] } = {} @@ -404,44 +408,28 @@ export class ReadMetadata { if (node.type === 'ElementTag') { this.hasContent = true - if (isToolCallTag(node)) { - if (isInsideContentTag) { - this.baseNodeError(errors.toolCallTagInsideContent, node) - return - } - - const attributes = await this.listTagAttributes({ - tagNode: node, - scopeContext, - }) - if (!attributes.has('id')) { - this.baseNodeError(errors.toolCallTagWithoutId, node) + if (isContentTag(node)) { + if (isInsideContentTag) { + this.baseNodeError(errors.contentTagInsideContent, node) return } - if (!attributes.has('name')) { - this.baseNodeError(errors.toolCallWithoutName, node) - return - } + if (node.name === ContentTypeTagName.toolCall) { + this.accumulatedToolCalls.push(node) - for await (const childNode of node.children ?? []) { - await this.readBaseMetadata({ - node: childNode, + const attributes = await this.listTagAttributes({ + tagNode: node, scopeContext, - isInsideMessageTag, - isInsideContentTag: true, }) - } - this.accumulatedToolCalls.push(node as ToolCallTag) - return - } + if (!attributes.has('id')) { + this.baseNodeError(errors.toolCallTagWithoutId, node) + } - if (isContentTag(node)) { - if (isInsideContentTag) { - this.baseNodeError(errors.contentTagInsideContent, node) - return + if (!attributes.has('name')) { + this.baseNodeError(errors.toolCallWithoutName, node) + } } for await (const childNode of node.children ?? []) { diff --git a/packages/promptl/src/compiler/types.ts b/packages/promptl/src/compiler/types.ts index fe7713d7f..e4dcebee1 100644 --- a/packages/promptl/src/compiler/types.ts +++ b/packages/promptl/src/compiler/types.ts @@ -1,5 +1,5 @@ -import { TemplateNode, ToolCallTag } from '$compiler/parser/interfaces' -import { ToolCall } from '$compiler/types' +import { TemplateNode } from '$promptl/parser/interfaces' +import { MessageRole } from '$promptl/types' import type Scope from './scope' @@ -11,4 +11,6 @@ export type ResolveBaseNodeProps = { completedValue?: unknown } -export type ToolCallReference = { node: ToolCallTag; value: ToolCall } +export type CompileOptions = { + defaultRole?: MessageRole +} diff --git a/packages/promptl/src/compiler/utils.ts b/packages/promptl/src/compiler/utils.ts index cd6bed281..8538d2415 100644 --- a/packages/promptl/src/compiler/utils.ts +++ b/packages/promptl/src/compiler/utils.ts @@ -1,18 +1,17 @@ import { CHAIN_STEP_TAG, + CUSTOM_CONTENT_TAG, CUSTOM_MESSAGE_TAG, REFERENCE_PROMPT_TAG, - TOOL_CALL_TAG, -} from '$compiler/constants' +} from '$promptl/constants' import { ChainStepTag, ContentTag, ElementTag, MessageTag, ReferenceTag, - ToolCallTag, -} from '$compiler/parser/interfaces' -import { ContentType, MessageRole } from '$compiler/types' +} from '$promptl/parser/interfaces' +import { ContentTypeTagName, MessageRole } from '$promptl/types' import { Scalar, Node as YAMLItem, YAMLMap, YAMLSeq } from 'yaml' export function isIterable(obj: unknown): obj is Iterable { @@ -49,7 +48,10 @@ export function isMessageTag(tag: ElementTag): tag is MessageTag { } export function isContentTag(tag: ElementTag): tag is ContentTag { - return Object.values(ContentType).includes(tag.name as ContentType) + if (tag.name === CUSTOM_CONTENT_TAG) return true + return Object.values(ContentTypeTagName).includes( + tag.name as ContentTypeTagName, + ) } export function isRefTag(tag: ElementTag): tag is ReferenceTag { @@ -60,10 +62,6 @@ export function isChainStepTag(tag: ElementTag): tag is ChainStepTag { return tag.name === CHAIN_STEP_TAG } -export function isToolCallTag(tag: ElementTag): tag is ToolCallTag { - return tag.name === TOOL_CALL_TAG -} - export function tagAttributeIsLiteral(tag: ElementTag, name: string): boolean { const attr = tag.attributes.find((attr) => attr.name === name) if (!attr) return false diff --git a/packages/promptl/src/constants.ts b/packages/promptl/src/constants.ts index 19abe2870..a6948e94a 100644 --- a/packages/promptl/src/constants.ts +++ b/packages/promptl/src/constants.ts @@ -5,14 +5,14 @@ export const CUSTOM_TAG_END = '}}' export const CUSTOM_MESSAGE_TAG = 'message' as const export const CUSTOM_MESSAGE_ROLE_ATTR = 'role' as const +export const CUSTOM_CONTENT_TAG = 'content' as const +export const CUSTOM_CONTENT_TYPE_ATTR = 'type' as const + // export const REFERENCE_PROMPT_TAG = 'prompt' as const export const REFERENCE_PROMPT_ATTR = 'path' as const export const REFERENCE_DEPTH_LIMIT = 50 -// { content } -export const TOOL_CALL_TAG = 'tool-call' as const - // export const CHAIN_STEP_TAG = 'response' as const diff --git a/packages/promptl/src/error/error.ts b/packages/promptl/src/error/error.ts index 56128ce31..01f679626 100644 --- a/packages/promptl/src/error/error.ts +++ b/packages/promptl/src/error/error.ts @@ -1,4 +1,4 @@ -import { Fragment } from '$compiler/parser/interfaces' +import { Fragment } from '$promptl/parser/interfaces' import { locate } from 'locate-character' export interface Position { diff --git a/packages/promptl/src/error/errors.ts b/packages/promptl/src/error/errors.ts index 095bf433b..acd10a26c 100644 --- a/packages/promptl/src/error/errors.ts +++ b/packages/promptl/src/error/errors.ts @@ -1,4 +1,4 @@ -import { CUSTOM_TAG_END, CUSTOM_TAG_START } from '$compiler/constants' +import { CUSTOM_TAG_END, CUSTOM_TAG_START } from '$promptl/constants' import CompileError from './error' @@ -143,7 +143,7 @@ export default { }), invalidToolCallPlacement: { code: 'invalid-tool-call-placement', - message: 'All tool calls must be inside of an assistant message', + message: 'Only assistant messages can contain tool calls', }, messageTagInsideMessage: { code: 'message-tag-inside-message', @@ -153,10 +153,6 @@ export default { code: 'content-tag-inside-content', message: 'Content tags must be directly inside message tags', }, - toolCallTagInsideContent: { - code: 'tool-call-tag-inside-content', - message: 'Tool calls must be directly inside message tags', - }, toolCallTagWithoutId: { code: 'tool-call-tag-without-id', message: 'Tool call tags must have an id attribute', @@ -227,6 +223,10 @@ export default { code: 'invalid-message-role', message: `Invalid message role: ${name}`, }), + invalidContentType: (name: string) => ({ + code: 'invalid-content-type', + message: `Invalid content type: ${name}`, + }), variableNotDeclared: (name: string) => ({ code: 'variable-not-declared', message: `Variable '${name}' is not declared`, diff --git a/packages/promptl/src/parser/index.ts b/packages/promptl/src/parser/index.ts index 689d00a4d..8ba81ed8c 100644 --- a/packages/promptl/src/parser/index.ts +++ b/packages/promptl/src/parser/index.ts @@ -1,6 +1,6 @@ -import CompileError, { error } from '$compiler/error/error' -import PARSER_ERRORS from '$compiler/error/errors' -import { reserved } from '$compiler/utils/names' +import CompileError, { error } from '$promptl/error/error' +import PARSER_ERRORS from '$promptl/error/errors' +import { reserved } from '$promptl/utils/names' import { isIdentifierChar, isIdentifierStart } from 'acorn' import type { BaseNode, Fragment } from './interfaces' diff --git a/packages/promptl/src/parser/interfaces.ts b/packages/promptl/src/parser/interfaces.ts index 7263ef1a2..cf832451e 100644 --- a/packages/promptl/src/parser/interfaces.ts +++ b/packages/promptl/src/parser/interfaces.ts @@ -1,10 +1,10 @@ import { CHAIN_STEP_TAG, + CUSTOM_CONTENT_TAG, CUSTOM_MESSAGE_TAG, REFERENCE_PROMPT_TAG, - TOOL_CALL_TAG, -} from '$compiler/constants' -import { ContentType, MessageRole } from '$compiler/types' +} from '$promptl/constants' +import { ContentTypeTagName, MessageRole } from '$promptl/types' import { Identifier, type Node as LogicalExpression } from 'estree' export type BaseNode = { @@ -43,13 +43,15 @@ type IElementTag = BaseNode & { children: TemplateNode[] } -export type ContentTag = IElementTag export type MessageTag = | IElementTag | IElementTag +export type ContentTag = + | IElementTag + | IElementTag + export type ReferenceTag = IElementTag export type ChainStepTag = IElementTag -export type ToolCallTag = IElementTag export type ElementTag = | ContentTag | MessageTag @@ -81,7 +83,6 @@ export type ForBlock = BaseNode & { expression: LogicalExpression context: Identifier index: Identifier | null - key: LogicalExpression else: ElseBlock | null } diff --git a/packages/promptl/src/parser/parser.test.ts b/packages/promptl/src/parser/parser.test.ts index 097e9a3a2..b92c95fb2 100644 --- a/packages/promptl/src/parser/parser.test.ts +++ b/packages/promptl/src/parser/parser.test.ts @@ -1,32 +1,20 @@ -import { CUSTOM_TAG_END, CUSTOM_TAG_START } from '$compiler/constants' -import CompileError from '$compiler/error/error' +import { CUSTOM_TAG_END, CUSTOM_TAG_START } from '$promptl/constants' +import CompileError from '$promptl/error/error' +import { getExpectedError } from '$promptl/test/helpers' import { describe, expect, it } from 'vitest' import parse from '.' import { TemplateNode } from './interfaces' -const getExpectedError = ( - action: () => void, - errorClass: new () => T, -): T => { - try { - action() - } catch (err) { - expect(err).toBeInstanceOf(errorClass) - return err as T - } - throw new Error('Expected an error to be thrown') -} - -describe('Fragment', () => { - it('parses any string as a fragment', () => { +describe('Fragment', async () => { + it('parses any string as a fragment', async () => { const fragment = parse('hello world') expect(fragment.type).toBe('Fragment') }) }) -describe('Text Block', () => { - it('parses any regular string as a text block', () => { +describe('Text Block', async () => { + it('parses any regular string as a text block', async () => { const text = 'hello world' const fragment = parse(text) expect(fragment.children.length).toBe(1) @@ -36,7 +24,7 @@ describe('Text Block', () => { expect(textBlock.data).toBe(text) }) - it('keeps line breaks', () => { + it('keeps line breaks', async () => { const text = 'hello\nworld' const fragment = parse(text) expect(fragment.children.length).toBe(1) @@ -46,7 +34,7 @@ describe('Text Block', () => { expect(textBlock.data).toBe(text) }) - it('parses escaped brackets as text', () => { + it('parses escaped brackets as text', async () => { const text = `hello \\${CUSTOM_TAG_START} world` const expected = `hello ${CUSTOM_TAG_START} world` const fragment = parse(text) @@ -58,8 +46,8 @@ describe('Text Block', () => { }) }) -describe('Comments', () => { - it('parses a multiline comment block', () => { +describe('Comments', async () => { + it('parses a multiline comment block', async () => { const fragment = parse('/* hello\nworld */') expect(fragment.children.length).toBe(1) @@ -69,7 +57,7 @@ describe('Comments', () => { expect(commentBlock.raw).toBe('/* hello\nworld */') }) - it('ignores brackets and any other block within a comment', () => { + it('ignores brackets and any other block within a comment', async () => { const fragment = parse( ` /* hello @@ -86,7 +74,7 @@ world */ ) }) - it('Allows tag comments', () => { + it('Allows tag comments', async () => { const fragment = parse('') expect(fragment.children.length).toBe(1) @@ -96,8 +84,8 @@ world */ }) }) -describe('Tags', () => { - it('parses any HTML-like tag', () => { +describe('Tags', async () => { + it('parses any HTML-like tag', async () => { const fragment = parse('') expect(fragment.children.length).toBe(1) @@ -106,7 +94,7 @@ describe('Tags', () => { expect(tag.name).toBe('custom-tag') }) - it('parses self closing tags', () => { + it('parses self closing tags', async () => { const fragment = parse('') expect(fragment.children.length).toBe(1) @@ -115,28 +103,28 @@ describe('Tags', () => { expect(tag.name).toBe('custom-tag') }) - it('fails if there is no closing tag', () => { + it('fails if there is no closing tag', async () => { const action = () => parse('') - const error = getExpectedError(action, CompileError) + const error = await getExpectedError(action, CompileError) expect(error.code).toBe('unclosed-block') }) - it('fails if the tag is not opened', () => { + it('fails if the tag is not opened', async () => { const action = () => parse('') - const error = getExpectedError(action, CompileError) + const error = await getExpectedError(action, CompileError) expect(error.code).toBe('unexpected-tag-close') }) - it('fails if the tag is not closed', () => { + it('fails if the tag is not closed', async () => { const action = () => parse(' { + it('Parses tags within tags', async () => { const fragment = parse('') expect(fragment.children.length).toBe(1) @@ -150,7 +138,7 @@ describe('Tags', () => { expect(child.name).toBe('child') }) - it('parses all attributes', () => { + it('parses all attributes', async () => { const fragment = parse( '', ) @@ -178,7 +166,7 @@ describe('Tags', () => { expect(value2[0]!.data).toBe('value2') }) - it('Parses attribute vales as expressions when interpolated', () => { + it('Parses attribute vales as expressions when interpolated', async () => { const fragment = parse( ``, ) @@ -198,7 +186,7 @@ describe('Tags', () => { expect(value[0]!.expression).toBeTruthy() }) - it('Parses attributes with no value as true', () => { + it('Parses attributes with no value as true', async () => { const fragment = parse(``) expect(fragment.children.length).toBe(1) @@ -213,10 +201,10 @@ describe('Tags', () => { expect(attr.value).toBe(true) }) - it('Fails when adding a duplicate attribute', () => { + it('Fails when adding a duplicate attribute', async () => { const action = () => parse(``) - const error = getExpectedError(action, CompileError) + const error = await getExpectedError(action, CompileError) expect(error.code).toBe('duplicate-attribute') }) }) diff --git a/packages/promptl/src/parser/read/context.ts b/packages/promptl/src/parser/read/context.ts index caf7d07c2..628433138 100644 --- a/packages/promptl/src/parser/read/context.ts +++ b/packages/promptl/src/parser/read/context.ts @@ -1,13 +1,13 @@ -import type CompileError from '$compiler/error/error' -import PARSER_ERRORS from '$compiler/error/errors' -import { parseExpressionAt } from '$compiler/parser/utils/acorn' +import type CompileError from '$promptl/error/error' +import PARSER_ERRORS from '$promptl/error/errors' +import { parseExpressionAt } from '$promptl/parser/utils/acorn' import { getBracketClose, isBracketClose, isBracketOpen, isBracketPair, -} from '$compiler/parser/utils/bracket' -import fullCharCodeAt from '$compiler/parser/utils/full_char_code_at' +} from '$promptl/parser/utils/bracket' +import fullCharCodeAt from '$promptl/parser/utils/full_char_code_at' import { isIdentifierStart } from 'acorn' import { Pattern } from 'estree' diff --git a/packages/promptl/src/parser/read/expression.ts b/packages/promptl/src/parser/read/expression.ts index 749090113..0cd98b074 100644 --- a/packages/promptl/src/parser/read/expression.ts +++ b/packages/promptl/src/parser/read/expression.ts @@ -1,7 +1,7 @@ -import CompileError from '$compiler/error/error' -import PARSER_ERRORS from '$compiler/error/errors' -import { Parser } from '$compiler/parser' -import { parseExpressionAt } from '$compiler/parser/utils/acorn' +import CompileError from '$promptl/error/error' +import PARSER_ERRORS from '$promptl/error/errors' +import { Parser } from '$promptl/parser' +import { parseExpressionAt } from '$promptl/parser/utils/acorn' export default function readExpression(parser: Parser) { try { diff --git a/packages/promptl/src/parser/state/config.ts b/packages/promptl/src/parser/state/config.ts index 8dffac98a..6b9cb2c64 100644 --- a/packages/promptl/src/parser/state/config.ts +++ b/packages/promptl/src/parser/state/config.ts @@ -1,6 +1,6 @@ -import PARSER_ERRORS from '$compiler/error/errors' -import { Parser } from '$compiler/parser' -import type { Config } from '$compiler/parser/interfaces' +import PARSER_ERRORS from '$promptl/error/errors' +import { Parser } from '$promptl/parser' +import type { Config } from '$promptl/parser/interfaces' export function config(parser: Parser) { const start = parser.index diff --git a/packages/promptl/src/parser/state/fragment.ts b/packages/promptl/src/parser/state/fragment.ts index 924e6538c..80b6ff206 100644 --- a/packages/promptl/src/parser/state/fragment.ts +++ b/packages/promptl/src/parser/state/fragment.ts @@ -1,4 +1,4 @@ -import { CUSTOM_TAG_END, CUSTOM_TAG_START } from '$compiler/constants' +import { CUSTOM_TAG_END, CUSTOM_TAG_START } from '$promptl/constants' import { Parser } from '..' import { config } from './config' diff --git a/packages/promptl/src/parser/state/multi_line_comment.ts b/packages/promptl/src/parser/state/multi_line_comment.ts index 6cd1b479f..7759683dd 100644 --- a/packages/promptl/src/parser/state/multi_line_comment.ts +++ b/packages/promptl/src/parser/state/multi_line_comment.ts @@ -1,5 +1,5 @@ -import PARSER_ERRORS from '$compiler/error/errors' -import type { Comment } from '$compiler/parser/interfaces' +import PARSER_ERRORS from '$promptl/error/errors' +import type { Comment } from '$promptl/parser/interfaces' import { Parser } from '..' diff --git a/packages/promptl/src/parser/state/mustache.test.ts b/packages/promptl/src/parser/state/mustache.test.ts index e075b2cc6..1bb9f6992 100644 --- a/packages/promptl/src/parser/state/mustache.test.ts +++ b/packages/promptl/src/parser/state/mustache.test.ts @@ -1,50 +1,38 @@ -import CompileError from '$compiler/error/error' +import CompileError from '$promptl/error/error' +import { getExpectedError } from '$promptl/test/helpers' import { describe, expect, it } from 'vitest' import parse from '..' import { TemplateNode } from '../interfaces' -const getExpectedError = ( - action: () => void, - errorClass: new () => T, -): T => { - try { - action() - } catch (err) { - expect(err).toBeInstanceOf(errorClass) - return err as T - } - throw new Error('Expected an error to be thrown') -} - -describe('Mustache', () => { - it('parses content between the mustache tag delimiters as mustage nodes', () => { +describe('Mustache', async () => { + it('parses content between the mustache tag delimiters as mustage nodes', async () => { const prompt = '{{ test }}' const fragment = parse(prompt) expect(fragment.children.length).toBe(1) expect(fragment.children[0]!.type).toBe('MustacheTag') }) - it('throws an error if the mustache tag is not closed', () => { + it('throws an error if the mustache tag is not closed', async () => { const prompt = '{{ test' - const error = getExpectedError(() => parse(prompt), CompileError) + const error = await getExpectedError(() => parse(prompt), CompileError) expect(error.code).toBe('unexpected-eof') }) - it('returns an IfBlock node if the mustache tag contains an if statement', () => { + it('returns an IfBlock node if the mustache tag contains an if statement', async () => { const prompt = '{{ if test }} something {{ endif }}' const fragment = parse(prompt) expect(fragment.children.length).toBe(1) expect(fragment.children[0]!.type).toBe('IfBlock') }) - it('fails if an IfBlock has not been closed', () => { + it('fails if an IfBlock has not been closed', async () => { const prompt = '{{ if test }} something' - const error = getExpectedError(() => parse(prompt), CompileError) + const error = await getExpectedError(() => parse(prompt), CompileError) expect(error.code).toBe('unclosed-block') }) - it('returns the correct expression', () => { + it('returns the correct expression', async () => { const prompt = '{{ if test == 3 }} something {{ endif }}' const fragment = parse(prompt) const ifBlock = fragment.children[0] as TemplateNode @@ -59,7 +47,7 @@ describe('Mustache', () => { expect(right.value).toBe(3) }) - it('returns an ElseBlock node if the mustache tag contains an else statement', () => { + it('returns an ElseBlock node if the mustache tag contains an else statement', async () => { const prompt = '{{ if test }} something {{ else }} something else {{ endif }}' const fragment = parse(prompt) @@ -70,13 +58,13 @@ describe('Mustache', () => { expect(ifBlock.else!.type).toBe('ElseBlock') }) - it('fails if an else statement is not within an if block', () => { + it('fails if an else statement is not within an if block', async () => { const prompt = '{{ else }}' - const error = getExpectedError(() => parse(prompt), CompileError) + const error = await getExpectedError(() => parse(prompt), CompileError) expect(error.code).toBe('invalid-else-placement') }) - it('ElseBlock has a condition when followed by an if', () => { + it('ElseBlock has a condition when followed by an if', async () => { const prompt = '{{ if test }} something {{ else if test2 }} something else {{ endif }}' const fragment = parse(prompt) @@ -93,7 +81,7 @@ describe('Mustache', () => { expect(elseBlock.expression).toBeDefined() }) - it('returns a chain of else if as a family of IfBlocks', () => { + it('returns a chain of else if as a family of IfBlocks', async () => { const prompt = '{{ if a }} a {{ else if b }} b {{ else if c }} c {{ else if d }} d {{ endif }}' const fragment = parse(prompt) @@ -119,35 +107,35 @@ describe('Mustache', () => { expect(dBlock.else).toBeUndefined() }) - it('fails if there is another else statement after an else statement without an if', () => { + it('fails if there is another else statement after an else statement without an if', async () => { const prompt = '{{ if test }} a {{ else }} b {{ else }} c {{ endif }}' - const error = getExpectedError(() => parse(prompt), CompileError) + const error = await getExpectedError(() => parse(prompt), CompileError) expect(error.code).toBe('invalid-else-placement') }) - it('returns a ForBlock node if the mustache tag contains a for statement', () => { + it('returns a ForBlock node if the mustache tag contains a for statement', async () => { const prompt = '{{ for item in items }} something {{ endfor }}' const fragment = parse(prompt) expect(fragment.children.length).toBe(1) expect(fragment.children[0]!.type).toBe('ForBlock') }) - it('fails if a ForBlock has not been closed', () => { + it('fails if a ForBlock has not been closed', async () => { const prompt = '{{ for item in items }} something' - const error = getExpectedError(() => parse(prompt), CompileError) + const error = await getExpectedError(() => parse(prompt), CompileError) expect(error.code).toBe('unclosed-block') }) - it('fails when an IfBlock is closed with an endfor, and the other way around', () => { + it('fails when an IfBlock is closed with an endfor, and the other way around', async () => { const prompt1 = '{{ if test }} something {{ endfor }}' const prompt2 = '{{ for item in items }} something {{ endif }}' - const error1 = getExpectedError(() => parse(prompt1), CompileError) - const error2 = getExpectedError(() => parse(prompt2), CompileError) + const error1 = await getExpectedError(() => parse(prompt1), CompileError) + const error2 = await getExpectedError(() => parse(prompt2), CompileError) expect(error1.code).toBe('unexpected-block-close') expect(error2.code).toBe('unexpected-block-close') }) - it('returns the correct expression, index and context', () => { + it('returns the correct expression, index and context', async () => { const prompt = '{{ for item, i in list }} something {{ endfor }}' const fragment = parse(prompt) const forBlock = fragment.children[0] as TemplateNode @@ -166,7 +154,7 @@ describe('Mustache', () => { expect(forBlock.expression!.name).toBe('list') }) - it('returns an else block within a for block', () => { + it('returns an else block within a for block', async () => { const prompt = '{{ for item in list }} content {{ else }} empty {{ endfor }}' const fragment = parse(prompt) diff --git a/packages/promptl/src/parser/state/mustache.ts b/packages/promptl/src/parser/state/mustache.ts index 67912f629..837327cc3 100644 --- a/packages/promptl/src/parser/state/mustache.ts +++ b/packages/promptl/src/parser/state/mustache.ts @@ -1,14 +1,14 @@ -import { CUSTOM_TAG_END, CUSTOM_TAG_START, KEYWORDS } from '$compiler/constants' -import PARSER_ERRORS from '$compiler/error/errors' -import { type Parser } from '$compiler/parser' +import { CUSTOM_TAG_END, CUSTOM_TAG_START, KEYWORDS } from '$promptl/constants' +import PARSER_ERRORS from '$promptl/error/errors' +import { type Parser } from '$promptl/parser' import type { BaseNode, ElseBlock, ForBlock, IfBlock, -} from '$compiler/parser/interfaces' -import readContext from '$compiler/parser/read/context' -import readExpression from '$compiler/parser/read/expression' +} from '$promptl/parser/interfaces' +import readContext from '$promptl/parser/read/context' +import readExpression from '$promptl/parser/read/expression' export function mustache(parser: Parser) { if (parser.match(CUSTOM_TAG_END)) { diff --git a/packages/promptl/src/parser/state/tag.ts b/packages/promptl/src/parser/state/tag.ts index aab0e8049..6d379775e 100644 --- a/packages/promptl/src/parser/state/tag.ts +++ b/packages/promptl/src/parser/state/tag.ts @@ -1,15 +1,15 @@ -import { CUSTOM_TAG_END, CUSTOM_TAG_START } from '$compiler/constants' -import CompileError from '$compiler/error/error' -import PARSER_ERRORS from '$compiler/error/errors' -import { type Parser } from '$compiler/parser' +import { CUSTOM_TAG_END, CUSTOM_TAG_START } from '$promptl/constants' +import CompileError from '$promptl/error/error' +import PARSER_ERRORS from '$promptl/error/errors' +import { type Parser } from '$promptl/parser' import { Attribute, ElementTag, TemplateNode, Text, -} from '$compiler/parser/interfaces' -import read_expression from '$compiler/parser/read/expression' -import { decode_character_references } from '$compiler/parser/utils/html' +} from '$promptl/parser/interfaces' +import read_expression from '$promptl/parser/read/expression' +import { decode_character_references } from '$promptl/parser/utils/html' const validTagName = /^!?[a-zA-Z]{1,}:?[a-zA-Z0-9-]*/ diff --git a/packages/promptl/src/parser/state/text.ts b/packages/promptl/src/parser/state/text.ts index 7b6593e19..21e59644d 100644 --- a/packages/promptl/src/parser/state/text.ts +++ b/packages/promptl/src/parser/state/text.ts @@ -1,6 +1,6 @@ -import { CUSTOM_TAG_START } from '$compiler/constants' -import { type Parser } from '$compiler/parser' -import type { Text } from '$compiler/parser/interfaces' +import { CUSTOM_TAG_START } from '$promptl/constants' +import { type Parser } from '$promptl/parser' +import type { Text } from '$promptl/parser/interfaces' const ENDS_WITH_ESCAPE_REGEX = /(?( + action: () => unknown, + errorClass: new () => T, +): Promise { + try { + await action() + } catch (err) { + expect(err).toBeInstanceOf(errorClass) + return err as T + } + throw new Error('Expected an error to be thrown') +} diff --git a/packages/promptl/src/types/index.ts b/packages/promptl/src/types/index.ts index 5e139a442..0d3b4d052 100644 --- a/packages/promptl/src/types/index.ts +++ b/packages/promptl/src/types/index.ts @@ -1,4 +1,4 @@ -import CompileError from '$compiler/error/error' +import CompileError from '$promptl/error/error' import { Message } from './message' diff --git a/packages/promptl/src/types/message.ts b/packages/promptl/src/types/message.ts index 708db61e2..5cf806500 100644 --- a/packages/promptl/src/types/message.ts +++ b/packages/promptl/src/types/message.ts @@ -1,15 +1,16 @@ +/* Message Content */ + export enum ContentType { text = 'text', image = 'image', toolCall = 'tool-call', - toolResult = 'tool-result', } -export enum MessageRole { - system = 'system', - user = 'user', - assistant = 'assistant', - tool = 'tool', +export enum ContentTypeTagName { + // This is used to translate between the tag name and the actual tag value + text = 'content-text', + image = 'content-image', + toolCall = 'tool-call', } interface IMessageContent { @@ -27,31 +28,22 @@ export type ImageContent = IMessageContent & { image: string | Uint8Array | Buffer | ArrayBuffer | URL } -export type ToolContent = { - type: ContentType.toolResult - toolCallId: string - toolName: string - result: unknown - isError?: boolean -} - -export type ToolRequestContent = { +export type ToolCallContent = { type: ContentType.toolCall toolCallId: string toolName: string - args: Record + toolArguments: Record } -export type MessageContent = - | TextContent - | ImageContent - | ToolContent - | ToolRequestContent +export type MessageContent = TextContent | ImageContent | ToolCallContent + +/* Message */ -export type ToolCall = { - id: string - name: string - arguments: Record +export enum MessageRole { + system = 'system', + user = 'user', + assistant = 'assistant', + tool = 'tool', } interface IMessage { @@ -69,16 +61,13 @@ export type UserMessage = IMessage & { name?: string } -export type AssistantMessage = { +export type AssistantMessage = IMessage & { role: MessageRole.assistant - toolCalls: ToolCall[] - content: string | ToolRequestContent[] | MessageContent[] } -export type ToolMessage = { +export type ToolMessage = IMessage & { role: MessageRole.tool - content: ToolContent[] - [key: string]: unknown + toolId: string } export type Message = diff --git a/packages/promptl/tsconfig.json b/packages/promptl/tsconfig.json index 76ecb50d8..cf4317f1b 100644 --- a/packages/promptl/tsconfig.json +++ b/packages/promptl/tsconfig.json @@ -10,7 +10,7 @@ "outDir": "dist", "noEmit": false, "paths": { - "$compiler/*": ["./src/*"], + "$promptl/*": ["./src/*"], "acorn": ["node_modules/@latitude-data/typescript-config/types/acorn"] } }, diff --git a/packages/promptl/vitest.config.ts b/packages/promptl/vitest.config.ts index 10832c54b..765910eaa 100644 --- a/packages/promptl/vitest.config.ts +++ b/packages/promptl/vitest.config.ts @@ -9,7 +9,7 @@ const root = dirname(filename) export default defineConfig({ resolve: { alias: { - $compiler: `${root}/src`, + $promptl: `${root}/src`, }, }, test: {