Skip to content

Commit

Permalink
feat: Cache customer auth context (#278)
Browse files Browse the repository at this point in the history
* chore: Cache customer auth context

* chore: Add docLinks to customer auth errors

* feat: Cache customer auth failures

* chore: Deliniate cache key via clusterId
  • Loading branch information
johnjcsmith authored Dec 11, 2024
1 parent 99776d8 commit f8de989
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 21 deletions.
16 changes: 6 additions & 10 deletions control-plane/src/modules/auth/api-secret.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import * as data from "../data";
import { eq, and, isNull } from "drizzle-orm";
import { createHash, randomBytes } from "crypto";
import { randomBytes } from "crypto";
import { logger } from "../observability/logger";
import { createCache } from "../../utilities/cache";
import { createCache, hashFromSecret } from "../../utilities/cache";

const authContextCache = createCache<{
const apiKeyContextCache = createCache<{
clusterId: string;
id: string;
organizationId: string;
}>(
Symbol("authContextCach"),
Symbol("apiKeyContextCache"),
);

const hashFromSecret = (secret: string): string => {
return createHash("sha256").update(secret).digest("hex");
};

export const isApiSecret = (authorization: string): boolean =>
authorization.startsWith("sk_");

Expand All @@ -26,7 +22,7 @@ export const verifyApiKey = async (
> => {
const secretHash = hashFromSecret(secret);

const cached = await authContextCache.get(secretHash);
const cached = await apiKeyContextCache.get(secretHash);

if (cached) {
return cached;
Expand Down Expand Up @@ -60,7 +56,7 @@ export const verifyApiKey = async (
return undefined;
}

await authContextCache.set(secretHash, {
await apiKeyContextCache.set(secretHash, {
clusterId: result.clusterId,
id: result.id,
organizationId: result.organizationId,
Expand Down
1 change: 1 addition & 0 deletions control-plane/src/modules/auth/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ export const extractCustomerAuthState = async (
if (!cluster.enable_customer_auth) {
throw new AuthenticationError(
"Customer auth is not enabled for this cluster",
"https://docs.inferable.ai/pages/auth#customer-provided-secrets"
);
}

Expand Down
59 changes: 48 additions & 11 deletions control-plane/src/modules/auth/customer-auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@ import { packer } from "../packer";
import * as jobs from "../jobs/jobs";
import { getJobStatusSync } from "../jobs/jobs";
import { getServiceDefinition } from "../service-definitions";
import { createCache, hashFromSecret } from "../../utilities/cache";
import { logger } from "../observability/logger";

export const VERIFY_FUNCTION_NAME = "handleCustomerAuth";
export const VERIFY_FUNCTION_SERVICE = "default";
const VERIFY_FUNCTION_ID = `${VERIFY_FUNCTION_SERVICE}_${VERIFY_FUNCTION_NAME}`;

const customerAuthContextCache = createCache<unknown>(
Symbol("customerAuthContextCache"),
);

/**
* Calls the customer provided verify function and returns the result
*/
Expand All @@ -22,6 +28,20 @@ export const verifyCustomerProvidedAuth = async ({
token: string;
clusterId: string;
}): Promise<unknown> => {

const secretHash = hashFromSecret(`${clusterId}:${token}`);

const cached = await customerAuthContextCache.get(secretHash);
if (cached) {
if (typeof cached === "object" && 'error' in cached && typeof cached.error === "string") {
throw new AuthenticationError(
cached.error,
"https://docs.inferable.ai/pages/auth#handlecustomerauth"
);
}
return cached;
}

try {
const serviceDefinition = await getServiceDefinition({
service: VERIFY_FUNCTION_SERVICE,
Expand Down Expand Up @@ -56,32 +76,49 @@ export const verifyCustomerProvidedAuth = async ({
const result = await getJobStatusSync({
jobId: id,
owner: { clusterId },
ttl: 5_000,
ttl: 15_000,
});

if (
result.status !== "success" ||
result.resultType !== "resolution" ||
!result.result
) {
if (result.status == "success" && result.resultType !== "resolution") {
throw new AuthenticationError(
`Call to ${VERIFY_FUNCTION_ID} failed. Result: ${result.result}`,
"Customer provided token is not valid",
"https://docs.inferable.ai/pages/auth#handlecustomerauth"
);
}

// This isn't expected
if (result.status != "success") {
throw new Error(
`Failed to call ${VERIFY_FUNCTION_ID}: ${result.result}`,
);
}

if (!result.result) {
throw new AuthenticationError(
`${VERIFY_FUNCTION_ID} did not return a result`,
"https://docs.inferable.ai/pages/auth#handlecustomerauth"
);
}

await customerAuthContextCache.set(secretHash, result, 300);

return packer.unpack(result.result);
} catch (e) {
if (e instanceof JobPollTimeoutError) {
throw new AuthenticationError(
`Call to ${VERIFY_FUNCTION_ID} did not complete in time`,
"https://docs.inferable.ai/pages/auth#handlecustomerauth"
);
}

if (e instanceof InvalidJobArgumentsError) {
throw new AuthenticationError(
`Could not find ${VERIFY_FUNCTION_ID} registration`,
);
// Cache the auth error for 1 minutes
if (e instanceof AuthenticationError) {
await customerAuthContextCache.set(secretHash, {
error: e.message
}, 60);
throw e;
}

throw e;
}
};
6 changes: 6 additions & 0 deletions control-plane/src/utilities/cache.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { createHash } from "crypto";
import NodeCache from "node-cache";
import { redisClient } from "../modules/redis";

Expand Down Expand Up @@ -37,3 +38,8 @@ export const createCache = <T>(namespace: symbol) => {
// const cache = createCache(Symbol("cache"));
// cache.set("key", "value");
// const value = cache.get("key");

export const hashFromSecret = (secret: string): string => {
return createHash("sha256").update(secret).digest("hex");
};

0 comments on commit f8de989

Please sign in to comment.