Skip to content

Commit

Permalink
PromptL compiler v1 - Chapter 2: Content tags and tool calls (#565)
Browse files Browse the repository at this point in the history
  • Loading branch information
csansoon authored Nov 11, 2024
1 parent 92d2626 commit 73e5fa6
Show file tree
Hide file tree
Showing 64 changed files with 650 additions and 525 deletions.
4 changes: 2 additions & 2 deletions packages/promptl/rollup.config.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function isInternalCircularDependency(warning) {

const __dirname = path.dirname(url.fileURLToPath(import.meta.url))
const aliasEntries = {
entries: [{ find: '$compiler', replacement: path.resolve(__dirname, 'src') }],
entries: [{ find: '$promptl', replacement: path.resolve(__dirname, 'src') }],
}

/** @type {import('rollup').RollupOptions} */
Expand Down Expand Up @@ -54,7 +54,7 @@ export default [
'node:crypto',
'yaml',
'crypto',
'zod'
'zod',
],
},
{
Expand Down
2 changes: 1 addition & 1 deletion packages/promptl/src/compiler/base/nodes/comment.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Comment } from '$compiler/parser/interfaces'
import { Comment } from '$promptl/parser/interfaces'

import { CompileNodeContext } from '../types'

Expand Down
2 changes: 1 addition & 1 deletion packages/promptl/src/compiler/base/nodes/config.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Config as ConfigNode } from '$compiler/parser/interfaces'
import { Config as ConfigNode } from '$promptl/parser/interfaces'
import yaml from 'yaml'

import { CompileNodeContext } from '../types'
Expand Down
6 changes: 3 additions & 3 deletions packages/promptl/src/compiler/base/nodes/for.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { getExpectedError } from '$compiler/compiler/test/helpers'
import CompileError from '$compiler/error/error'
import { Message, MessageContent, TextContent } from '$compiler/types'
import CompileError from '$promptl/error/error'
import { getExpectedError } from '$promptl/test/helpers'
import { Message, MessageContent, TextContent } from '$promptl/types'
import { describe, expect, it } from 'vitest'

import { render } from '../..'
Expand Down
6 changes: 3 additions & 3 deletions packages/promptl/src/compiler/base/nodes/for.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { hasContent, isIterable } from '$compiler/compiler/utils'
import errors from '$compiler/error/errors'
import { ForBlock } from '$compiler/parser/interfaces'
import { hasContent, isIterable } from '$promptl/compiler/utils'
import errors from '$promptl/error/errors'
import { ForBlock } from '$promptl/parser/interfaces'

import { CompileNodeContext, TemplateNodeWithStatus } from '../types'

Expand Down
2 changes: 1 addition & 1 deletion packages/promptl/src/compiler/base/nodes/fragment.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Fragment } from '$compiler/parser/interfaces'
import { Fragment } from '$promptl/parser/interfaces'

import { CompileNodeContext } from '../types'

Expand Down
2 changes: 1 addition & 1 deletion packages/promptl/src/compiler/base/nodes/if.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AssistantMessage, MessageRole, UserMessage } from '$compiler/types'
import { AssistantMessage, MessageRole, UserMessage } from '$promptl/types'
import { describe, expect, it, vi } from 'vitest'

import { render } from '../..'
Expand Down
2 changes: 1 addition & 1 deletion packages/promptl/src/compiler/base/nodes/if.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { IfBlock } from '$compiler/parser/interfaces'
import { IfBlock } from '$promptl/parser/interfaces'

import { CompileNodeContext } from '../types'

Expand Down
2 changes: 1 addition & 1 deletion packages/promptl/src/compiler/base/nodes/mustache.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { MustacheTag } from '$compiler/parser/interfaces'
import { MustacheTag } from '$promptl/parser/interfaces'

import { CompileNodeContext } from '../types'

Expand Down
14 changes: 3 additions & 11 deletions packages/promptl/src/compiler/base/nodes/tag.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,21 @@ import {
isContentTag,
isMessageTag,
isRefTag,
isToolCallTag,
} from '$compiler/compiler/utils'
import errors from '$compiler/error/errors'
} from '$promptl/compiler/utils'
import errors from '$promptl/error/errors'
import {
ChainStepTag,
ContentTag,
ElementTag,
MessageTag,
ReferenceTag,
ToolCallTag,
} from '$compiler/parser/interfaces'
} from '$promptl/parser/interfaces'

