diff --git a/control-plane/src/modules/jobs/create-job.ts b/control-plane/src/modules/jobs/create-job.ts index e55ff8d6..563a9685 100644 --- a/control-plane/src/modules/jobs/create-job.ts +++ b/control-plane/src/modules/jobs/create-job.ts @@ -11,7 +11,7 @@ import { getServiceDefinition, parseJobArgs, } from "../service-definitions"; -import { extractWithPath } from "../util"; +import { extractWithJsonPath } from "../util"; import { externalServices } from "./external"; import { env } from "../../utilities/env"; import { injectTraceContext } from "../observability/tracer"; @@ -34,12 +34,15 @@ type CreateJobParams = { const DEFAULT_RETRY_COUNT_ON_STALL = 0; -const extractKeyFromPath = (path: string, args: unknown) => { +const extractCacheKeyFromJsonPath = (path: string, args: unknown) => { try { - return extractWithPath(path, args)[0]; + return extractWithJsonPath(path, args)[0]; } catch (error) { if (error instanceof NotFoundError) { - throw new InvalidJobArgumentsError(error.message); + throw new InvalidJobArgumentsError( + `Failed to extract cache key from arguments: ${error.message}`, + "https://docs.inferable.ai/pages/functions#config-cache" + ); } throw error; } @@ -94,7 +97,7 @@ export const createJob = async (params: { }; if (config?.cache?.keyPath && config?.cache?.ttlSeconds) { - const cacheKey = extractKeyFromPath(config.cache.keyPath, args); + const cacheKey = extractCacheKeyFromJsonPath(config.cache.keyPath, args); const { id, created } = await createJobStrategies.cached({ ...jobConfig, diff --git a/control-plane/src/modules/service-definitions.test.ts b/control-plane/src/modules/service-definitions.test.ts index 3109d8f7..793937a3 100644 --- a/control-plane/src/modules/service-definitions.test.ts +++ b/control-plane/src/modules/service-definitions.test.ts @@ -340,7 +340,7 @@ describe("validateServiceRegistration", () => { z.object({ test: z.string(), }) - )) + )), }, ], }, @@ -370,4 +370,47 @@ describe("validateServiceRegistration", () => { }).not.toThrow(); }) + it("should reject invalid cache.keyPath jsonpath", () => { + expect(() => { + validateServiceRegistration({ + service: "default", + definition: { + name: "default", + functions: [ + { + name: "myFn", + config: { + cache: { + keyPath: "$invalid", + ttlSeconds: 10 + } + } + }, + ], + }, + }); + }).toThrow(InvalidServiceRegistrationError); + }) + + it("should accept valid cache.keyPath jsonpath", () => { + expect(() => { + validateServiceRegistration({ + service: "default", + definition: { + name: "default", + functions: [ + { + name: "myFn", + config: { + cache: { + keyPath: "$.someKey", + ttlSeconds: 10 + } + } + }, + ], + }, + }); + }).not.toThrow(); + }) }) diff --git a/control-plane/src/modules/service-definitions.ts b/control-plane/src/modules/service-definitions.ts index 0ad5e2bd..2d252153 100644 --- a/control-plane/src/modules/service-definitions.ts +++ b/control-plane/src/modules/service-definitions.ts @@ -1,6 +1,5 @@ import { and, eq, lte } from "drizzle-orm"; import { - handleCustomerAuthSchema, validateDescription, validateFunctionName, validateFunctionSchema, @@ -20,6 +19,7 @@ import { embeddableEntitiy } from "./embeddings/embeddings"; import { logger } from "./observability/logger"; import { packer } from "./packer"; import { withThrottle } from "./util"; +import jsonpath from "jsonpath"; // The time without a ping before a service is considered expired const SERVICE_LIVE_THRESHOLD_MS = 30 * 60 * 1000; // 30 minutes @@ -400,10 +400,20 @@ export const validateServiceRegistration = ({ } } - const VERIFY_FUNCTION_NAME = "handleCustomerAuth"; - const VERIFY_FUNCTION_SERVICE = "default"; + if (fn.config?.cache) { + try { + jsonpath.parse(fn.config.cache.keyPath); + } catch { + throw new InvalidServiceRegistrationError( + `${fn.name} cache.keyPath is invalid`, + "https://docs.inferable.ai/pages/functions#config-cache" + ) + } + } // Checks for customer auth handler + const VERIFY_FUNCTION_NAME = "handleCustomerAuth"; + const VERIFY_FUNCTION_SERVICE = "default"; if (service === VERIFY_FUNCTION_SERVICE && fn.name === VERIFY_FUNCTION_NAME) { if (!fn.schema) { throw new InvalidServiceRegistrationError( diff --git a/control-plane/src/modules/util.ts b/control-plane/src/modules/util.ts index 4838f462..e8095651 100644 --- a/control-plane/src/modules/util.ts +++ b/control-plane/src/modules/util.ts @@ -2,7 +2,7 @@ import jsonpath from "jsonpath"; import { NotFoundError } from "../utilities/errors"; import { redisClient } from "./redis"; -export const extractWithPath = (path: string, args: unknown) => { +export const extractWithJsonPath = (path: string, args: unknown) => { const result = jsonpath.query(args, path); if (!result || result.length === 0) { throw new NotFoundError(`Path ${path} not found within input`);