Skip to content

Commit

Permalink
Merge pull request #3443 from quantified-uncertainty/ai-tweaks
Browse files Browse the repository at this point in the history
AI cleanups
  • Loading branch information
berekuk authored Nov 19, 2024
2 parents af0cd96 + af2d43b commit d0d38f0
Show file tree
Hide file tree
Showing 17 changed files with 389 additions and 306 deletions.
201 changes: 121 additions & 80 deletions packages/ai/src/LLMStepInstance.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BaseArtifact, makeArtifact } from "./Artifact.js";
import { ArtifactKind, BaseArtifact, makeArtifact } from "./Artifact.js";
import {
calculatePriceMultipleCalls,
LlmMetrics,
Expand All @@ -10,6 +10,7 @@ import {
Inputs,
IOShape,
LLMStepTemplate,
OutputKind,
Outputs,
StepState,
} from "./LLMStepTemplate.js";
Expand All @@ -27,14 +28,22 @@ export type StepParams<Shape extends IOShape> = {
id: string;
sequentialId: number;
template: LLMStepTemplate<Shape>;
state: StepState;
state: StepState<Shape>;
inputs: Inputs<Shape>;
outputs: Partial<Outputs<Shape>>;
startTime: number;
conversationMessages: Message[];
llmMetricsList: LlmMetrics[];
};

class FailError extends Error {
constructor(
public type: ErrorType,
message: string
) {
super(message);
}
}

export class LLMStepInstance<
const Shape extends IOShape = IOShape,
const WorkflowShape extends IOShape = IOShape,
Expand All @@ -43,10 +52,8 @@ export class LLMStepInstance<
public sequentialId: StepParams<Shape>["sequentialId"];
public readonly template: StepParams<Shape>["template"];

private state: StepParams<Shape>["state"];

// must be public for `instanceOf` type guard to work
public readonly _outputs: StepParams<Shape>["outputs"];
private _state: StepParams<Shape>["state"];

public readonly inputs: StepParams<Shape>["inputs"];

Expand All @@ -56,7 +63,7 @@ export class LLMStepInstance<

// These two fields are not serialized
private logger: Logger;
private workflow: Workflow<WorkflowShape>;
public workflow: Workflow<WorkflowShape>;

private constructor(
params: StepParams<Shape> & {
Expand All @@ -70,8 +77,7 @@ export class LLMStepInstance<
this.conversationMessages = params.conversationMessages;

this.startTime = params.startTime;
this.state = params.state;
this._outputs = params.outputs;
this._state = params.state;

this.template = params.template;
this.inputs = params.inputs;
Expand All @@ -93,7 +99,6 @@ export class LLMStepInstance<
llmMetricsList: [],
startTime: Date.now(),
state: { kind: "PENDING" },
outputs: {},
...params,
});
}
Expand All @@ -113,7 +118,7 @@ export class LLMStepInstance<
}

