From f5a5d981a3c4b1676c3fc2868f259d09c0e048d8 Mon Sep 17 00:00:00 2001 From: John Smith Date: Sun, 1 Dec 2024 14:46:44 +1030 Subject: [PATCH] chore: Remove learning extraction (#174) --- control-plane/src/index.ts | 9 +- control-plane/src/modules/data.ts | 60 ------ .../modules/knowledge/learnings.ai.test.ts | 190 ------------------ .../src/modules/knowledge/learnings.ts | 170 ---------------- control-plane/src/modules/knowledge/queues.ts | 79 -------- .../workflows/agent/nodes/model-call.ts | 89 +------- control-plane/src/utilities/env.ts | 1 - 7 files changed, 4 insertions(+), 594 deletions(-) delete mode 100644 control-plane/src/modules/knowledge/learnings.ai.test.ts delete mode 100644 control-plane/src/modules/knowledge/learnings.ts delete mode 100644 control-plane/src/modules/knowledge/queues.ts diff --git a/control-plane/src/index.ts b/control-plane/src/index.ts index 6d78abf7..40b2b29d 100644 --- a/control-plane/src/index.ts +++ b/control-plane/src/index.ts @@ -10,7 +10,6 @@ import * as serviceDefinitions from "./modules/service-definitions"; import * as events from "./modules/observability/events"; import * as router from "./modules/router"; import * as redis from "./modules/redis"; -import * as knowledge from "./modules/knowledge/queues"; import * as toolhouse from "./modules/integrations/toolhouse"; import * as externalCalls from "./modules/jobs/external"; import * as models from "./modules/models/routing"; @@ -128,16 +127,15 @@ const startTime = Date.now(); jobs.start(), serviceDefinitions.start(), workflows.start(), - knowledge.start(), models.start(), redis.start(), + customerTelemetry.start(), + toolhouse.start(), + externalCalls.start(), ...(env.EE_DEPLOYMENT ? [ flagsmith?.getEnvironmentFlags(), - customerTelemetry.start(), analytics.start(), - toolhouse.start(), - externalCalls.start(), ] : []), ]) @@ -176,7 +174,6 @@ process.on("SIGTERM", async () => { flagsmith?.close(), hdx?.shutdown(), redis.stop(), - knowledge.stop(), customerTelemetry.stop(), externalCalls.stop(), ]); diff --git a/control-plane/src/modules/data.ts b/control-plane/src/modules/data.ts index 70c1abe9..c1cd8e45 100644 --- a/control-plane/src/modules/data.ts +++ b/control-plane/src/modules/data.ts @@ -435,62 +435,6 @@ export const embeddings = pgTable( }), ); -export const knowledgeLearnings = pgTable( - "knowledge_learnings", - { - id: varchar("id", { length: 1024 }).notNull(), - cluster_id: varchar("cluster_id").notNull(), - summary: text("summary").notNull(), - accepted: boolean("accepted").notNull().default(false), - }, - (table) => ({ - pk: primaryKey({ - columns: [table.cluster_id, table.id], - }), - }), -); - -export const knowledgeLearningsRelations = relations( - knowledgeLearnings, - ({ many }) => ({ - entities: many(knowledgeEntities, { - relationName: "knowledgeLearnings", - }), - }), -); - -export const knowledgeEntities = pgTable( - "knowledge_entities", - { - cluster_id: varchar("cluster_id").notNull(), - learning_id: varchar("learning_id", { length: 1024 }), - type: text("type", { - enum: ["tool"], - }), - name: varchar("name", { length: 1024 }), - }, - (table) => ({ - pk: primaryKey({ - columns: [table.cluster_id, table.name, table.learning_id], - }), - learningReference: foreignKey({ - columns: [table.cluster_id, table.learning_id], - foreignColumns: [knowledgeLearnings.cluster_id, knowledgeLearnings.id], - }).onDelete("cascade"), - }), -); - -export const knowledgeEntitiesRelations = relations( - knowledgeEntities, - ({ one }) => ({ - learning: one(knowledgeLearnings, { - relationName: "knowledgeLearnings", - fields: [knowledgeEntities.cluster_id, knowledgeEntities.learning_id], - references: [knowledgeLearnings.cluster_id, knowledgeLearnings.id], - }), - }), -); - export const apiKeys = pgTable( "api_keys", { @@ -684,10 +628,6 @@ export const analyticsSnapshots = pgTable( export const db = drizzle(pool, { schema: { workflows, - knowledgeLearnings, - knowledgeLearningsRelations, - knowledgeEntities, - knowledgeEntitiesRelations, toolMetadata, promptTemplates, events, diff --git a/control-plane/src/modules/knowledge/learnings.ai.test.ts b/control-plane/src/modules/knowledge/learnings.ai.test.ts deleted file mode 100644 index edac1201..00000000 --- a/control-plane/src/modules/knowledge/learnings.ai.test.ts +++ /dev/null @@ -1,190 +0,0 @@ -import { ulid } from "ulid"; -import { mergeLearnings } from "./learnings"; - -describe("mergeLearnings", () => { - it("should dedeuplicate learnings", async () => { - const existingLearnings = [ - { - id: ulid(), - summary: "Requires authentication before use", - entities: [ - { - name: "loadWebpage", - type: "tool" as const, - }, - ], - relevance: { - temporality: "persistent", - }, - }, - ]; - - const newLearnings = [ - { - id: ulid(), - summary: "Needs to be authenticated when used", - entities: [ - { - name: "loadWebpage", - type: "tool" as const, - }, - ], - relevance: { - temporality: "persistent", - }, - }, - { - id: ulid(), - summary: "Can not be called without authentication", - entities: [ - { - name: "loadWebpage", - type: "tool" as const, - }, - ], - relevance: { - temporality: "persistent", - }, - }, - { - id: ulid(), - summary: "Call authenticate before use", - entities: [ - { - name: "loadWebpage", - type: "tool" as const, - }, - ], - relevance: { - temporality: "persistent", - }, - }, - ]; - - const result = await mergeLearnings({ - newLearnings, - existingLearnings, - clusterId: "test", - }); - - expect(result).toHaveLength(1); - expect(result).toEqual(existingLearnings); - }); - - it("should add new learnings", async () => { - const existingLearnings = [ - { - id: ulid(), - summary: "Requires authentication before use", - entities: [ - { - name: "loadWebpage", - type: "tool" as const, - }, - ], - relevance: { - temporality: "persistent", - }, - }, - ]; - - const newLearnings = [ - { - id: ulid(), - summary: "Can only be used by administrator users", - entities: [ - { - name: "loadWebpage", - type: "tool" as const, - }, - ], - relevance: { - temporality: "persistent", - }, - }, - { - id: ulid(), - summary: "Can only be called on business days", - entities: [ - { - name: "loadWebpage", - type: "tool" as const, - }, - ], - relevance: { - temporality: "persistent", - }, - }, - ]; - - const result = await mergeLearnings({ - newLearnings, - existingLearnings, - clusterId: "test", - }); - - expect(result).toHaveLength(3); - expect(result).toEqual(existingLearnings.concat(newLearnings)); - }); - - it("should merge entities of existing and new learnings", async () => { - const existingLearnings = [ - { - id: ulid(), - summary: "Requires authentication before use", - entities: [ - { - name: "loadWebpage", - type: "tool" as const, - }, - ], - relevance: { - temporality: "persistent", - }, - }, - ]; - - const newLearnings = [ - { - id: ulid(), - summary: "Needs to be logged in", - entities: [ - { - name: "sendEmail", - type: "tool" as const, - }, - ], - relevance: { - temporality: "persistent", - }, - }, - ]; - - const result = await mergeLearnings({ - newLearnings, - existingLearnings, - clusterId: "test", - }); - - expect(result).toHaveLength(1); - expect(result).toEqual([ - { - id: existingLearnings[0].id, - summary: "Requires authentication before use", - entities: [ - { - name: "loadWebpage", - type: "tool" as const, - }, - { - name: "sendEmail", - type: "tool" as const, - }, - ], - relevance: { - temporality: "persistent", - }, - }, - ]); - }); -}); diff --git a/control-plane/src/modules/knowledge/learnings.ts b/control-plane/src/modules/knowledge/learnings.ts deleted file mode 100644 index 75a081ee..00000000 --- a/control-plane/src/modules/knowledge/learnings.ts +++ /dev/null @@ -1,170 +0,0 @@ -import { logger } from "../observability/logger"; -import { db, knowledgeEntities, knowledgeLearnings } from "../data"; -import { and, eq } from "drizzle-orm"; -import { learningSchema } from "../contract"; -import { z } from "zod"; -import { buildModel } from "../models"; - -export type Learning = Omit, "relevance"> & { - id: string; -}; - -export const getLearnings = async (clusterId: string) => { - return (await db.query.knowledgeLearnings.findMany({ - where: eq(knowledgeLearnings.cluster_id, clusterId), - with: { - entities: true, - }, - })) as Learning[]; -}; - -export const upsertLearning = async ( - clusterId: string, - learning: Learning & { accepted?: boolean }, -) => { - await db.transaction(async (tx) => { - await tx - .insert(knowledgeLearnings) - .values({ - id: learning.id, - summary: learning.summary, - accepted: learning.accepted ?? false, - cluster_id: clusterId, - }) - .onConflictDoUpdate({ - where: and( - eq(knowledgeLearnings.cluster_id, clusterId), - eq(knowledgeLearnings.id, learning.id), - ), - set: { - accepted: learning.accepted ?? false, - }, - target: [knowledgeLearnings.cluster_id, knowledgeLearnings.id], - }); - - await tx - .insert(knowledgeEntities) - .values( - learning.entities.map((entity) => ({ - ...entity, - cluster_id: clusterId, - learning_id: learning.id, - })), - ) - .onConflictDoNothing(); - }); -}; - -/** - * Merge two sets of learnings. - * Duplicates are discarded. - * If a duplicate specifies a new entity, the new entity is appended to the existing learning's entity list. - */ -export const mergeLearnings = async ({ - newLearnings, - existingLearnings, - clusterId, - attempts = 0, -}: { - newLearnings: Learning[]; - existingLearnings: Learning[]; - clusterId: string; - attempts?: number; -}): Promise => { - const system = [ - `A learning is a piece of information about a tool that is relevant to the system.`, - `Evaluate the existing and new learnings in the system and identify which are duplicates.`, - `A duplicate is defined as a learning describing the same information.`, - ].join("\n"); - - const schema = z.object({ - duplicates: z - .record( - z - // @ts-expect-error: We don't care about the type information here, but we want to constrain the choices - .enum(existingLearnings.map((l) => l.id) as string[] as const) - .describe("The existing learning ID"), - z - .array(z.string()) - .describe( - "The IDs of all the learnings that are duplicates of the existing learning.", - ), - ) - .optional(), - }); - - const model = buildModel({ - identifier: "claude-3-5-sonnet", - trackingOptions: { - clusterId, - }, - purpose: "learnings.merge", - }); - - // Strip out other fields from the learnings (entities, etc) - const prepared = { - existing: existingLearnings.map((l) => ({ id: l.id, summary: l.summary })), - new: newLearnings.map((l) => ({ id: l.id, summary: l.summary })), - }; - - const result = await model.structured({ - system, - schema, - messages: [ - { - role: "user", - content: ` -${prepared.existing.map((learning) => JSON.stringify(learning, null, 2)).join("\n")} - - -${prepared.new.map((learning) => JSON.stringify(learning, null, 2)).join("\n")} -`, - }, - ], - }); - - if (!result.parsed.success) { - if (attempts >= 5) { - throw new Error("Failed to parse mergeLearnings output after 5 attempts"); - } - - logger.info("Failed to parse mergeLearnings output, retrying", { - attempts, - }); - - return mergeLearnings({ - newLearnings, - existingLearnings, - clusterId, - attempts: attempts + 1, - }); - } - - const duplicateLookup = result.parsed.data.duplicates ?? {}; - - return [ - // Attach any new entities to existing learnings - ...existingLearnings.map((existing) => ({ - ...existing, - entities: [ - ...existing.entities, - ...newLearnings - .filter((newLearning) => - duplicateLookup[existing.id]?.includes(newLearning.id), - ) - .flatMap((newLearning) => newLearning.entities) - .filter( - (entity) => - !existing.entities.some( - (existingEntity) => existingEntity.name === entity.name, - ), - ), - ], - })), - // Add new learnings, filtering out any duplicates - ...newLearnings.filter( - (newLearning) => - !Object.values(duplicateLookup).flat().includes(newLearning.id), - ), - ]; -}; diff --git a/control-plane/src/modules/knowledge/queues.ts b/control-plane/src/modules/knowledge/queues.ts deleted file mode 100644 index be6cf105..00000000 --- a/control-plane/src/modules/knowledge/queues.ts +++ /dev/null @@ -1,79 +0,0 @@ -import { env } from "../../utilities/env"; - -import { Consumer } from "sqs-consumer"; -import { BaseMessage, baseMessageSchema, sqs, withObservability } from "../sqs"; -import { getLearnings, mergeLearnings, upsertLearning } from "./learnings"; -import { z } from "zod"; -import { learningSchema } from "../contract"; -import { logger } from "../observability/logger"; -import { ulid } from "ulid"; - -const learningProcessConsumer = env.SQS_LEARNING_INGEST_QUEUE_URL - ? Consumer.create({ - queueUrl: env.SQS_LEARNING_INGEST_QUEUE_URL, - batchSize: 5, - visibilityTimeout: 60, - heartbeatInterval: 30, - handleMessage: withObservability( - env.SQS_LEARNING_INGEST_QUEUE_URL, - handleLearningIngest, - ), - sqs, - }) - : undefined; - -export const start = async () => { - await Promise.all([learningProcessConsumer?.start()]); -}; - -export const stop = async () => { - learningProcessConsumer?.stop(); -}; - -async function handleLearningIngest(message: BaseMessage) { - const zodResult = baseMessageSchema - .extend({ - learnings: z.array(learningSchema), - }) - .safeParse(message); - - if (!zodResult.success) { - logger.error("Message does not conform to learning ingestion schema", { - error: zodResult.error, - body: message, - }); - return; - } - - const { clusterId, runId, learnings } = zodResult.data; - - logger.info("Evaluating new learnings", { - learnings, - }); - - const existing = await getLearnings(clusterId); - - const merged = await mergeLearnings({ - clusterId, - newLearnings: learnings.map((l) => ({ - ...l, - id: ulid(), - })), - existingLearnings: existing, - }); - - const newLearnings = merged.filter( - (l) => !existing.find((e) => e.id === l.id), - ); - if (!newLearnings.length) { - return; - } - - logger.info("New learnings found", { - learnings: newLearnings, - }); - - for (const learning of newLearnings) { - await upsertLearning(clusterId, learning); - } -} diff --git a/control-plane/src/modules/workflows/agent/nodes/model-call.ts b/control-plane/src/modules/workflows/agent/nodes/model-call.ts index 6a9d2b79..cc8c1672 100644 --- a/control-plane/src/modules/workflows/agent/nodes/model-call.ts +++ b/control-plane/src/modules/workflows/agent/nodes/model-call.ts @@ -4,20 +4,16 @@ import { logger } from "../../../observability/logger"; import { WorkflowAgentState, WorkflowAgentStateMessage } from "../state"; import { addAttributes, - injectTraceContext, withSpan, } from "../../../observability/tracer"; import { AgentError } from "../../../../utilities/errors"; import { z } from "zod"; import { ulid } from "ulid"; -import { learningSchema } from "../../../contract"; import { deserializeFunctionSchema } from "../../../service-definitions"; import { validateFunctionSchema } from "inferable"; import { JsonSchemaInput } from "inferable/bin/types"; import { toolSchema } from "./tool-parser"; -import { sqs } from "../../../sqs"; -import { env } from "../../../../utilities/env"; import { Model } from "../../../models"; import { ToolUseBlock } from "@anthropic-ai/sdk/resources"; @@ -90,18 +86,13 @@ const _handleModelCall = async ( } : {}), - learnings: z - .array(learningSchema) - .describe( - "Any information you have learned about the tools as a result of this step, do not repeat.", - ) - .optional(), issue: z .string() .describe( "Describe any issues you have encountered in this step. Specifically related to the tools you are using.", ) .optional(), + invocations: z .array( z.object({ @@ -136,33 +127,6 @@ const _handleModelCall = async ( "Once all tasks have been completed, return the final result as a structured object.", "Provide concise and clear responses. Use **bold** to highlight important words.", state.additionalContext, - "If you learn details about an entity, include them in the 'learnings' field.", - "", - JSON.stringify({ - entities: [ - { - name: "loadWebpage", - type: "tool", - }, - ], - summary: "Requires a fully qualified URL", - relevance: { - temporality: "transient", - }, - }), - JSON.stringify({ - summary: "Is currently impacted by network issues", - entities: [ - { - name: "sendEmail", - type: "tool", - }, - ], - relevance: { - temporality: "transient", - }, - }), - "", "", schemaString, "", @@ -349,56 +313,6 @@ const _handleModelCall = async ( }; } - if (data.learnings && data.learnings.length > 0) { - data.learnings = data.learnings.filter((learning) => { - const missing = learning.entities?.filter((entity) => { - return !state.allAvailableTools.includes(entity.name); - }); - - if (missing && missing.length > 0) { - logger.info("Filtering out learning as entities could not be found", { - learning, - }); - return false; - } - - const selfReference = learning.entities.find((entity) => { - return learning.summary.includes(entity.name); - }); - - if (!!selfReference) { - logger.info( - "Filtering out learning as it references entity in summary", - { - learning, - }, - ); - return false; - } - - return true; - }); - } - - if ( - env.SQS_LEARNING_INGEST_QUEUE_URL && - data.learnings && - data.learnings.length > 0 - ) { - await sqs - .sendMessage({ - QueueUrl: env.SQS_LEARNING_INGEST_QUEUE_URL, - MessageBody: JSON.stringify({ - clusterId: state.workflow.clusterId, - runId: state.workflow.id, - learnings: data.learnings, - ...injectTraceContext(), - }), - }) - .catch((e) => { - logger.error("Failed to send learning to SQS", { error: e }); - }); - } return { messages: [ @@ -411,7 +325,6 @@ const _handleModelCall = async ( id: ulid(), reasoning: invocation.reasoning as string | undefined, })), - learnings: data.learnings, issue: data.issue, result: data.result, message: typeof data.message === "string" ? data.message : undefined, diff --git a/control-plane/src/utilities/env.ts b/control-plane/src/utilities/env.ts index 25141a06..5286d62d 100644 --- a/control-plane/src/utilities/env.ts +++ b/control-plane/src/utilities/env.ts @@ -43,7 +43,6 @@ const envSchema = z SQS_RUN_PROCESS_QUEUE_URL: z.string(), SQS_RUN_GENERATE_NAME_QUEUE_URL: z.string(), - SQS_LEARNING_INGEST_QUEUE_URL: z.string().optional(), SQS_CUSTOMER_TELEMETRY_QUEUE_URL: z.string(), SQS_EXTERNAL_TOOL_CALL_QUEUE_URL: z.string(),