Skip to content

Commit

Permalink
Compile iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
csansoon committed Jul 29, 2024
1 parent a4b833e commit 7ddcc9a
Show file tree
Hide file tree
Showing 15 changed files with 822 additions and 53 deletions.
9 changes: 6 additions & 3 deletions packages/compiler/src/compiler/base/nodes/config.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import { Config } from '$compiler/parser/interfaces'
import { Config as ConfigNode } from '$compiler/parser/interfaces'

import { CompileNodeContext } from '../types'

export async function compile(_: CompileNodeContext<Config>) {
/* do nothing */
export async function compile({
node,
setConfig,
}: CompileNodeContext<ConfigNode>): Promise<void> {
setConfig(node.value)
}
29 changes: 28 additions & 1 deletion packages/compiler/src/compiler/base/nodes/each.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@ import { hasContent, isIterable } from '$compiler/compiler/utils'
import errors from '$compiler/error/errors'
import { EachBlock } from '$compiler/parser/interfaces'

import { CompileNodeContext } from '../types'
import { CompileNodeContext, TemplateNodeWithStatus } from '../types'

type EachNodeWithStatus = TemplateNodeWithStatus & {
status: TemplateNodeWithStatus['status'] & {
loopIterationIndex: number
}
}

export async function compile({
node,
Expand All @@ -13,6 +19,12 @@ export async function compile({
resolveExpression,
expressionError,
}: CompileNodeContext<EachBlock>) {
const nodeWithStatus = node as EachNodeWithStatus
nodeWithStatus.status = {
...nodeWithStatus.status,
scopePointers: scope.getPointers(),
}

const iterableElement = await resolveExpression(node.expression, scope)
if (!isIterable(iterableElement) || !(await hasContent(iterableElement))) {
const childScope = scope.copy()
Expand Down Expand Up @@ -44,7 +56,14 @@ export async function compile({
}

let i = 0

for await (const element of iterableElement) {
if (i < (nodeWithStatus.status.loopIterationIndex ?? 0)) {
i++
continue
}
nodeWithStatus.status.loopIterationIndex = i

const localScope = scope.copy()
localScope.set(contextVarName, element)
if (indexVarName) {
Expand All @@ -54,14 +73,22 @@ export async function compile({
}
localScope.set(indexVarName, indexValue)
}

for await (const childNode of node.children ?? []) {
await resolveBaseNode({
node: childNode,
scope: localScope,
isInsideMessageTag,
isInsideContentTag,
completedValue: `step_${i}`,
})
}

i++
}

nodeWithStatus.status = {
...nodeWithStatus.status,
loopIterationIndex: 0,
}
}
13 changes: 13 additions & 0 deletions packages/compiler/src/compiler/base/nodes/tag.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import {
isChainStepTag,
isContentTag,
isMessageTag,
isRefTag,
isToolCallTag,
} from '$compiler/compiler/utils'
import errors from '$compiler/error/errors'
import {
ChainStepTag,
ContentTag,
ElementTag,
MessageTag,
Expand All @@ -14,6 +16,7 @@ import {
} from '$compiler/parser/interfaces'

import { CompileNodeContext } from '../types'
import { compile as resolveChainStep } from './tags/chainStep'
import { compile as resolveContent } from './tags/content'
import { compile as resolveMessage } from './tags/message'
import { compile as resolveRef } from './tags/ref'
Expand Down Expand Up @@ -91,8 +94,18 @@ export async function compile(props: CompileNodeContext<ElementTag>) {
return
}

if (isChainStepTag(node)) {
await resolveChainStep(
props as CompileNodeContext<ChainStepTag>,
attributes,
)
return
}

//@ts-ignore - Linter knows there *should* not be another type of tag.
baseNodeError(errors.unknownTag(node.name), node)

//@ts-ignore - ditto
for await (const childNode of node.children ?? []) {
await resolveBaseNode({
node: childNode,
Expand Down
53 changes: 53 additions & 0 deletions packages/compiler/src/compiler/base/nodes/tags/chainStep.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { tagAttributeIsLiteral } from '$compiler/compiler/utils'
import errors from '$compiler/error/errors'
import { ChainStepTag } from '$compiler/parser/interfaces'
import { Config, ContentType } from '$compiler/types'

import { CompileNodeContext } from '../../types'

function isValidConfig(value: unknown): value is Config | undefined {
if (value === undefined) return true
if (Array.isArray(value)) return false
return typeof value === 'object'
}

export async function compile(
{
node,
scope,
popStepResponse,
addMessage,
groupContent,
baseNodeError,
stop,
}: CompileNodeContext<ChainStepTag>,
attributes: Record<string, unknown>,
) {
const stepResponse = popStepResponse()

const { as: varName, ...config } = attributes

if (stepResponse === undefined) {
if (!isValidConfig(config)) {
baseNodeError(errors.invalidStepConfig, node)
}

stop(config as Config)
}

if ('as' in attributes) {
if (!tagAttributeIsLiteral(node, 'as')) {
baseNodeError(errors.invalidStaticAttribute('as'), node)
}

const responseText = stepResponse?.content
.filter((c) => c.type === ContentType.text)
.map((c) => c.value)
.join(' ')

scope.set(String(varName), responseText)
}

groupContent()
addMessage(stepResponse!)
}
23 changes: 21 additions & 2 deletions packages/compiler/src/compiler/base/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import Scope from '$compiler/compiler/scope'
import Scope, { ScopePointers } from '$compiler/compiler/scope'
import { TemplateNode } from '$compiler/parser/interfaces'
import { Message, MessageContent } from '$compiler/types'
import {
AssistantMessage,
Config,
Message,
MessageContent,
} from '$compiler/types'
import type { Node as LogicalExpression } from 'estree'

import { ResolveBaseNodeProps, ToolCallReference } from '../types'
Expand All @@ -27,6 +32,16 @@ type RaiseErrorFn<T = void | never, N = TemplateNode | LogicalExpression> = (
node: N,
) => T

type NodeStatus = {
completedAs?: unknown
scopePointers?: ScopePointers | undefined
eachIterationIndex?: number
}

export type TemplateNodeWithStatus = TemplateNode & {
status?: NodeStatus
}

export type CompileNodeContext<N extends TemplateNode> = {
node: N
scope: Scope
Expand All @@ -41,6 +56,7 @@ export type CompileNodeContext<N extends TemplateNode> = {
isInsideMessageTag: boolean
isInsideContentTag: boolean

setConfig: (config: Config) => void
addMessage: (message: Message) => void
addStrayText: (text: string) => void
popStrayText: () => string
Expand All @@ -50,4 +66,7 @@ export type CompileNodeContext<N extends TemplateNode> = {
groupContent: () => void
addToolCall: (toolCallRef: ToolCallReference) => void
popToolCalls: () => ToolCallReference[]
popStepResponse: () => AssistantMessage | undefined

stop: (config?: Config) => void
}
Loading

0 comments on commit 7ddcc9a

Please sign in to comment.