async _run() {
if (this.state.kind !== "PENDING") {
if (this._state.kind !== "PENDING") {
return;
}

Expand All @@ -123,27 +128,51 @@ export class LLMStepInstance<
return;
}

const executeContext: ExecuteContext<Shape> = {
setOutput: (key, value) => this.setOutput(key, value),
const executeContext: ExecuteContext = {
log: (log) => this.log(log),
queryLLM: (promptPair) => this.queryLLM(promptPair),
fail: (errorType, message) => this.fail(errorType, message),
fail: (errorType, message) => {
throw new FailError(errorType, message);
},
};

try {
await this.template.execute(executeContext, this.inputs);
} catch (error) {
this.fail(
"MINOR",
error instanceof Error ? error.message : String(error)
);
return;
}

const hasFailed = (this.state as StepState).kind === "FAILED";
const result = await this.template.execute(executeContext, this.inputs);

const outputs = {} as Outputs<Shape>; // will be filled in below

// This code is unfortunately full of `as any`, caused by `Object.keys` limitations in TypeScript.
Object.keys(result).forEach((key: keyof typeof result) => {
const value = result[key];
if (value instanceof BaseArtifact) {
// already existing artifact - probably passed through from another step
outputs[key] = value as any;
} else if (!value) {
outputs[key] = undefined as any;
} else {
const outputKind = this.template.shape.outputs[key] as OutputKind;
const artifactKind = outputKind.endsWith("?")
? (outputKind.slice(0, -1) as ArtifactKind)
: (outputKind as ArtifactKind);

outputs[key] = makeArtifact(artifactKind, value as any, this) as any;
}
});

if (!hasFailed) {
this.state = { kind: "DONE", durationMs: this.calculateDuration() };
this._state = {
kind: "DONE",
durationMs: this.calculateDuration(),
outputs,
};
} catch (error) {
if (error instanceof FailError) {
this.fail(error.type, error.message);
} else {
this.fail(
"MINOR", // TODO - critical?
error instanceof Error ? error.message : String(error)
);
}
}
}

Expand All @@ -155,8 +184,8 @@ export class LLMStepInstance<

await this._run();

const completionMessage = `Step "${this.template.name}" completed with status: ${this.state.kind}${
this.state.kind !== "PENDING" && `, in ${this.state.durationMs / 1000}s`
const completionMessage = `Step "${this.template.name}" completed with status: ${this._state.kind}${
this._state.kind !== "PENDING" && `, in ${this._state.durationMs / 1000}s`
}`;

this.log({
Expand All @@ -166,15 +195,11 @@ export class LLMStepInstance<
}

getState() {
return this.state;
return this._state;
}

getDuration() {
return this.state.kind === "PENDING" ? 0 : this.state.durationMs;
}

getOutputs(): this["_outputs"] {
return this._outputs;
return this._state.kind === "PENDING" ? 0 : this._state.durationMs;
}

getInputs() {
Expand Down Expand Up @@ -208,34 +233,21 @@ export class LLMStepInstance<
}

isDone() {
return this.state.kind === "DONE";
return this._state.kind === "DONE";
}

// private methods

private setOutput<K extends Extract<keyof Shape["outputs"], string>>(
key: K,
value: Outputs<Shape>[K] | Outputs<Shape>[K]["value"]
): void {
if (key in this._outputs) {
this.fail(
"CRITICAL",
`Output ${key} is already set. This is a bug with the workflow code.`
);
return;
}

if (value instanceof BaseArtifact) {
// already existing artifact - probably passed through from another step
this._outputs[key] = value;
} else {
const kind = this.template.shape.outputs[
key
] as Outputs<Shape>[K]["kind"];
this._outputs[key] = makeArtifact(kind, value as any, this) as any;
getPreviousStep() {
const steps = this.workflow.getSteps();
const index = steps.indexOf(this as LLMStepInstance<any, WorkflowShape>);
if (index < 1) {
// first step or not found
return undefined;
}
return steps[index - 1];
}

// private methods

private log(log: LogEntry): void {
this.logger.log(log, {
workflowId: this.workflow.id,
Expand All @@ -245,7 +257,7 @@ export class LLMStepInstance<

private fail(errorType: ErrorType, message: string) {
this.log({ type: "error", message });
this.state = {
this._state = {
kind: "FAILED",
durationMs: this.calculateDuration(),
errorType,
Expand Down Expand Up @@ -320,14 +332,13 @@ export class LLMStepInstance<
// Serialization/deserialization

// StepParams don't contain the workflow reference, to to avoid circular dependencies
toParams(): StepParams<any> {
toParams(): StepParams<Shape> {
return {
id: this.id,
sequentialId: this.sequentialId,
template: this.template,
state: this.state,
state: this._state,
inputs: this.inputs,
outputs: this._outputs,
startTime: this.startTime,
conversationMessages: this.conversationMessages,
llmMetricsList: this.llmMetricsList,
Expand All @@ -342,7 +353,13 @@ export class LLMStepInstance<
}

static deserialize(
{ templateName, inputIds, outputIds, ...params }: SerializedStep,
{
templateName,
inputIds,
outputIds: legacyOutputIds,
state: serializedState,
...params
}: SerializedStep,
visitor: AiDeserializationVisitor
): StepParams<any> {
const template: LLMStepTemplate<any> = getStepTemplateByName(templateName);
Expand All @@ -352,31 +369,53 @@ export class LLMStepInstance<
visitor.artifact(inputId),
])
);
const outputs = Object.fromEntries(
Object.entries(outputIds).map(([name, outputId]) => [
name,
visitor.artifact(outputId),
])
);

let state: StepState<any>;

if (serializedState.kind === "DONE") {
const { outputIds, ...serializedStateWithoutOutputs } = serializedState;
state = {
...serializedStateWithoutOutputs,
outputs: Object.fromEntries(
Object.entries(legacyOutputIds ?? outputIds).map(
([name, outputId]) => [name, visitor.artifact(outputId)]
)
),
};
} else {
state = serializedState;
}

return {
...params,
template,
inputs,
outputs,
state,
};
}
}

export function serializeStepParams(
params: StepParams<IOShape>,
visitor: AiSerializationVisitor
) {
): SerializedStep {
return {
id: params.id,
sequentialId: params.sequentialId,
templateName: params.template.name,
state: params.state,
state:
params.state.kind === "DONE"
? {
...params.state,
outputIds: Object.fromEntries(
Object.entries(params.state.outputs)
.map(([key, output]) =>
output ? [key, visitor.artifact(output)] : undefined
)
.filter((x) => x !== undefined)
),
}
: params.state,
startTime: params.startTime,
conversationMessages: params.conversationMessages,
llmMetricsList: params.llmMetricsList,
Expand All @@ -386,21 +425,23 @@ export function serializeStepParams(
visitor.artifact(input),
])
),
outputIds: Object.fromEntries(
Object.entries(params.outputs)
.map(([key, output]) =>
output ? [key, visitor.artifact(output)] : undefined
)
.filter((x) => x !== undefined)
),
};
}

type SerializedState<Shape extends IOShape> =
| (Omit<Extract<StepState<Shape>, { kind: "DONE" }>, "outputs"> & {
outputIds: Record<string, number>;
})
| Exclude<StepState<Shape>, { kind: "DONE" }>;

export type SerializedStep = Omit<
StepParams<IOShape>,
"inputs" | "outputs" | "template"
"inputs" | "outputs" | "template" | "state"
> & {
templateName: string;
inputIds: Record<string, number>;
outputIds: Record<string, number>;
// Legacy - in the modern format we store outputs on SerializedState, but we still support this field for old workflows.
// Can be removed if we deserialize and serialize again all workflows in the existing database.
outputIds?: Record<string, number>;
state: SerializedState<IOShape>;
};
Loading

0 comments on commit d0d38f0

Please sign in to comment.