Skip to content

Commit

Permalink
feat(sdk-node): Auth context for functions (#72)
Browse files Browse the repository at this point in the history
* chore(sdk-node): Remove authenticate from registrations

* chore(sdk-node): Customer auth context support
  • Loading branch information
johnjcsmith authored Nov 7, 2024
1 parent 7534f50 commit 3609beb
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 88 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ For language-specific quick start guides, please refer to the README in each SDK
| Call [Timeouts](https://docs.inferable.ai/pages/functions#config-timeoutseconds) ||||
| Call [Retries](https://docs.inferable.ai/pages/functions#config-retrycountonstall) ||||
| Call [Approval](https://docs.inferable.ai/pages/functions#config-requiresapproval) (Human in the loop) ||||
| Auth Context ||||

## Documentation

Expand Down
6 changes: 0 additions & 6 deletions sdk-node/src/Inferable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ describe("Inferable", () => {
}),
},
description: "echoes the input",
authenticate: (ctx, args) => {
return args.foo === ctx ? Promise.resolve() : Promise.reject();
},
});

expect(d.registeredFunctions).toEqual(["echo"]);
Expand All @@ -116,9 +113,6 @@ describe("Inferable", () => {
}),
},
description: "echoes the input",
authenticate: (ctx, args) => {
return args.foo === ctx ? Promise.resolve() : Promise.reject();
},
});

