Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Refactor message types to use unified schema #456

Merged
merged 3 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 37 additions & 32 deletions control-plane/src/modules/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ export const integrationSchema = z.object({
.nullable(),
});

export const genericMessageDataSchema = z
const genericMessageDataSchema = z
.object({
message: z.string(),
details: z.object({}).passthrough().optional(),
})
.strict();

export const resultDataSchema = z
const resultDataSchema = z
.object({
id: z.string(),
result: z.object({}).passthrough(),
Expand All @@ -124,7 +124,7 @@ export const learningSchema = z.object({
}),
});

export const agentDataSchema = z
const agentDataSchema = z
.object({
done: z.boolean().optional(),
result: anyObject.optional(),
Expand All @@ -144,10 +144,31 @@ export const agentDataSchema = z
})
.strict();

export const messageDataSchema = z.union([
resultDataSchema,
agentDataSchema,
genericMessageDataSchema,
export const unifiedMessageDataSchema = z.discriminatedUnion("type", [
z.object({
type: z.literal("agent"),
data: agentDataSchema,
}),
z.object({
type: z.literal("invocation-result"),
data: resultDataSchema,
}),
z.object({
type: z.literal("human"),
data: genericMessageDataSchema,
}),
z.object({
type: z.literal("template"),
data: genericMessageDataSchema,
}),
z.object({
type: z.literal("supervisor"),
data: genericMessageDataSchema,
}),
z.object({
type: z.literal("agent-invalid"),
data: genericMessageDataSchema,
}),
]);

export const FunctionConfigSchema = z.object({
Expand Down Expand Up @@ -623,21 +644,14 @@ export const definition = {
}),
responses: {
200: z.array(
z.object({
id: z.string(),
data: messageDataSchema,
type: z.enum([
"human",
"template",
"invocation-result",
"agent",
"agent-invalid",
"supervisor",
]),
createdAt: z.date(),
pending: z.boolean().default(false),
displayableContext: z.record(z.string()).nullable(),
})
z
.object({
id: z.string(),
createdAt: z.date(),
pending: z.boolean().default(false),
displayableContext: z.record(z.string()).nullable(),
})
.merge(z.object({ data: unifiedMessageDataSchema }))
),
401: z.undefined(),
},
Expand Down Expand Up @@ -935,16 +949,7 @@ export const definition = {
messages: z.array(
z.object({
id: z.string(),
data: messageDataSchema,
type: z.enum([
// TODO: Remove 'template' type
"template",
"invocation-result",
"human",
"agent",
"agent-invalid",
"supervisor",
]),
data: unifiedMessageDataSchema,
createdAt: z.date(),
pending: z.boolean().default(false),
displayableContext: z.record(z.string()).nullable(),
Expand Down
3 changes: 1 addition & 2 deletions control-plane/src/modules/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import {
import { Pool } from "pg";
import { env } from "../utilities/env";
import { logger } from "./observability/logger";
import { MessageData } from "./workflows/workflow-messages";

export const createMutex = advisoryLock(env.DATABASE_URL);

Expand Down Expand Up @@ -323,7 +322,7 @@ export const workflowMessages = pgTable(
withTimezone: true,
precision: 6,
}),
data: json("data").$type<MessageData>().notNull(),
data: json("data").$type<unknown>().notNull(),
type: text("type", {
enum: ["human", "invocation-result", "template", "agent", "agent-invalid", "supervisor"],
}).notNull(),
Expand Down
81 changes: 42 additions & 39 deletions control-plane/src/modules/email/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { workflowMessages } from "../data";
import { ses } from "../ses";
import { getMessageByReference, updateMessageReference } from "../workflows/workflow-messages";
import { ulid } from "ulid";
import { unifiedMessageDataSchema } from "../contract";

const EMAIL_INIT_MESSAGE_ID_META_KEY = "emailInitMessageId";
const EMAIL_SUBJECT_META_KEY = "emailSubject";
Expand All @@ -33,7 +34,7 @@ const sesMessageSchema = z.object({
dmarcVerdict: z.object({ status: z.string() }),
}),
content: z.string(),
})
});

const snsNotificationSchema = z.object({
Type: z.literal("Notification"),
Expand All @@ -54,7 +55,7 @@ const emailIngestionConsumer = env.SQS_EMAIL_INGESTION_QUEUE_URL
: undefined;

export const start = async () => {
emailIngestionConsumer?.start()
emailIngestionConsumer?.start();
};

export const stop = async () => {
Expand All @@ -71,25 +72,28 @@ export const handleNewRunMessage = async ({
runId: string;
type: InferSelectModel<typeof workflowMessages>["type"];
data: InferSelectModel<typeof workflowMessages>["data"];

};
runMetadata?: Record<string, string>;
}) => {
if (message.type !== "agent") {
return;
}

if (!runMetadata?.[EMAIL_INIT_MESSAGE_ID_META_KEY] || !runMetadata?.[EMAIL_SUBJECT_META_KEY] || !runMetadata?.[EMAIL_SOURCE_META_KEY]) {
if (
!runMetadata?.[EMAIL_INIT_MESSAGE_ID_META_KEY] ||
!runMetadata?.[EMAIL_SUBJECT_META_KEY] ||
!runMetadata?.[EMAIL_SOURCE_META_KEY]
) {
return;
}

if ("message" in message.data && message.data.message) {
const messageData = unifiedMessageDataSchema.parse(message.data).data;

if ("message" in messageData && messageData.message) {
const result = await ses.sendEmail({
Source: `"Inferable" <${message.clusterId}@${env.INFERABLE_EMAIL_DOMAIN}>`,
Destination: {
ToAddresses: [
runMetadata[EMAIL_SOURCE_META_KEY]
],
ToAddresses: [runMetadata[EMAIL_SOURCE_META_KEY]],
},
Message: {
Subject: {
Expand All @@ -99,11 +103,11 @@ export const handleNewRunMessage = async ({
Body: {
Text: {
Charset: "UTF-8",
Data: message.data.message,
Data: messageData.message,
},
},
},
})
});

if (!result.MessageId) {
throw new Error("SES did not return a message ID");
Expand All @@ -119,7 +123,6 @@ export const handleNewRunMessage = async ({
clusterId: message.clusterId,
messageId: message.id,
});

} else {
logger.warn("Email thread message does not have content");
}
Expand All @@ -130,7 +133,7 @@ export async function parseMessage(message: unknown) {
if (!notification.success) {
logger.error("Could not parse SNS notification message", {
error: notification.error,
})
});
throw new Error("Could not parse SNS notification message");
}

Expand All @@ -143,19 +146,22 @@ export async function parseMessage(message: unknown) {
if (!sesMessage.success) {
logger.error("Could not parse SES message", {
error: sesMessage.error,
})
});
throw new Error("Could not parse SES message");
}

const ingestionAddresses = sesMessage.data.mail.destination.filter(
(email) => email.endsWith(env.INFERABLE_EMAIL_DOMAIN)
)
const ingestionAddresses = sesMessage.data.mail.destination.filter(email =>
email.endsWith(env.INFERABLE_EMAIL_DOMAIN)
);

if (ingestionAddresses.length > 1) {
throw new Error("Found multiple Inferable email addresses in destination");
}

const clusterId = ingestionAddresses.pop()?.replace(env.INFERABLE_EMAIL_DOMAIN, "").replace("@", "");
const clusterId = ingestionAddresses
.pop()
?.replace(env.INFERABLE_EMAIL_DOMAIN, "")
.replace("@", "");

if (!clusterId) {
throw new Error("Could not extract clusterId from email address");
Expand All @@ -166,9 +172,9 @@ export async function parseMessage(message: unknown) {
throw new Error("Could not parse email content");
}

let body = mail.text
let body = mail.text;
if (!body && mail.html) {
body = mail.html
body = mail.html;
}

return {
Expand All @@ -179,26 +185,25 @@ export async function parseMessage(message: unknown) {
messageId: mail.messageId,
source: sesMessage.data.mail.source,
inReplyTo: mail.inReplyTo,
references: (typeof mail.references === "string") ? [mail.references] : mail.references ?? [],
}
references: typeof mail.references === "string" ? [mail.references] : (mail.references ?? []),
};
}

// Strip trailing email chain quotes ">"
export const stripQuoteTail = (message: string) => {
const lines = message.split("\n").reverse();

while (lines[0] && lines[0].startsWith(">") || lines[0].trim() === "") {
while ((lines[0] && lines[0].startsWith(">")) || lines[0].trim() === "") {
lines.shift();
}

return lines.reverse().join("\n");
}
};

async function handleEmailIngestion(raw: unknown) {
const message = await parseMessage(raw);
if (!message.body) {
logger.info("Email had no body. Skipping", {
});
logger.info("Email had no body. Skipping", {});
return;
}

Expand Down Expand Up @@ -239,7 +244,7 @@ async function handleEmailIngestion(raw: unknown) {
body: message.body,
clusterId: message.clusterId,
messageId: message.messageId,
runId: existing.runId
runId: existing.runId,
});
}

Expand All @@ -249,12 +254,11 @@ async function handleEmailIngestion(raw: unknown) {
clusterId: message.clusterId,
messageId: message.messageId,
subject: message.subject,
source: message.source
source: message.source,
});
}


const parseMailContent = (message: string): Promise<ParsedMail> => {
const parseMailContent = (message: string): Promise<ParsedMail> => {
return new Promise((resolve, reject) => {
simpleParser(message, (error, parsed) => {
if (error) {
Expand All @@ -263,10 +267,9 @@ const parseMailContent = (message: string): Promise<ParsedMail> => {
resolve(parsed);
}
});
})
});
};


const authenticateUser = async (emailAddress: string, clusterId: string) => {
if (!env.CLERK_SECRET_KEY) {
throw new Error("CLERK_SECRET_KEY must be set for email authentication");
Expand All @@ -293,7 +296,7 @@ const handleNewChain = async ({
clusterId,
messageId,
subject,
source
source,
}: {
userId: string;
body: string;
Expand All @@ -302,7 +305,7 @@ const handleNewChain = async ({
subject: string;
source: string;
}) => {
logger.info("Creating new run from email")
logger.info("Creating new run from email");
await createRunWithMessage({
userId,
clusterId,
Expand All @@ -317,23 +320,23 @@ const handleNewChain = async ({
},
message: body,
type: "human",
})
}
});
};

const handleExistingChain = async ({
userId,
body,
clusterId,
messageId,
runId
runId,
}: {
userId: string;
body: string;
clusterId: string;
messageId: string;
runId: string;
}) => {
logger.info("Continuing existing run from email")
logger.info("Continuing existing run from email");
await addMessageAndResume({
id: ulid(),
clusterId,
Expand All @@ -345,5 +348,5 @@ const handleExistingChain = async ({
},
message: body,
type: "human",
})
}
});
};
Loading
Loading