Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AI cleanups #3443

Merged
merged 7 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 121 additions & 80 deletions packages/ai/src/LLMStepInstance.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BaseArtifact, makeArtifact } from "./Artifact.js";
import { ArtifactKind, BaseArtifact, makeArtifact } from "./Artifact.js";
import {
calculatePriceMultipleCalls,
LlmMetrics,
Expand All @@ -10,6 +10,7 @@ import {
Inputs,
IOShape,
LLMStepTemplate,
OutputKind,
Outputs,
StepState,
} from "./LLMStepTemplate.js";
Expand All @@ -27,14 +28,22 @@ export type StepParams<Shape extends IOShape> = {
id: string;
sequentialId: number;
template: LLMStepTemplate<Shape>;
state: StepState;
state: StepState<Shape>;
inputs: Inputs<Shape>;
outputs: Partial<Outputs<Shape>>;
startTime: number;
conversationMessages: Message[];
llmMetricsList: LlmMetrics[];
};

class FailError extends Error {
constructor(
public type: ErrorType,
message: string
) {
super(message);
}
}

export class LLMStepInstance<
const Shape extends IOShape = IOShape,
const WorkflowShape extends IOShape = IOShape,
Expand All @@ -43,10 +52,8 @@ export class LLMStepInstance<
public sequentialId: StepParams<Shape>["sequentialId"];
public readonly template: StepParams<Shape>["template"];

private state: StepParams<Shape>["state"];

// must be public for `instanceOf` type guard to work
public readonly _outputs: StepParams<Shape>["outputs"];
private _state: StepParams<Shape>["state"];

public readonly inputs: StepParams<Shape>["inputs"];

Expand All @@ -56,7 +63,7 @@ export class LLMStepInstance<

// These two fields are not serialized
private logger: Logger;
private workflow: Workflow<WorkflowShape>;
public workflow: Workflow<WorkflowShape>;

private constructor(
params: StepParams<Shape> & {
Expand All @@ -70,8 +77,7 @@ export class LLMStepInstance<
this.conversationMessages = params.conversationMessages;

this.startTime = params.startTime;
this.state = params.state;
this._outputs = params.outputs;
this._state = params.state;

this.template = params.template;
this.inputs = params.inputs;
Expand All @@ -93,7 +99,6 @@ export class LLMStepInstance<
llmMetricsList: [],
startTime: Date.now(),
state: { kind: "PENDING" },
outputs: {},
...params,
});
}
Expand All @@ -113,7 +118,7 @@ export class LLMStepInstance<
}