import { CompileNodeContext } from '../types'
import { compile as resolveChainStep } from './tags/chainStep'
import { compile as resolveContent } from './tags/content'
import { compile as resolveMessage } from './tags/message'
import { compile as resolveRef } from './tags/ref'
import { compile as resolveToolCall } from './tags/toolCall'

async function resolveTagAttributes({
node: tagNode,
Expand Down Expand Up @@ -81,11 +78,6 @@ export async function compile(props: CompileNodeContext<ElementTag>) {

const attributes = await resolveTagAttributes(props)

if (isToolCallTag(node)) {
await resolveToolCall(props as CompileNodeContext<ToolCallTag>, attributes)
return
}

if (isContentTag(node)) {
await resolveContent(props as CompileNodeContext<ContentTag>, attributes)
return
Expand Down
8 changes: 4 additions & 4 deletions packages/promptl/src/compiler/base/nodes/tags/chainStep.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { tagAttributeIsLiteral } from '$compiler/compiler/utils'
import errors from '$compiler/error/errors'
import { ChainStepTag } from '$compiler/parser/interfaces'
import { Config } from '$compiler/types'
import { tagAttributeIsLiteral } from '$promptl/compiler/utils'
import errors from '$promptl/error/errors'
import { ChainStepTag } from '$promptl/parser/interfaces'
import { Config } from '$promptl/types'

import { CompileNodeContext } from '../../types'

Expand Down
149 changes: 149 additions & 0 deletions packages/promptl/src/compiler/base/nodes/tags/content.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import { render } from '$promptl/compiler'
import { removeCommonIndent } from '$promptl/compiler/utils'
import CompileError from '$promptl/error/error'
import { getExpectedError } from '$promptl/test/helpers'
import {
ImageContent,
SystemMessage,
TextContent,
ToolCallContent,
UserMessage,
} from '$promptl/types'
import { describe, expect, it } from 'vitest'

describe('content tags', async () => {
it('adds stray text at root level as text contents inside a system message', async () => {
const prompt = 'Test message'
const result = await render({ prompt })
expect(result.messages.length).toBe(1)
const message = result.messages[0]! as SystemMessage
expect(message.role).toBe('system')
expect(message.content.length).toBe(1)
expect(message.content[0]!.type).toBe('text')
expect((message.content[0] as TextContent).text).toBe('Test message')
})

it('adds stray text inside a message as the message content', async () => {
const prompt = '<user>Test user message</user>'
const result = await render({ prompt })
expect(result.messages.length).toBe(1)
const message = result.messages[0]! as UserMessage
expect(message.role).toBe('user')
expect(message.content.length).toBe(1)
expect(message.content[0]!.type).toBe('text')
expect((message.content[0] as TextContent).text).toBe('Test user message')
})

it('Can add multiple text and image contents inside a message', async () => {
const prompt = removeCommonIndent(`
<user>
<content-text> Text 1 </content-text>
<content-image> Image 1 </content-image>
<content-text> Text 2 </content-text>
<content-text> Text 3 </content-text>
</user>
`)
const result = await render({ prompt })
expect(result.messages.length).toBe(1)
const message = result.messages[0]! as UserMessage
expect(message.role).toBe('user')
expect(message.content.length).toBe(4)
expect(message.content[0]!.type).toBe('text')
expect((message.content[0] as TextContent).text).toBe('Text 1')
expect(message.content[1]!.type).toBe('image')
expect((message.content[1] as ImageContent).image).toBe('Image 1')
expect(message.content[2]!.type).toBe('text')
expect((message.content[2] as TextContent).text).toBe('Text 2')
expect(message.content[3]!.type).toBe('text')
expect((message.content[3] as TextContent).text).toBe('Text 3')
})

it('Can interpolate stray text and defined content tags', async () => {
const prompt = removeCommonIndent(`
Text 1
<content-text> Text 2 </content-text>
Text 3
`)
const result = await render({ prompt })

expect(result.messages.length).toBe(1)
const message = result.messages[0]!
expect(message.content.length).toBe(3)
expect(message.content[0]!.type).toBe('text')
expect((message.content[0] as TextContent).text).toBe('Text 1')
expect(message.content[1]!.type).toBe('text')
expect((message.content[1] as TextContent).text).toBe('Text 2')
expect(message.content[2]!.type).toBe('text')
expect((message.content[2] as TextContent).text).toBe('Text 3')
})

it('Can use the general "content" tag with a type attribute', async () => {
const prompt = removeCommonIndent(`
<content type="text"> Text </content>
<content type="image"> Image </content>
`)
const result = await render({ prompt })

expect(result.messages.length).toBe(1)
const message = result.messages[0]!
expect(message.content.length).toBe(2)
expect(message.content[0]!.type).toBe('text')
expect((message.content[0] as TextContent).text).toBe('Text')
expect(message.content[1]!.type).toBe('image')
expect((message.content[1] as ImageContent).image).toBe('Image')
})

it('Cannot include a content tag inside another content tag', async () => {
const prompt = removeCommonIndent(`
<content type="text">
<content type="text"> Text </content>
</content>
`)
const error = await getExpectedError(() => render({ prompt }), CompileError)
expect(error.code).toBe('content-tag-inside-content')
})
})

describe('tool-call tags', async () => {
it('returns tool calls in the content of assistant messages', async () => {
const prompt = removeCommonIndent(`
<assistant>
<tool-call name="get_weather" id="123" />
</assistant>
`)
const result = await render({ prompt })

expect(result.messages.length).toBe(1)
const message = result.messages[0]! as SystemMessage
expect(message.role).toBe('assistant')
expect(message.content.length).toBe(1)
expect(message.content[0]!.type).toBe('tool-call')
const toolCall = message.content[0]! as ToolCallContent
expect(toolCall.toolName).toBe('get_weather')
expect(toolCall.toolCallId).toBe('123')
})

it('fails when not in an assistant message tag', async () => {
const prompt = removeCommonIndent(`
<user>
<tool-call name="get_weather" id="123" />
</user>
`)

const error = await getExpectedError(() => render({ prompt }), CompileError)
expect(error.code).toBe('invalid-tool-call-placement')
})

it('fails when a tool call is inside another tool call', async () => {
const prompt = removeCommonIndent(`
<assistant>
<tool-call name="get_weather" id="123">
<tool-call name="get_weather" id="456" />
</tool-call>
</assistant>
`)

const error = await getExpectedError(() => render({ prompt }), CompileError)
expect(error.code).toBe('content-tag-inside-content')
})
})
77 changes: 63 additions & 14 deletions packages/promptl/src/compiler/base/nodes/tags/content.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { removeCommonIndent } from '$compiler/compiler/utils'
import errors from '$compiler/error/errors'
import { ContentTag } from '$compiler/parser/interfaces'
import { ContentType } from '$compiler/types'
import { removeCommonIndent } from '$promptl/compiler/utils'
import {
CUSTOM_CONTENT_TAG,
CUSTOM_CONTENT_TYPE_ATTR,
} from '$promptl/constants'
import errors from '$promptl/error/errors'
import { ContentTag } from '$promptl/parser/interfaces'
import { ContentType, ContentTypeTagName } from '$promptl/types'

import { CompileNodeContext } from '../../types'

Expand Down Expand Up @@ -32,19 +36,64 @@ export async function compile(
}
const textContent = removeCommonIndent(popStrayText())

// TODO: This if else is probably not required but the types enforce it.
// Improve types.
if (node.name === 'text') {
let type: ContentType
if (node.name === CUSTOM_CONTENT_TAG) {
if (attributes[CUSTOM_CONTENT_TYPE_ATTR] === undefined) {
baseNodeError(errors.messageTagWithoutRole, node)
}
type = attributes[CUSTOM_CONTENT_TYPE_ATTR] as ContentType
delete attributes[CUSTOM_CONTENT_TYPE_ATTR]
} else {
const contentTypeKeysFromTagName = Object.fromEntries(
Object.entries(ContentTypeTagName).map(([k, v]) => [v, k]),
)
type =
ContentType[
contentTypeKeysFromTagName[node.name] as keyof typeof ContentType
]
}

if (type === ContentType.text) {
addContent({
...attributes,
type: ContentType.text,
text: textContent,
node,
content: {
...attributes,
type: ContentType.text,
text: textContent,
},
})
} else {
return
}

if (type === ContentType.image) {
addContent({
...attributes,
type: ContentType.image,
image: textContent,
node,
content: {
...attributes,
type: ContentType.image,
image: textContent,
},
})
return
}

if (type == ContentType.toolCall) {
const { id, name, ...rest } = attributes
if (!id) baseNodeError(errors.toolCallTagWithoutId, node)
if (!name) baseNodeError(errors.toolCallWithoutName, node)

addContent({
node,
content: {
...rest,
type: ContentType.toolCall,
toolCallId: String(id),
toolName: String(name),
toolArguments: {}, // TODO: Issue for a future PR
},
})
return
}

baseNodeError(errors.invalidContentType(type), node)
}
Loading

0 comments on commit 73e5fa6

Please sign in to comment.