expect(d.activeServices).toEqual([]);
Expand Down
11 changes: 2 additions & 9 deletions sdk-node/src/Inferable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import * as links from "./links";
import { machineId } from "./machine-id";
import { Service, registerMachine } from "./service";
import {
ContextInput,
FunctionConfig,
FunctionInput,
FunctionRegistration,
Expand Down Expand Up @@ -346,11 +347,9 @@ export class Inferable {
schema,
config,
description,
authenticate,
}) => {
this.registerFunction({
name,
authenticate,
serviceName: input.name,
func,
inputSchema: schema.input,
Expand Down Expand Up @@ -413,21 +412,16 @@ export class Inferable {

private registerFunction<T extends z.ZodTypeAny | JsonSchemaInput>({
name,
authenticate,
serviceName,
func,
inputSchema,
config,
description,
}: {
authenticate?: (
authContext: string,
args: FunctionInput<T>,
) => Promise<void>;
name: string;
serviceName: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
func: (input: FunctionInput<T>) => any;
func: (input: FunctionInput<T>, context: ContextInput) => any;
inputSchema: T;
config?: FunctionConfig;
description?: string;
Expand Down Expand Up @@ -464,7 +458,6 @@ export class Inferable {

const registration: FunctionRegistration<T> = {
name,
authenticate,
serviceName,
func,
schema: {
Expand Down
17 changes: 11 additions & 6 deletions sdk-node/src/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const machineHeaders = {
};

// Alphanumeric, underscore, hyphen, no whitespace. From 6 to 128 characters.
const userDefinedIdRegex = /^[a-zA-Z0-9-]{6,128}$/;
const userDefinedIdRegex = /^[a-zA-Z0-9-_]{6,128}$/;

const functionReference = z.object({
service: z.string(),
Expand Down Expand Up @@ -426,7 +426,7 @@ export const definition = {
),
})
.optional()
.describe("A prompt template which the run should be created from"),
.describe("DEPRECATED"),
reasoningTraces: z
.boolean()
.default(true)
Expand Down Expand Up @@ -582,7 +582,6 @@ export const definition = {
responses: {
200: z.object({
id: z.string(),
jobHandle: z.string().nullable(),
userId: z.string().nullable(),
status: z
.enum(["pending", "running", "paused", "done", "failed"])
Expand Down Expand Up @@ -900,6 +899,7 @@ export const definition = {
runId: z.string(),
}),
responses: {
404: z.undefined(),
200: z.object({
messages: z.array(
z.object({
Expand Down Expand Up @@ -955,7 +955,6 @@ export const definition = {
),
run: z.object({
id: z.string(),
jobHandle: z.string().nullable(),
userId: z.string().nullable(),
status: z
.enum(["pending", "running", "paused", "done", "failed"])
Expand Down Expand Up @@ -1232,7 +1231,9 @@ export const definition = {
z.object({
id: z.string(),
data: z.string(),
tags: z.array(z.string()),
tags: z
.array(z.string())
.transform((tags) => tags.map((tag) => tag.toLowerCase().trim())),
title: z.string(),
}),
),
Expand All @@ -1252,6 +1253,7 @@ export const definition = {
query: z.object({
query: z.string(),
limit: z.coerce.number().min(1).max(50).default(5),
tag: z.string().optional(),
}),
responses: {
200: z.array(
Expand All @@ -1275,7 +1277,9 @@ export const definition = {
headers: z.object({ authorization: z.string() }),
body: z.object({
data: z.string(),
tags: z.array(z.string()),
tags: z
.array(z.string())
.transform((tags) => tags.map((tag) => tag.toLowerCase().trim())),
title: z.string(),
}),
responses: {
Expand Down Expand Up @@ -1396,6 +1400,7 @@ export const definition = {
id: z.string(),
function: z.string(),
input: z.any(),
customerAuthContext: z.any().nullable(),
}),
),
},
Expand Down
47 changes: 0 additions & 47 deletions sdk-node/src/execute-fn.test.ts
Original file line number Diff line number Diff line change
@@ -1,54 +1,7 @@
import { executeFn } from "./execute-fn";

describe("executeFn", () => {
it("should run a function with arguments", async () => {
const fn = (val: { [key: string]: string }) => Promise.resolve(val.foo);
const result = await fn({ foo: "bar" });
expect(result).toBe("bar");
});

it("should authenticate a function with valid context", async () => {
const fn = (val: { [key: string]: string }) => Promise.resolve(val.foo);
const args = { foo: "bar" };
const authenticate = (
authContext: string,
args: { [key: string]: string },
) => {
return args.foo === authContext
? Promise.resolve()
: Promise.reject(new Error("Unauthorized"));
};

const result = executeFn(fn, [args], authenticate, "bar");

await expect(result).resolves.toEqual({
content: "bar",
functionExecutionTime: expect.any(Number),
type: "resolution",
});
});

it("should authenticate a function with invalid context", async () => {
const fn = (val: { [key: string]: string }) => Promise.resolve(val.foo);
const args = { foo: "bar" };
const authenticate = (
authContext: string,
args: { [key: string]: string },
) => {
return args.foo === authContext
? Promise.resolve()
: Promise.reject(new Error("Unauthorized"));
};

const result = executeFn(fn, [args], authenticate, "not-bar");

await expect(result).resolves.toEqual(
expect.objectContaining({
type: "rejection",
content: expect.objectContaining({
message: "Unauthorized",
}),
}),
);
});
});
14 changes: 0 additions & 14 deletions sdk-node/src/execute-fn.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { InferableError } from "./errors";
import { serializeError } from "./serialize-error";
import { FunctionRegistration } from "./types";

Expand All @@ -11,22 +10,9 @@ export type Result<T = unknown> = {
export const executeFn = async (
fn: FunctionRegistration["func"],
args: Parameters<FunctionRegistration["func"]>,
authenticate?: (
authContext: string,
args: Parameters<FunctionRegistration["func"]>["0"],
) => Promise<void>,
authContext?: string,
): Promise<Result> => {
const start = Date.now();
try {
if (authenticate) {
if (!authContext) {
throw new InferableError(InferableError.JOB_AUTHCONTEXT_INVALID);
}

await authenticate(authContext, args[0]);
}

const result = await fn(...args);

return {
Expand Down
6 changes: 4 additions & 2 deletions sdk-node/src/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type CallMessage = {
id: string;
function: string;
input?: unknown;
customerAuthContext?: unknown;
};

export class Service {
Expand Down Expand Up @@ -268,8 +269,9 @@ export class Service {

const result = await executeFn(
registration.func,
[args],
registration.authenticate,
[args, {
customerAuthContext: call.customerAuthContext,
}],
);

await onComplete(result);
Expand Down
13 changes: 9 additions & 4 deletions sdk-node/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import { z } from "zod";
import { FunctionConfigSchema } from "./contract";

/**
* Context object which is passed to function calls
*/
export type ContextInput = {
customerAuthContext?: unknown;
}

export type FunctionConfig = z.infer<typeof FunctionConfigSchema>;

export type FunctionInput<T extends z.ZodTypeAny | JsonSchemaInput> =
Expand Down Expand Up @@ -53,9 +60,8 @@ export type FunctionRegistrationInput<
T extends z.ZodTypeAny | JsonSchemaInput,
> = {
name: string;
authenticate?: (authContext: string, args: FunctionInput<T>) => Promise<void>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
func: (input: FunctionInput<T>) => any;
func: (input: FunctionInput<T>, context: ContextInput) => any;
schema: FunctionSchema<T>;
config?: FunctionConfig;
description?: string;
Expand Down Expand Up @@ -96,13 +102,12 @@ export interface FunctionRegistration<
T extends JsonSchemaInput | z.ZodTypeAny = any,
> {
name: string;
authenticate?: (authContext: string, args: FunctionInput<T>) => Promise<void>;
serviceName: string;
description?: string;
schema: {
input: T;
inputJson: string;
};
func: (args: FunctionInput<T>) => any;
func: (args: FunctionInput<T>, context: ContextInput) => any;
config?: FunctionConfig;
}

0 comments on commit 3609beb

Please sign in to comment.