async _run() {
if (this.state.kind !== "PENDING") {
if (this._state.kind !== "PENDING") {
return;
}

Expand All @@ -123,27 +128,51 @@ export class LLMStepInstance<
return;
}

const executeContext: ExecuteContext<Shape> = {
setOutput: (key, value) => this.setOutput(key, value),
const executeContext: ExecuteContext = {
log: (log) => this.log(log),
queryLLM: (promptPair) => this.queryLLM(promptPair),
fail: (errorType, message) => this.fail(errorType, message),
fail: (errorType, message) => {
throw new FailError(errorType, message);
},
};

try {
await this.template.execute(executeContext, this.inputs);
} catch (error) {
this.fail(
"MINOR",
error instanceof Error ? error.message : String(error)
);
return;
}

const hasFailed = (this.state as StepState).kind === "FAILED";
const result = await this.template.execute(executeContext, this.inputs);

const outputs = {} as Outputs<Shape>; // will be filled in below

// This code is unfortunately full of `as any`, caused by `Object.keys` limitations in TypeScript.
Object.keys(result).forEach((key: keyof typeof result) => {
const value = result[key];
if (value instanceof BaseArtifact) {
// already existing artifact - probably passed through from another step
outputs[key] = value as any;
} else if (!value) {
outputs[key] = undefined as any;
} else {
const outputKind = this.template.shape.outputs[key] as OutputKind;
const artifactKind = outputKind.endsWith("?")
? (outputKind.slice(0, -1) as ArtifactKind)
: (outputKind as ArtifactKind);

outputs[key] = makeArtifact(artifactKind, value as any, this) as any;
}
});

if (!hasFailed) {
this.state = { kind: "DONE", durationMs: this.calculateDuration() };
this._state = {
kind: "DONE",
durationMs: this.calculateDuration(),
outputs,
};
} catch (error) {
if (error instanceof FailError) {
this.fail(error.type, error.message);
} else {
this.fail(
"MINOR", // TODO - critical?
error instanceof Error ? error.message : String(error)
);
}
}
}

Expand All @@ -155,8 +184,8 @@ export class LLMStepInstance<

await this._run();

const completionMessage = `Step "${this.template.name}" completed with status: ${this.state.kind}${
this.state.kind !== "PENDING" && `, in ${this.state.durationMs / 1000}s`
const completionMessage = `Step "${this.template.name}" completed with status: ${this._state.kind}${
this._state.kind !== "PENDING" && `, in ${this._state.durationMs / 1000}s`
}`;

this.log({
Expand All @@ -166,15 +195,11 @@ export class LLMStepInstance<
}

getState() {
return this.state;
return this._state;
}

getDuration() {
return this.state.kind === "PENDING" ? 0 : this.state.durationMs;
}

getOutputs(): this["_outputs"] {
return this._outputs;
return this._state.kind === "PENDING" ? 0 : this._state.durationMs;
}

getInputs() {
Expand Down Expand Up @@ -208,34 +233,21 @@ export class LLMStepInstance<
}

isDone() {
return this.state.kind === "DONE";
return this._state.kind === "DONE";
}

// private methods

private setOutput<K extends Extract<keyof Shape["outputs"], string>>(
key: K,
value: Outputs<Shape>[K] | Outputs<Shape>[K]["value"]
): void {
if (key in this._outputs) {
this.fail(
"CRITICAL",
`Output ${key} is already set. This is a bug with the workflow code.`
);
return;
}

if (value instanceof BaseArtifact) {
// already existing artifact - probably passed through from another step
this._outputs[key] = value;
} else {
const kind = this.template.shape.outputs[
key
] as Outputs<Shape>[K]["kind"];
this._outputs[key] = makeArtifact(kind, value as any, this) as any;
getPreviousStep() {
const steps = this.workflow.getSteps();
const index = steps.indexOf(this as LLMStepInstance<any, WorkflowShape>);
if (index < 1) {
// first step or not found
return undefined;
}
return steps[index - 1];
}

// private methods

private log(log: LogEntry): void {
this.logger.log(log, {
workflowId: this.workflow.id,
Expand All @@ -245,7 +257,7 @@ export class LLMStepInstance<

private fail(errorType: ErrorType, message: string) {
this.log({ type: "error", message });
this.state = {
this._state = {
kind: "FAILED",
durationMs: this.calculateDuration(),
errorType,
Expand Down Expand Up @@ -320,14 +332,13 @@ export class LLMStepInstance<
// Serialization/deserialization

// StepParams don't contain the workflow reference, to to avoid circular dependencies
toParams(): StepParams<any> {
toParams(): StepParams<Shape> {
return {
id: this.id,
sequentialId: this.sequentialId,
template: this.template,
state: this.state,
state: this._state,
inputs: this.inputs,
outputs: this._outputs,
startTime: this.startTime,
conversationMessages: this.conversationMessages,
llmMetricsList: this.llmMetricsList,
Expand All @@ -342,7 +353,13 @@ export class LLMStepInstance<
}

static deserialize(
{ templateName, inputIds, outputIds, ...params }: SerializedStep,
{
templateName,
inputIds,
outputIds: legacyOutputIds,
state: serializedState,
...params
}: SerializedStep,
visitor: AiDeserializationVisitor
): StepParams<any> {
const template: LLMStepTemplate<any> = getStepTemplateByName(templateName);
Expand All @@ -352,31 +369,53 @@ export class LLMStepInstance<
visitor.artifact(inputId),
])
);
const outputs = Object.fromEntries(
Object.entries(outputIds).map(([name, outputId]) => [
name,
visitor.artifact(outputId),
])
);

let state: StepState<any>;

if (serializedState.kind === "DONE") {
const { outputIds, ...serializedStateWithoutOutputs } = serializedState;
state = {
...serializedStateWithoutOutputs,
outputs: Object.fromEntries(
Object.entries(legacyOutputIds ?? outputIds).map(
([name, outputId]) => [name, visitor.artifact(outputId)]
)
),
};
} else {
state = serializedState;
}

return {
...params,
template,
inputs,
outputs,
state,
};
}
}

export function serializeStepParams(
params: StepParams<IOShape>,
visitor: AiSerializationVisitor
) {
): SerializedStep {
return {
id: params.id,
sequentialId: params.sequentialId,
templateName: params.template.name,
state: params.state,
state:
params.state.kind === "DONE"
? {
...params.state,
outputIds: Object.fromEntries(
Object.entries(params.state.outputs)
.map(([key, output]) =>
output ? [key, visitor.artifact(output)] : undefined
)
.filter((x) => x !== undefined)
),
}
: params.state,
startTime: params.startTime,
conversationMessages: params.conversationMessages,
llmMetricsList: params.llmMetricsList,
Expand All @@ -386,21 +425,23 @@ export function serializeStepParams(
visitor.artifact(input),
])
),
outputIds: Object.fromEntries(
Object.entries(params.outputs)
.map(([key, output]) =>
output ? [key, visitor.artifact(output)] : undefined
)
.filter((x) => x !== undefined)
),
};
}

type SerializedState<Shape extends IOShape> =
| (Omit<Extract<StepState<Shape>, { kind: "DONE" }>, "outputs"> & {
outputIds: Record<string, number>;
})
| Exclude<StepState<Shape>, { kind: "DONE" }>;

export type SerializedStep = Omit<
StepParams<IOShape>,
"inputs" | "outputs" | "template"
"inputs" | "outputs" | "template" | "state"
> & {
templateName: string;
inputIds: Record<string, number>;
outputIds: Record<string, number>;
// Legacy - in the modern format we store outputs on SerializedState, but we still support this field for old workflows.
// Can be removed if we deserialize and serialize again all workflows in the existing database.
outputIds?: Record<string, number>;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope I haven't broken old workflows with this change (moving outputs to step.state, and changing serialization code accordingly); as far as I can tell it's fine, i.e. old workflows still load in my dev hub instance.

state: SerializedState<IOShape>;
};
Loading