diff --git a/packages/ai/src/LLMStepInstance.ts b/packages/ai/src/LLMStepInstance.ts index e72d2af3de..c638460f5f 100644 --- a/packages/ai/src/LLMStepInstance.ts +++ b/packages/ai/src/LLMStepInstance.ts @@ -1,4 +1,4 @@ -import { BaseArtifact, makeArtifact } from "./Artifact.js"; +import { ArtifactKind, BaseArtifact, makeArtifact } from "./Artifact.js"; import { calculatePriceMultipleCalls, LlmMetrics, @@ -10,6 +10,7 @@ import { Inputs, IOShape, LLMStepTemplate, + OutputKind, Outputs, StepState, } from "./LLMStepTemplate.js"; @@ -27,14 +28,22 @@ export type StepParams = { id: string; sequentialId: number; template: LLMStepTemplate; - state: StepState; + state: StepState; inputs: Inputs; - outputs: Partial>; startTime: number; conversationMessages: Message[]; llmMetricsList: LlmMetrics[]; }; +class FailError extends Error { + constructor( + public type: ErrorType, + message: string + ) { + super(message); + } +} + export class LLMStepInstance< const Shape extends IOShape = IOShape, const WorkflowShape extends IOShape = IOShape, @@ -43,10 +52,8 @@ export class LLMStepInstance< public sequentialId: StepParams["sequentialId"]; public readonly template: StepParams["template"]; - private state: StepParams["state"]; - // must be public for `instanceOf` type guard to work - public readonly _outputs: StepParams["outputs"]; + private _state: StepParams["state"]; public readonly inputs: StepParams["inputs"]; @@ -56,7 +63,7 @@ export class LLMStepInstance< // These two fields are not serialized private logger: Logger; - private workflow: Workflow; + public workflow: Workflow; private constructor( params: StepParams & { @@ -70,8 +77,7 @@ export class LLMStepInstance< this.conversationMessages = params.conversationMessages; this.startTime = params.startTime; - this.state = params.state; - this._outputs = params.outputs; + this._state = params.state; this.template = params.template; this.inputs = params.inputs; @@ -93,7 +99,6 @@ export class LLMStepInstance< llmMetricsList: [], startTime: Date.now(), state: { kind: "PENDING" }, - outputs: {}, ...params, }); } @@ -113,7 +118,7 @@ export class LLMStepInstance< } async _run() { - if (this.state.kind !== "PENDING") { + if (this._state.kind !== "PENDING") { return; } @@ -123,27 +128,51 @@ export class LLMStepInstance< return; } - const executeContext: ExecuteContext = { - setOutput: (key, value) => this.setOutput(key, value), + const executeContext: ExecuteContext = { log: (log) => this.log(log), queryLLM: (promptPair) => this.queryLLM(promptPair), - fail: (errorType, message) => this.fail(errorType, message), + fail: (errorType, message) => { + throw new FailError(errorType, message); + }, }; try { - await this.template.execute(executeContext, this.inputs); - } catch (error) { - this.fail( - "MINOR", - error instanceof Error ? error.message : String(error) - ); - return; - } - - const hasFailed = (this.state as StepState).kind === "FAILED"; + const result = await this.template.execute(executeContext, this.inputs); + + const outputs = {} as Outputs; // will be filled in below + + // This code is unfortunately full of `as any`, caused by `Object.keys` limitations in TypeScript. + Object.keys(result).forEach((key: keyof typeof result) => { + const value = result[key]; + if (value instanceof BaseArtifact) { + // already existing artifact - probably passed through from another step + outputs[key] = value as any; + } else if (!value) { + outputs[key] = undefined as any; + } else { + const outputKind = this.template.shape.outputs[key] as OutputKind; + const artifactKind = outputKind.endsWith("?") + ? (outputKind.slice(0, -1) as ArtifactKind) + : (outputKind as ArtifactKind); + + outputs[key] = makeArtifact(artifactKind, value as any, this) as any; + } + }); - if (!hasFailed) { - this.state = { kind: "DONE", durationMs: this.calculateDuration() }; + this._state = { + kind: "DONE", + durationMs: this.calculateDuration(), + outputs, + }; + } catch (error) { + if (error instanceof FailError) { + this.fail(error.type, error.message); + } else { + this.fail( + "MINOR", // TODO - critical? + error instanceof Error ? error.message : String(error) + ); + } } } @@ -155,8 +184,8 @@ export class LLMStepInstance< await this._run(); - const completionMessage = `Step "${this.template.name}" completed with status: ${this.state.kind}${ - this.state.kind !== "PENDING" && `, in ${this.state.durationMs / 1000}s` + const completionMessage = `Step "${this.template.name}" completed with status: ${this._state.kind}${ + this._state.kind !== "PENDING" && `, in ${this._state.durationMs / 1000}s` }`; this.log({ @@ -166,15 +195,11 @@ export class LLMStepInstance< } getState() { - return this.state; + return this._state; } getDuration() { - return this.state.kind === "PENDING" ? 0 : this.state.durationMs; - } - - getOutputs(): this["_outputs"] { - return this._outputs; + return this._state.kind === "PENDING" ? 0 : this._state.durationMs; } getInputs() { @@ -208,34 +233,21 @@ export class LLMStepInstance< } isDone() { - return this.state.kind === "DONE"; + return this._state.kind === "DONE"; } - // private methods - - private setOutput>( - key: K, - value: Outputs[K] | Outputs[K]["value"] - ): void { - if (key in this._outputs) { - this.fail( - "CRITICAL", - `Output ${key} is already set. This is a bug with the workflow code.` - ); - return; - } - - if (value instanceof BaseArtifact) { - // already existing artifact - probably passed through from another step - this._outputs[key] = value; - } else { - const kind = this.template.shape.outputs[ - key - ] as Outputs[K]["kind"]; - this._outputs[key] = makeArtifact(kind, value as any, this) as any; + getPreviousStep() { + const steps = this.workflow.getSteps(); + const index = steps.indexOf(this as LLMStepInstance); + if (index < 1) { + // first step or not found + return undefined; } + return steps[index - 1]; } + // private methods + private log(log: LogEntry): void { this.logger.log(log, { workflowId: this.workflow.id, @@ -245,7 +257,7 @@ export class LLMStepInstance< private fail(errorType: ErrorType, message: string) { this.log({ type: "error", message }); - this.state = { + this._state = { kind: "FAILED", durationMs: this.calculateDuration(), errorType, @@ -320,14 +332,13 @@ export class LLMStepInstance< // Serialization/deserialization // StepParams don't contain the workflow reference, to to avoid circular dependencies - toParams(): StepParams { + toParams(): StepParams { return { id: this.id, sequentialId: this.sequentialId, template: this.template, - state: this.state, + state: this._state, inputs: this.inputs, - outputs: this._outputs, startTime: this.startTime, conversationMessages: this.conversationMessages, llmMetricsList: this.llmMetricsList, @@ -342,7 +353,13 @@ export class LLMStepInstance< } static deserialize( - { templateName, inputIds, outputIds, ...params }: SerializedStep, + { + templateName, + inputIds, + outputIds: legacyOutputIds, + state: serializedState, + ...params + }: SerializedStep, visitor: AiDeserializationVisitor ): StepParams { const template: LLMStepTemplate = getStepTemplateByName(templateName); @@ -352,18 +369,28 @@ export class LLMStepInstance< visitor.artifact(inputId), ]) ); - const outputs = Object.fromEntries( - Object.entries(outputIds).map(([name, outputId]) => [ - name, - visitor.artifact(outputId), - ]) - ); + + let state: StepState; + + if (serializedState.kind === "DONE") { + const { outputIds, ...serializedStateWithoutOutputs } = serializedState; + state = { + ...serializedStateWithoutOutputs, + outputs: Object.fromEntries( + Object.entries(legacyOutputIds ?? outputIds).map( + ([name, outputId]) => [name, visitor.artifact(outputId)] + ) + ), + }; + } else { + state = serializedState; + } return { ...params, template, inputs, - outputs, + state, }; } } @@ -371,12 +398,24 @@ export class LLMStepInstance< export function serializeStepParams( params: StepParams, visitor: AiSerializationVisitor -) { +): SerializedStep { return { id: params.id, sequentialId: params.sequentialId, templateName: params.template.name, - state: params.state, + state: + params.state.kind === "DONE" + ? { + ...params.state, + outputIds: Object.fromEntries( + Object.entries(params.state.outputs) + .map(([key, output]) => + output ? [key, visitor.artifact(output)] : undefined + ) + .filter((x) => x !== undefined) + ), + } + : params.state, startTime: params.startTime, conversationMessages: params.conversationMessages, llmMetricsList: params.llmMetricsList, @@ -386,21 +425,23 @@ export function serializeStepParams( visitor.artifact(input), ]) ), - outputIds: Object.fromEntries( - Object.entries(params.outputs) - .map(([key, output]) => - output ? [key, visitor.artifact(output)] : undefined - ) - .filter((x) => x !== undefined) - ), }; } +type SerializedState = + | (Omit, { kind: "DONE" }>, "outputs"> & { + outputIds: Record; + }) + | Exclude, { kind: "DONE" }>; + export type SerializedStep = Omit< StepParams, - "inputs" | "outputs" | "template" + "inputs" | "outputs" | "template" | "state" > & { templateName: string; inputIds: Record; - outputIds: Record; + // Legacy - in the modern format we store outputs on SerializedState, but we still support this field for old workflows. + // Can be removed if we deserialize and serialize again all workflows in the existing database. + outputIds?: Record; + state: SerializedState; }; diff --git a/packages/ai/src/LLMStepTemplate.ts b/packages/ai/src/LLMStepTemplate.ts index f34105160c..ab8a2e4d76 100644 --- a/packages/ai/src/LLMStepTemplate.ts +++ b/packages/ai/src/LLMStepTemplate.ts @@ -4,13 +4,14 @@ import { PromptPair } from "./prompts.js"; export type ErrorType = "CRITICAL" | "MINOR"; -export type StepState = +export type StepState = | { kind: "PENDING"; } | { kind: "DONE"; durationMs: number; + outputs: Outputs; } | { kind: "FAILED"; @@ -19,9 +20,11 @@ export type StepState = message: string; }; +export type OutputKind = ArtifactKind | `${ArtifactKind}?`; + export type IOShape< I extends Record = Record, - O extends Record = Record, + O extends Record = Record, > = { inputs: I; outputs: O; @@ -31,32 +34,93 @@ export type Inputs> = { [K in keyof Shape["inputs"]]: Extract; }; +// Possible output value based on output kind. +// For example, if the step specifies `code: "code"`, then the value can be +// a code artifact or a string (the code itself). +// If the step specifies `code: "code?"`, then the value can be any of those or it can be undefined. +type AllowedOutputValue = Kind extends ArtifactKind + ? Extract + : Kind extends `${infer ArtifactKind}?` + ? Extract | undefined + : never; + export type Outputs> = { - [K in keyof Shape["outputs"]]: Extract< - Artifact, - { kind: Shape["outputs"][K] } - >; + [K in keyof Shape["outputs"]]: AllowedOutputValue; +}; + +// What the step implementation returns. +type StepExecuteResult = { + [K in keyof Shape["outputs"]]: // for any output field... + | AllowedOutputValue // ...return either the artifact object... + | NonNullable>["value"]; // or its value, as a shorthand }; // ExecuteContext is the context that's available to the step implementation. // We intentionally don't pass the reference to the step implementation, so that steps won't mess with their internal state. -export type ExecuteContext = { - setOutput>( - key: K, - value: Outputs[K] | Outputs[K]["value"] // can be either the artifact or the value inside the artifact - ): void; +export type ExecuteContext = { queryLLM(promptPair: PromptPair): Promise; log(log: LogEntry): void; - fail(errorType: ErrorType, message: string): void; + fail(errorType: ErrorType, message: string): never; +}; + +/* + * This is a type that's used to prepare the step for execution. `PreparedStep` + * will be converted by `Workflow.addStep()` to `LLMStepInstance`. + * + * Notes: + * 1) We can't just use LLMStepInstance here, this would cause nasty circular + * dependencies (sometimes even Node.js crashes! seriously). + * 2) We can't use a pair of template and inputs, because their generic + * parameters must be synchronized, and TypeScript can't express that. So we + * need to produce a PreparedStep from a template method. + * + * So, in total, we have three related types: + * - `LLMStepTemplate` + * - `LLMStepInstance` + * - `PreparedStep` + * + * (And also `StepParams` and `SerializedStep` in `LLMStepInstance`...) + * + * If this sounds messy, it is. I really tried to find other approaches, but + * fixing this would require giving up on `Shape` generic parameters, which has + * its own issues. + */ +export type PreparedStep = { + template: LLMStepTemplate; + inputs: Inputs; }; export class LLMStepTemplate { constructor( public readonly name: string, public readonly shape: Shape, + /** + * This function is the main API for implementing steps. + * + * It takes the context and inputs, and returns the outputs. + * + * The returned value must match the template's shape. For example, if the + * shape has `{ outputs: { result: 'code' }}`, then the result must be an + * object with a `result` property that's either a code artifact, or a value + * of such artifact. + * + * It's possible to use optional outputs, but only if the shape has `?` in + * the output kind. In this case, you must return `{ result: undefined }`. + * + * See the individual step implementations in `./steps/*.ts` for examples. + * + * If the step has failed, you can call `context.fail()` to stop execution. + */ public readonly execute: ( - context: ExecuteContext, + context: ExecuteContext, inputs: Inputs - ) => Promise + ) => Promise> ) {} + + prepare(inputs: Inputs): PreparedStep { + return { + template: this, + inputs, + }; + } } diff --git a/packages/ai/src/generateSummary.ts b/packages/ai/src/generateSummary.ts index f890a4592b..99cd91e3e6 100644 --- a/packages/ai/src/generateSummary.ts +++ b/packages/ai/src/generateSummary.ts @@ -1,6 +1,3 @@ -import fs from "fs"; -import path from "path"; - import { Artifact, ArtifactKind } from "./Artifact.js"; import { Code } from "./Code.js"; import { calculatePriceMultipleCalls } from "./LLMClient.js"; @@ -99,10 +96,15 @@ function generateDetailedStepLogs( detailedLogs += getFullArtifact(key, artifact); } - detailedLogs += "### Outputs:\n"; - for (const [key, artifact] of Object.entries(step.getOutputs())) { - if (!artifact) continue; - detailedLogs += getFullArtifact(key, artifact); + { + const state = step.getState(); + if (state.kind === "DONE") { + detailedLogs += "### Outputs:\n"; + for (const [key, artifact] of Object.entries(state.outputs)) { + if (!artifact) continue; + detailedLogs += getFullArtifact(key, artifact); + } + } } detailedLogs += "### Logs:\n"; @@ -227,16 +229,3 @@ ${code.source} `; } } - -export function saveSummaryToFile(summary: string): void { - const logDir = path.join(process.cwd(), "logs"); - if (!fs.existsSync(logDir)) { - fs.mkdirSync(logDir, { recursive: true }); - } - - const timestamp = new Date().toISOString().replace(/:/g, "-"); - const logFile = path.join(logDir, `squiggle_summary_${timestamp}.md`); - - fs.writeFileSync(logFile, summary); - console.log(`Summary saved to ${logFile}`); -} diff --git a/packages/ai/src/steps/adjustToFeedbackStep.ts b/packages/ai/src/steps/adjustToFeedbackStep.ts index 7b778b5a6b..0c53a170d6 100644 --- a/packages/ai/src/steps/adjustToFeedbackStep.ts +++ b/packages/ai/src/steps/adjustToFeedbackStep.ts @@ -59,7 +59,7 @@ export const adjustToFeedbackStep = new LLMStepTemplate( code: "code", }, outputs: { - code: "code", + code: "code?", }, }, async (context, { prompt, code }) => { @@ -73,7 +73,7 @@ export const adjustToFeedbackStep = new LLMStepTemplate( if (!completion) { // failed - return; + return context.fail("CRITICAL", "LLM failed to provide a response"); } // handle adjustment response @@ -87,7 +87,7 @@ export const adjustToFeedbackStep = new LLMStepTemplate( type: "info", message: "LLM determined no adjustment is needed", }); - return; + return { code: undefined }; } if ( @@ -101,19 +101,17 @@ export const adjustToFeedbackStep = new LLMStepTemplate( message: "FAIL: " + diffResponse.value, }); // try again - context.setOutput("code", code); - return; + return { code }; } const adjustedCode = await codeStringToCode(diffResponse.value); - context.setOutput("code", adjustedCode); - return; + return { code: adjustedCode }; } else { context.log({ type: "info", message: "No adjustments provided, considering process complete", }); - return; + return { code: undefined }; } } ); diff --git a/packages/ai/src/steps/fixCodeUntilItRunsStep.ts b/packages/ai/src/steps/fixCodeUntilItRunsStep.ts index 0689b504bc..b70705afe6 100644 --- a/packages/ai/src/steps/fixCodeUntilItRunsStep.ts +++ b/packages/ai/src/steps/fixCodeUntilItRunsStep.ts @@ -86,16 +86,19 @@ export const fixCodeUntilItRunsStep = new LLMStepTemplate( const promptPair = editExistingSquiggleCodePrompt(code.value); const completion = await context.queryLLM(promptPair); - if (completion) { - const newCodeResult = await diffCompletionContentToCode( - completion, - code.value - ); - if (newCodeResult.ok) { - context.setOutput("code", newCodeResult.value); - } else { - context.fail("MINOR", newCodeResult.value); - } + if (!completion) { + // failed + return context.fail("CRITICAL", "LLM failed to provide a response"); + } + + const newCodeResult = await diffCompletionContentToCode( + completion, + code.value + ); + if (newCodeResult.ok) { + return { code: newCodeResult.value }; + } else { + return context.fail("MINOR", newCodeResult.value); } } ); diff --git a/packages/ai/src/steps/generateCodeStep.ts b/packages/ai/src/steps/generateCodeStep.ts index fde584d932..d57d1e2f63 100644 --- a/packages/ai/src/steps/generateCodeStep.ts +++ b/packages/ai/src/steps/generateCodeStep.ts @@ -138,16 +138,15 @@ export const generateCodeStep = new LLMStepTemplate( const promptPair = generateNewSquiggleCodePrompt(prompt.value); const completion = await context.queryLLM(promptPair); - if (completion) { - const state = await generationCompletionContentToCode(completion); - if (state.ok) { - context.setOutput("code", state.value); - } else { - context.log({ - type: "error", - message: state.value, - }); - } + if (!completion) { + return context.fail("MINOR", "No completion"); + } + + const state = await generationCompletionContentToCode(completion); + if (state.ok) { + return { code: state.value }; + } else { + return context.fail("MINOR", state.value); } } ); diff --git a/packages/ai/src/steps/matchStyleGuideStep.ts b/packages/ai/src/steps/matchStyleGuideStep.ts index 0be376e6c8..d6f515c840 100644 --- a/packages/ai/src/steps/matchStyleGuideStep.ts +++ b/packages/ai/src/steps/matchStyleGuideStep.ts @@ -67,7 +67,7 @@ export const matchStyleGuideStep = new LLMStepTemplate( code: "code", }, outputs: { - code: "code", + code: "code?", }, }, async (context, { prompt, code }) => { @@ -81,7 +81,7 @@ export const matchStyleGuideStep = new LLMStepTemplate( if (!completion) { // failed - return; + return context.fail("MINOR", "No completion"); } // handle adjustment response @@ -95,7 +95,7 @@ export const matchStyleGuideStep = new LLMStepTemplate( type: "info", message: "LLM determined no adjustment is needed", }); - return; + return { code: undefined }; } if ( @@ -109,19 +109,17 @@ export const matchStyleGuideStep = new LLMStepTemplate( message: "FAIL: " + diffResponse.value, }); // try again - context.setOutput("code", code); - return; + return { code }; } const adjustedCode = await codeStringToCode(diffResponse.value); - context.setOutput("code", adjustedCode); - return; + return { code: adjustedCode }; } else { context.log({ type: "info", message: "No adjustments provided, considering process complete", }); - return; + return { code: undefined }; } } ); diff --git a/packages/ai/src/steps/runAndFormatCodeStep.ts b/packages/ai/src/steps/runAndFormatCodeStep.ts index deae894a65..0c5f3ec1fd 100644 --- a/packages/ai/src/steps/runAndFormatCodeStep.ts +++ b/packages/ai/src/steps/runAndFormatCodeStep.ts @@ -7,8 +7,8 @@ export const runAndFormatCodeStep = new LLMStepTemplate( inputs: { source: "source" }, outputs: { code: "code" }, }, - async (context, { source }) => { + async (_, { source }) => { const code = await codeStringToCode(source.value); - context.setOutput("code", code); + return { code }; } ); diff --git a/packages/ai/src/workflows/Workflow.ts b/packages/ai/src/workflows/Workflow.ts index fa9da62d36..db30ac33d3 100644 --- a/packages/ai/src/workflows/Workflow.ts +++ b/packages/ai/src/workflows/Workflow.ts @@ -8,7 +8,7 @@ import { Message, } from "../LLMClient.js"; import { LLMStepInstance } from "../LLMStepInstance.js"; -import { Inputs, IOShape, LLMStepTemplate } from "../LLMStepTemplate.js"; +import { Inputs, IOShape, PreparedStep } from "../LLMStepTemplate.js"; import { TimestampedLogEntry } from "../Logger.js"; import { LlmId } from "../modelConfigs.js"; import { @@ -96,6 +96,11 @@ export type WorkflowEventListener< Shape extends IOShape, > = (event: WorkflowEvent) => void; +export type StepTransitionRule = ( + step: LLMStepInstance, + helpers: WorkflowGuardHelpers +) => NextStepAction; + /** * This class is responsible for managing the steps in a workflow. * @@ -107,12 +112,12 @@ export class Workflow { public id: string; public readonly template: WorkflowTemplate; public readonly inputs: Inputs; - private started: boolean = false; public llmConfig: LlmConfig; public startTime: number; private steps: LLMStepInstance[]; + private transitionRule: StepTransitionRule; public llmClient: LLMClient; @@ -130,19 +135,20 @@ export class Workflow { params.openaiApiKey, params.anthropicApiKey ); + + this.transitionRule = params.template.getTransitionRule(this); } private startOrThrow() { - if (this.started) { + // This function just inserts the first step. + // Previously we had a `started` flag, but that wasn't very useful. + // But if we ever implement resumable workflows, we'll need to change this. + if (this.steps.length) { throw new Error("Workflow already started"); } - this.started = true; - } - - private configure() { - // we configure the controller loop first, so it has a chance to react to its initial step - this.template.configureControllerLoop(this, this.inputs); - this.template.configureInitialSteps(this, this.inputs); + // add the first step + const initialStep = this.template.getInitialStep(this); + this.addStep(initialStep); } // Run workflow to the ReadableStream, appropriate for streaming in Next.js routes @@ -157,10 +163,6 @@ export class Workflow { // handlers, but before we add any steps. this.dispatchEvent({ type: "workflowStarted" }); - // Important! `configure` should be called after all event listeners are - // set up. We want to capture `stepAdded` events. - this.configure(); - await this.runUntilComplete(); controller.close(); }, @@ -172,32 +174,20 @@ export class Workflow { // Run workflow without streaming, only capture the final result async runToResult(): Promise { this.startOrThrow(); - this.configure(); await this.runUntilComplete(); - // saveSummaryToFile(generateSummary(workflow)); return this.getFinalResult(); } - getPreviousStep( - step: LLMStepInstance - ): LLMStepInstance | undefined { - const index = this.steps.indexOf(step); - if (index < 1) { - return undefined; - } - return this.steps[index - 1]; - } - - addStep( - template: LLMStepTemplate, - inputs: Inputs + private addStep( + prepatedStep: PreparedStep ): LLMStepInstance { - // sorry for "any"; contravariance issues + // `any` is necessary because of countervariance issues. + // But that's not important because `PreparedStep` was already strictly typed. const step: LLMStepInstance = LLMStepInstance.create({ - template, - inputs, + template: prepatedStep.template, + inputs: prepatedStep.inputs, workflow: this, }); @@ -209,34 +199,10 @@ export class Workflow { return step; } - addLinearRule( - produce: ( - step: LLMStepInstance, - helpers: WorkflowGuardHelpers - ) => NextStepAction - ) { - this.addEventListener("stepFinished", ({ data: { step } }) => { - const result = produce(step, new WorkflowGuardHelpers(this, step)); - switch (result.kind) { - case "repeat": - this.addStep(step.template, step.inputs); - break; - case "step": - this.addStep(result.step, result.inputs); - break; - case "finish": - // no new steps to add - break; - case "fatal": - throw new Error(result.message); - } - }); - } - private async runNextStep(): Promise { const step = this.getCurrentStep(); - if (!step) { + if (!step || step.getState().kind !== "PENDING") { return; } @@ -251,9 +217,31 @@ export class Workflow { type: "stepFinished", payload: { step }, }); + + // apply the transition rule, produce the next step in PENDING state + // this code is inlined in this method, which guarantees that we always have one pending step + // (until the transition rule decides to finish) + const result = this.transitionRule(step, new WorkflowGuardHelpers(step)); + switch (result.kind) { + case "repeat": + this.addStep(step.template.prepare(step.inputs)); + break; + case "step": + this.addStep(result.step.prepare(result.inputs)); + break; + case "finish": + // no new steps to add + break; + case "fatal": + throw new Error(result.message); + } } async runUntilComplete() { + if (!this.steps.length) { + throw new Error("Workflow not started"); + } + while (!this.isProcessComplete()) { await this.runNextStep(); } @@ -306,18 +294,19 @@ export class Workflow { // Single pass through steps from most recent to oldest for (let i = this.steps.length - 1; i >= 0; i--) { const step = this.steps[i]; - const outputs = step.getOutputs(); + const stepState = step.getState(); + if (stepState.kind !== "DONE") { + continue; + } - for (const output of Object.values(outputs)) { + for (const output of Object.values(stepState.outputs)) { if (output?.kind === "code") { // If we find successful code, return immediately if (output.value.type === "success") { return { step, code: output.value.source }; } - // Otherwise store the first step with any code - if (!stepWithAnyCode) { - stepWithAnyCode = { step, code: output.value.source }; - } + // Store the first step with any code + stepWithAnyCode ??= { step, code: output.value.source }; } } } diff --git a/packages/ai/src/workflows/WorkflowGuardHelpers.ts b/packages/ai/src/workflows/WorkflowGuardHelpers.ts index 32e82c36ee..18981cc9db 100644 --- a/packages/ai/src/workflows/WorkflowGuardHelpers.ts +++ b/packages/ai/src/workflows/WorkflowGuardHelpers.ts @@ -31,10 +31,10 @@ addRule({ Where totalRepeats counts the number of template runs in the entire history, while receptRepeats counts the number of immediate repeats in the step -> parent -> parent chain. (I'm stubborn and still trying to make it parallelism-agnostic, so thinking in terms of graphs, not lists). */ export class WorkflowGuardHelpers { - constructor( - private readonly workflow: Workflow, - private readonly _step: LLMStepInstance - ) {} + private workflow: Workflow; + constructor(private readonly _step: LLMStepInstance) { + this.workflow = _step.workflow; + } totalRepeats(template: LLMStepTemplate) { return this.workflow.getSteps().filter((step) => step.instanceOf(template)) @@ -43,20 +43,20 @@ export class WorkflowGuardHelpers { recentRepeats(template: LLMStepTemplate) { let count = -1; - let step = this.workflow.getPreviousStep(this._step); + let step = this._step.getPreviousStep(); while (step?.instanceOf(template)) { count++; - step = this.workflow.getPreviousStep(step); + step = step.getPreviousStep(); } return count; } recentFailedRepeats(template: LLMStepTemplate) { let count = -1; - let step = this.workflow.getPreviousStep(this._step); + let step = this._step.getPreviousStep(); while (step?.instanceOf(template) && step.getState().kind === "FAILED") { count++; - step = this.workflow.getPreviousStep(step); + step = step.getPreviousStep(); } return count; } diff --git a/packages/ai/src/workflows/WorkflowTemplate.ts b/packages/ai/src/workflows/WorkflowTemplate.ts index 50c8be8ea4..ceccaacdb1 100644 --- a/packages/ai/src/workflows/WorkflowTemplate.ts +++ b/packages/ai/src/workflows/WorkflowTemplate.ts @@ -1,6 +1,6 @@ import { LLMStepInstance } from "../LLMStepInstance.js"; -import { Inputs, IOShape } from "../LLMStepTemplate.js"; -import { type LlmConfig, Workflow } from "./Workflow.js"; +import { Inputs, IOShape, PreparedStep } from "../LLMStepTemplate.js"; +import { type LlmConfig, StepTransitionRule, Workflow } from "./Workflow.js"; export type WorkflowInstanceParams = { id: string; @@ -19,27 +19,22 @@ export type WorkflowInstanceParams = { * It works similarly to LLMStepTemplate, but for workflows. */ export class WorkflowTemplate { - public readonly name: string; + readonly name: string; - public configureControllerLoop: ( - workflow: Workflow, - inputs: Inputs - ) => void; - public configureInitialSteps: ( - workflow: Workflow, - inputs: Inputs - ) => void; + getInitialStep: (workflow: Workflow) => PreparedStep; + getTransitionRule: (workflow: Workflow) => StepTransitionRule; // TODO - shape parameter constructor(params: { name: string; - // TODO - do we need two separate functions? we always call them together - configureControllerLoop: WorkflowTemplate["configureControllerLoop"]; - configureInitialSteps: WorkflowTemplate["configureInitialSteps"]; + // This function will be called to obtain the first step of the workflow. + getInitialStep: WorkflowTemplate["getInitialStep"]; + // This function will be called to obtain the next step of the workflow based on the current step. + getTransitionRule: WorkflowTemplate["getTransitionRule"]; }) { this.name = params.name; - this.configureInitialSteps = params.configureInitialSteps; - this.configureControllerLoop = params.configureControllerLoop; + this.getInitialStep = params.getInitialStep; + this.getTransitionRule = params.getTransitionRule; } instantiate( diff --git a/packages/ai/src/workflows/controllers.ts b/packages/ai/src/workflows/controllers.ts index b83aa00b4e..ad8cc6d2d3 100644 --- a/packages/ai/src/workflows/controllers.ts +++ b/packages/ai/src/workflows/controllers.ts @@ -1,12 +1,12 @@ import { CodeArtifact, PromptArtifact } from "../Artifact.js"; import { LLMStepInstance } from "../LLMStepInstance.js"; -import { IOShape } from "../LLMStepTemplate.js"; +import { IOShape, StepState } from "../LLMStepTemplate.js"; import { adjustToFeedbackStep } from "../steps/adjustToFeedbackStep.js"; import { fixCodeUntilItRunsStep } from "../steps/fixCodeUntilItRunsStep.js"; import { generateCodeStep } from "../steps/generateCodeStep.js"; import { matchStyleGuideStep } from "../steps/matchStyleGuideStep.js"; import { runAndFormatCodeStep } from "../steps/runAndFormatCodeStep.js"; -import { Workflow } from "./Workflow.js"; +import { StepTransitionRule, Workflow } from "./Workflow.js"; import { NextStepAction, WorkflowGuardHelpers, @@ -14,11 +14,6 @@ import { const MAX_MINOR_ERRORS = 5; -type config = { - maxNumericSteps: number; - maxStyleGuideSteps: number; -}; - // Error Messages const ERROR_MESSAGES = { FAILED_STEP: (stepName: string, errorType: string) => @@ -32,8 +27,7 @@ const ERROR_MESSAGES = { // Helper function to handle failed states function handleFailedState( step: LLMStepInstance, - h: WorkflowGuardHelpers, - config: config + h: WorkflowGuardHelpers ): NextStepAction | undefined { const state = step.getState(); @@ -50,26 +44,34 @@ function handleFailedState( ); } + return undefined; +} + +// assumes that the step is done +function getOutputs( + step: LLMStepInstance +): Extract, { kind: "DONE" }>["outputs"] { + const state = step.getState(); if (state.kind !== "DONE") { - return h.fatal(ERROR_MESSAGES.NOT_DONE(step.template.name, state.kind)); + throw new Error(ERROR_MESSAGES.NOT_DONE(step.template.name, state.kind)); } - - return undefined; + return state.outputs; } -export function fixAdjustRetryLoop( +export function getDefaultTransitionRule( workflow: Workflow, prompt: PromptArtifact -) { +): StepTransitionRule { const config = { maxNumericSteps: workflow.llmConfig.numericSteps, maxStyleGuideSteps: workflow.llmConfig.styleGuideSteps, }; - workflow.addLinearRule((step, h) => { - function getNextIntendedState( + + return (step, h) => { + const getNextIntendedState = ( intendedStep: "AdjustToFeedback" | "MatchStyleGuide", code: CodeArtifact - ): NextStepAction { + ): NextStepAction => { const nextSteps = [ // Only check AdjustToFeedback if it's intended and has runs remaining intendedStep === "AdjustToFeedback" && @@ -82,45 +84,40 @@ export function fixAdjustRetryLoop( : null, ].filter((r) => !!r); return nextSteps.length ? nextSteps[0] : h.finish(); - } + }; // process bad states - const failedState = handleFailedState(step, h, config); + const failedState = handleFailedState(step, h); if (failedState) return failedState; - function fixCodeOrAdjustToFeedback( - code: CodeArtifact | undefined - ): NextStepAction { - if (!code) { - return h.fatal(ERROR_MESSAGES.NO_CODE); - } + const fixCodeOrAdjustToFeedback = (code: CodeArtifact): NextStepAction => { if (code.value.type !== "success") { return h.step(fixCodeUntilItRunsStep, { code }); } return getNextIntendedState("AdjustToFeedback", code); - } + }; // generateCodeStep if (step.instanceOf(generateCodeStep)) { - const { code } = step.getOutputs(); + const { code } = getOutputs(step); return fixCodeOrAdjustToFeedback(code); } // runAndFormatCodeStep if (step.instanceOf(runAndFormatCodeStep)) { - const { code } = step.getOutputs(); + const { code } = getOutputs(step); return fixCodeOrAdjustToFeedback(code); } // fixCodeUntilItRunsStep if (step.instanceOf(fixCodeUntilItRunsStep)) { - const { code } = step.getOutputs(); + const { code } = getOutputs(step); return fixCodeOrAdjustToFeedback(code); } // adjustToFeedbackStep if (step.instanceOf(adjustToFeedbackStep)) { - const { code } = step.getOutputs(); + const { code } = getOutputs(step); // no code means no need for adjustment, apply style guide if (!code) { @@ -136,7 +133,7 @@ export function fixAdjustRetryLoop( // matchStyleGuideStep if (step.instanceOf(matchStyleGuideStep)) { - const { code } = step.getOutputs(); + const { code } = getOutputs(step); if (!code) { return h.finish(); } @@ -149,5 +146,5 @@ export function fixAdjustRetryLoop( } return h.fatal("Unknown step"); - }); + }; } diff --git a/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts b/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts index 89cbbf6c81..48532b29e7 100644 --- a/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts +++ b/packages/ai/src/workflows/createSquiggleWorkflowTemplate.ts @@ -1,5 +1,5 @@ import { generateCodeStep } from "../steps/generateCodeStep.js"; -import { fixAdjustRetryLoop } from "./controllers.js"; +import { getDefaultTransitionRule } from "./controllers.js"; import { WorkflowTemplate } from "./WorkflowTemplate.js"; /** @@ -16,10 +16,8 @@ export const createSquiggleWorkflowTemplate = new WorkflowTemplate<{ outputs: Record; }>({ name: "CreateSquiggle", - configureControllerLoop(workflow, inputs) { - fixAdjustRetryLoop(workflow, inputs.prompt); - }, - configureInitialSteps(workflow, inputs) { - workflow.addStep(generateCodeStep, { prompt: inputs.prompt }); - }, + getInitialStep: (workflow) => + generateCodeStep.prepare({ prompt: workflow.inputs.prompt }), + getTransitionRule: (workflow) => + getDefaultTransitionRule(workflow, workflow.inputs.prompt), }); diff --git a/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts b/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts index 1f4a4022b3..b768e9cccd 100644 --- a/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts +++ b/packages/ai/src/workflows/fixSquiggleWorkflowTemplate.ts @@ -1,6 +1,6 @@ import { PromptArtifact } from "../Artifact.js"; import { runAndFormatCodeStep } from "../steps/runAndFormatCodeStep.js"; -import { fixAdjustRetryLoop } from "./controllers.js"; +import { getDefaultTransitionRule } from "./controllers.js"; import { WorkflowTemplate } from "./WorkflowTemplate.js"; export const fixSquiggleWorkflowTemplate = new WorkflowTemplate<{ @@ -10,12 +10,11 @@ export const fixSquiggleWorkflowTemplate = new WorkflowTemplate<{ outputs: Record; }>({ name: "FixSquiggle", - configureControllerLoop(workflow) { + getInitialStep: (workflow) => + runAndFormatCodeStep.prepare({ source: workflow.inputs.source }), + getTransitionRule: (workflow) => { // TODO - cache the prompt artifact once? maybe even as a global variable // (but it's better to just refactor steps to make the prompt optional, somehow) - fixAdjustRetryLoop(workflow, new PromptArtifact("")); - }, - configureInitialSteps(workflow, inputs) { - workflow.addStep(runAndFormatCodeStep, { source: inputs.source }); + return getDefaultTransitionRule(workflow, new PromptArtifact("")); }, }); diff --git a/packages/ai/src/workflows/streaming.ts b/packages/ai/src/workflows/streaming.ts index 230780be0b..2076399a45 100644 --- a/packages/ai/src/workflows/streaming.ts +++ b/packages/ai/src/workflows/streaming.ts @@ -45,6 +45,22 @@ export function artifactToClientArtifact(value: Artifact): ClientArtifact { } } +function getClientOutputs( + step: LLMStepInstance +): Record { + const stepState = step.getState(); + return stepState.kind === "DONE" + ? Object.fromEntries( + Object.entries(stepState.outputs) + .filter( + (pair): pair is [string, NonNullable<(typeof pair)[1]>] => + pair[1] !== undefined + ) + .map(([key, value]) => [key, artifactToClientArtifact(value)]) + ) + : {}; +} + export function stepToClientStep(step: LLMStepInstance): ClientStep { return { id: step.id, @@ -56,14 +72,7 @@ export function stepToClientStep(step: LLMStepInstance): ClientStep { artifactToClientArtifact(value), ]) ), - outputs: Object.fromEntries( - Object.entries(step.getOutputs()) - .filter( - (pair): pair is [string, NonNullable<(typeof pair)[1]>] => - pair[1] !== undefined - ) - .map(([key, value]) => [key, artifactToClientArtifact(value)]) - ), + outputs: getClientOutputs(step), messages: step.getConversationMessages(), }; } @@ -121,21 +130,12 @@ export function addStreamingListeners( content: { id: event.data.step.id, state: event.data.step.getState().kind, - outputs: Object.fromEntries( - Object.entries(event.data.step.getOutputs()) - .filter( - (pair): pair is [string, NonNullable<(typeof pair)[1]>] => - pair[1] !== undefined - ) - .map(([key, value]) => [key, artifactToClientArtifact(value)]) - ), + outputs: getClientOutputs(event.data.step), messages: event.data.step.getConversationMessages(), }, }); }); workflow.addEventListener("allStepsFinished", (event) => { - // saveSummaryToFile(generateSummary(prompt, workflow)); - send({ kind: "finalResult", content: event.workflow.getFinalResult(), diff --git a/packages/hub/src/server/ai/analytics/index.ts b/packages/hub/src/server/ai/analytics/index.ts index be6017b9f4..9e592cd4df 100644 --- a/packages/hub/src/server/ai/analytics/index.ts +++ b/packages/hub/src/server/ai/analytics/index.ts @@ -68,7 +68,10 @@ export async function getTypeStats() { (typeStats[stepName]["FAILED"] ?? 0) + 1; continue; } - const code = step.getOutputs()["code"]; + if (state.kind !== "DONE") { + continue; + } + const code = state.outputs["code"]; if (code && code instanceof CodeArtifact) { typeStats[stepName][code.value.type] = (typeStats[stepName][code.value.type] ?? 0) + 1; @@ -92,7 +95,11 @@ export async function getCodeErrors() { const errors: StepError[] = []; for (const workflow of getModernWorkflows(rows)) { for (const step of workflow.getSteps()) { - const code = step.getOutputs()["code"]; + const state = step.getState(); + if (state.kind !== "DONE") { + continue; + } + const code = state.outputs["code"]; if ( code && code instanceof CodeArtifact && diff --git a/packages/hub/src/server/ai/storage.ts b/packages/hub/src/server/ai/storage.ts index de7e6cb71f..b85050db48 100644 --- a/packages/hub/src/server/ai/storage.ts +++ b/packages/hub/src/server/ai/storage.ts @@ -14,8 +14,14 @@ export function decodeDbWorkflowToClientWorkflow( switch (row.format) { case "V1_0": return decodeV1_0JsonToClientWorkflow(row.workflow); - case "V2_0": - return decodeV2_0JsonToClientWorkflow(row.workflow); + case "V2_0": { + const clientWorkflow = decodeV2_0JsonToClientWorkflow(row.workflow); + // serialized workflow doesn't include full logs, but we store them in the database + if (clientWorkflow.status === "finished") { + clientWorkflow.result.logSummary = row.markdown; + } + return clientWorkflow; + } default: throw new Error( `Unknown workflow format: ${row.format satisfies never}`