diff --git a/control-plane/src/modules/auth/api-secret.ts b/control-plane/src/modules/auth/api-secret.ts index 0c5cbf6c..239f77fa 100644 --- a/control-plane/src/modules/auth/api-secret.ts +++ b/control-plane/src/modules/auth/api-secret.ts @@ -1,21 +1,17 @@ import * as data from "../data"; import { eq, and, isNull } from "drizzle-orm"; -import { createHash, randomBytes } from "crypto"; +import { randomBytes } from "crypto"; import { logger } from "../observability/logger"; -import { createCache } from "../../utilities/cache"; +import { createCache, hashFromSecret } from "../../utilities/cache"; -const authContextCache = createCache<{ +const apiKeyContextCache = createCache<{ clusterId: string; id: string; organizationId: string; }>( - Symbol("authContextCach"), + Symbol("apiKeyContextCache"), ); -const hashFromSecret = (secret: string): string => { - return createHash("sha256").update(secret).digest("hex"); -}; - export const isApiSecret = (authorization: string): boolean => authorization.startsWith("sk_"); @@ -26,7 +22,7 @@ export const verifyApiKey = async ( > => { const secretHash = hashFromSecret(secret); - const cached = await authContextCache.get(secretHash); + const cached = await apiKeyContextCache.get(secretHash); if (cached) { return cached; @@ -60,7 +56,7 @@ export const verifyApiKey = async ( return undefined; } - await authContextCache.set(secretHash, { + await apiKeyContextCache.set(secretHash, { clusterId: result.clusterId, id: result.id, organizationId: result.organizationId, diff --git a/control-plane/src/modules/auth/auth.ts b/control-plane/src/modules/auth/auth.ts index c904223a..c02a8470 100644 --- a/control-plane/src/modules/auth/auth.ts +++ b/control-plane/src/modules/auth/auth.ts @@ -373,6 +373,7 @@ export const extractCustomerAuthState = async ( if (!cluster.enable_customer_auth) { throw new AuthenticationError( "Customer auth is not enabled for this cluster", + "https://docs.inferable.ai/pages/auth#customer-provided-secrets" ); } diff --git a/control-plane/src/modules/auth/customer-auth.ts b/control-plane/src/modules/auth/customer-auth.ts index 98baf325..8dec119d 100644 --- a/control-plane/src/modules/auth/customer-auth.ts +++ b/control-plane/src/modules/auth/customer-auth.ts @@ -7,11 +7,17 @@ import { packer } from "../packer"; import * as jobs from "../jobs/jobs"; import { getJobStatusSync } from "../jobs/jobs"; import { getServiceDefinition } from "../service-definitions"; +import { createCache, hashFromSecret } from "../../utilities/cache"; +import { logger } from "../observability/logger"; export const VERIFY_FUNCTION_NAME = "handleCustomerAuth"; export const VERIFY_FUNCTION_SERVICE = "default"; const VERIFY_FUNCTION_ID = `${VERIFY_FUNCTION_SERVICE}_${VERIFY_FUNCTION_NAME}`; +const customerAuthContextCache = createCache( + Symbol("customerAuthContextCache"), +); + /** * Calls the customer provided verify function and returns the result */ @@ -22,6 +28,20 @@ export const verifyCustomerProvidedAuth = async ({ token: string; clusterId: string; }): Promise => { + + const secretHash = hashFromSecret(`${clusterId}:${token}`); + + const cached = await customerAuthContextCache.get(secretHash); + if (cached) { + if (typeof cached === "object" && 'error' in cached && typeof cached.error === "string") { + throw new AuthenticationError( + cached.error, + "https://docs.inferable.ai/pages/auth#handlecustomerauth" + ); + } + return cached; + } + try { const serviceDefinition = await getServiceDefinition({ service: VERIFY_FUNCTION_SERVICE, @@ -56,32 +76,49 @@ export const verifyCustomerProvidedAuth = async ({ const result = await getJobStatusSync({ jobId: id, owner: { clusterId }, - ttl: 5_000, + ttl: 15_000, }); - if ( - result.status !== "success" || - result.resultType !== "resolution" || - !result.result - ) { + if (result.status == "success" && result.resultType !== "resolution") { throw new AuthenticationError( - `Call to ${VERIFY_FUNCTION_ID} failed. Result: ${result.result}`, + "Customer provided token is not valid", + "https://docs.inferable.ai/pages/auth#handlecustomerauth" ); } + // This isn't expected + if (result.status != "success") { + throw new Error( + `Failed to call ${VERIFY_FUNCTION_ID}: ${result.result}`, + ); + } + + if (!result.result) { + throw new AuthenticationError( + `${VERIFY_FUNCTION_ID} did not return a result`, + "https://docs.inferable.ai/pages/auth#handlecustomerauth" + ); + } + + await customerAuthContextCache.set(secretHash, result, 300); + return packer.unpack(result.result); } catch (e) { if (e instanceof JobPollTimeoutError) { throw new AuthenticationError( `Call to ${VERIFY_FUNCTION_ID} did not complete in time`, + "https://docs.inferable.ai/pages/auth#handlecustomerauth" ); } - if (e instanceof InvalidJobArgumentsError) { - throw new AuthenticationError( - `Could not find ${VERIFY_FUNCTION_ID} registration`, - ); + // Cache the auth error for 1 minutes + if (e instanceof AuthenticationError) { + await customerAuthContextCache.set(secretHash, { + error: e.message + }, 60); + throw e; } + throw e; } }; diff --git a/control-plane/src/utilities/cache.ts b/control-plane/src/utilities/cache.ts index 2ca4bf02..810c9b7d 100644 --- a/control-plane/src/utilities/cache.ts +++ b/control-plane/src/utilities/cache.ts @@ -1,3 +1,4 @@ +import { createHash } from "crypto"; import NodeCache from "node-cache"; import { redisClient } from "../modules/redis"; @@ -37,3 +38,8 @@ export const createCache = (namespace: symbol) => { // const cache = createCache(Symbol("cache")); // cache.set("key", "value"); // const value = cache.get("key"); + +export const hashFromSecret = (secret: string): string => { + return createHash("sha256").update(secret).digest("hex"); +}; +