From fdde30630ffdfeddb08e742d47d3982519ea698c Mon Sep 17 00:00:00 2001 From: John Smith Date: Mon, 28 Oct 2024 15:01:36 +1030 Subject: [PATCH] feat: Add run creation to .NET SDK --- sdk-dotnet/src/API/APIClient.cs | 31 +++++++ sdk-dotnet/src/API/Models.cs | 86 +++++++++++++++++++ sdk-dotnet/src/Inferable.cs | 62 +++++++++++-- .../tests/Inferable.Tests/InferableTest.cs | 40 +++++++++ sdk-node/README.md | 18 ++-- sdk-node/src/Inferable.ts | 68 --------------- 6 files changed, 220 insertions(+), 85 deletions(-) diff --git a/sdk-dotnet/src/API/APIClient.cs b/sdk-dotnet/src/API/APIClient.cs index 03a68d3b..47c24e4c 100644 --- a/sdk-dotnet/src/API/APIClient.cs +++ b/sdk-dotnet/src/API/APIClient.cs @@ -63,6 +63,37 @@ async public Task CreateCallResult(string clusterId, string callId, CreateResult response.EnsureSuccessStatusCode(); } + async public Task CreateRun(string clusterId, CreateRunInput input) + { + string jsonData = JsonSerializer.Serialize(input); + + HttpResponseMessage response = await _client.PostAsync( + $"/clusters/{clusterId}/runs", + new StringContent(jsonData, Encoding.UTF8, "application/json") + ); + + response.EnsureSuccessStatusCode(); + + string responseBody = await response.Content.ReadAsStringAsync(); + var result = JsonSerializer.Deserialize(responseBody); + + return result; + } + + async public Task GetRun(string clusterId, string runId) + { + HttpResponseMessage response = await _client.GetAsync( + $"/clusters/{clusterId}/runs/{runId}" + ); + + response.EnsureSuccessStatusCode(); + + string responseBody = await response.Content.ReadAsStringAsync(); + var result = JsonSerializer.Deserialize(responseBody); + + return result; + } + async public Task<(List, int?)> ListCalls(string clusterId, string service) { HttpResponseMessage response = await _client.GetAsync( diff --git a/sdk-dotnet/src/API/Models.cs b/sdk-dotnet/src/API/Models.cs index 9256ea8c..aebd2573 100644 --- a/sdk-dotnet/src/API/Models.cs +++ b/sdk-dotnet/src/API/Models.cs @@ -112,4 +112,90 @@ public struct FunctionConfig public int? TimeoutSeconds { get; set; } } + public struct CreateRunInput + { + [JsonPropertyName("message")] + public string? Message { get; set; } + + [ + JsonPropertyName("attachedFunctions"), + JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull) + ] + public List? AttachedFunctions { get; set; } + + [ + JsonPropertyName("metadata"), + JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull) + ] + public Dictionary? Metadata { get; set; } + + [ + JsonPropertyName("result"), + JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull) + ] + public CreateRunResultInput? Result { get; set; } + + [ + JsonPropertyName("template"), + JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull) + ] + public CreateRunTemplateInput? Template { get; set; } + } + + public struct CreateRunTemplateInput + { + [JsonPropertyName("input")] + public Dictionary Input { get; set; } + + [JsonPropertyName("id")] + public string Id { get; set; } + } + + public struct CreateRunResultInput + { + [JsonPropertyName("handler")] + public CreateRunResultHandlerInput? Handler { get; set; } + + [JsonPropertyName("schema")] + public object? Schema { get; set; } + } + + public struct CreateRunResultHandlerInput + { + [JsonPropertyName("service")] + public string? Service { get; set; } + + [JsonPropertyName("function")] + public string? Function { get; set; } + } + + public struct CreateRunResult + { + [JsonPropertyName("id")] + public string ID { get; set; } + } + + public struct GetRunResult + { + [JsonPropertyName("id")] + public string ID { get; set; } + + [JsonPropertyName("status")] + public string Status { get; set; } + + [JsonPropertyName("failureReason")] + public string FailureReason { get; set; } + + [JsonPropertyName("summary")] + public string Summary { get; set; } + + [JsonPropertyName("result")] + public string Result { get; set; } + + [JsonPropertyName("attachedFunctions")] + public List AttachedFunctions { get; set; } + + [JsonPropertyName("metadata")] + public Dictionary Metadata { get; set; } + } } diff --git a/sdk-dotnet/src/Inferable.cs b/sdk-dotnet/src/Inferable.cs index 570bb504..308f36c4 100644 --- a/sdk-dotnet/src/Inferable.cs +++ b/sdk-dotnet/src/Inferable.cs @@ -13,20 +13,30 @@ public class InferableOptions { public string? BaseUrl { get; set; } public string? ApiSecret { get; set; } - /// - /// PingInterval in seconds - /// - public int? PingInterval { get; set; } public string? MachineId { get; set; } + public string? ClusterId { get; set; } } + public struct PollRunOptions + { + public required TimeSpan MaxWaitTime { get; set; } + public required TimeSpan Delay { get; set; } + } + + public class RunHandle + { + public required string ID { get; set; } + public required Func> Poll { get; set; } + } + public class InferableClient { public static string DefaultBaseUrl = "https://api.inferable.ai/"; private readonly ApiClient _client; private readonly ILogger _logger; + private readonly string? _clusterId; // Dictionary of service name to list of functions private Dictionary> _functionRegistry = new Dictionary>(); @@ -46,6 +56,7 @@ public InferableClient(InferableOptions? options = null, ILogger CreateRun(CreateRunInput input) + { + if (this._clusterId == null) { + throw new ArgumentException("Cluster ID must be provided to manage runs"); + } + + var result = await this._client.CreateRun(this._clusterId, input); + + return new RunHandle { + ID = result.ID, + Poll = async (PollRunOptions? options) => { + var MaxWaitTime = options?.MaxWaitTime ?? TimeSpan.FromSeconds(60); + var Delay = options?.Delay ?? TimeSpan.FromMilliseconds(500); + + var start = DateTime.Now; + while (DateTime.Now - start < MaxWaitTime) { + var pollResult = await this._client.GetRun(this._clusterId, result.ID); + + var transientStates = new List { "pending", "running" }; + if (transientStates.Contains(pollResult.Status)) { + await Task.Delay(Delay); + } + + return pollResult; + } + return null; + } + }; + } + public IEnumerable ActiveServices { get @@ -138,6 +179,12 @@ internal async Task StopService(string name) { } } + public struct FunctionHandle + { + public required string Service { get; set; } + public required string Function { get; set; } + } + public struct RegisteredService { private string _name; @@ -148,8 +195,13 @@ internal RegisteredService(string name, InferableClient inferable) { this._inferable = inferable; } - public void RegisterFunction(FunctionRegistration function) where T : struct { + public FunctionHandle RegisterFunction(FunctionRegistration function) where T : struct { this._inferable.RegisterFunction(this._name, function); + + return new FunctionHandle { + Service = this._name, + Function = function.Name + }; } async public Task Start() { diff --git a/sdk-dotnet/tests/Inferable.Tests/InferableTest.cs b/sdk-dotnet/tests/Inferable.Tests/InferableTest.cs index 9b7538d4..d49f1ec4 100644 --- a/sdk-dotnet/tests/Inferable.Tests/InferableTest.cs +++ b/sdk-dotnet/tests/Inferable.Tests/InferableTest.cs @@ -214,6 +214,46 @@ async public void Inferable_Can_Handle_Functions_Failure() await inferable.Default.Stop(); } } + + [Fact] + async public void Inferable_Can_Trigger_Runs() + { + var inferable = CreateInferableClient(); + + var registration = new FunctionRegistration + { + Name = "successFunction", + Func = new Func((input) => + { + Console.WriteLine("Executing successFunction"); + return "This is a test response"; + }) + }; + + inferable.Default.RegisterFunction(registration); + + try + { + await inferable.Default.Start(); + + var run = await inferable.CreateRun(new CreateRunInput + { + Message = "Call the successFunction", + AttachedFunctions = new List + { + "default_successFunction" + } + }); + + var result = await run.Poll(null); + + Assert.NotNull(result); + } + finally + { + await inferable.Default.Stop(); + } + } } //TODO: Test transient /call failures diff --git a/sdk-node/README.md b/sdk-node/README.md index c5e2d885..62fd843f 100644 --- a/sdk-node/README.md +++ b/sdk-node/README.md @@ -35,36 +35,30 @@ pnpm add inferable ### 1. Initializing Inferable -Create a file named i.ts which will be used to initialize Inferable. This file will export the Inferable instance. +Initialize the Inferable client with your API secret. ```typescript -// d.ts - import { Inferable } from "inferable"; // Initialize the Inferable client with your API secret. // Get yours at https://console.inferable.ai. -export const d = new Inferable({ +export const client = new Inferable({ apiSecret: "YOUR_API_SECRET", }); ``` ### 2. Hello World Function -In a separate file, register a "sayHello" [function](https://docs.inferable.ai/pages/functions). This file will import the Inferable instance from `i.ts` and register the [function](https://docs.inferable.ai/pages/functions) with the [control-plane](https://docs.inferable.ai/pages/control-plane). +Register a "sayHello" [function](https://docs.inferable.ai/pages/functions). This file will import the Inferable instance from `i.ts` and register the [function](https://docs.inferable.ai/pages/functions) with the [control-plane](https://docs.inferable.ai/pages/control-plane). ```typescript -// service.ts - -import { i } from "./i"; - // Define a simple function that returns "Hello, World!" const sayHello = async ({ to }: { to: string }) => { return `Hello, ${to}!`; }; // Register the service (using the 'default' service) -const sayHello = i.default.register({ +const sayHello = client.default.register({ name: "sayHello", func: sayHello, schema: { @@ -75,7 +69,7 @@ const sayHello = i.default.register({ }); // Start the 'default' service -i.default.start(); +client.default.start(); ``` ### 3. Running the Service @@ -96,7 +90,7 @@ The following code will create an [Inferable run](https://docs.inferable.ai/page > - in the [CLI](https://www.npmjs.com/package/@inferable/cli) via `inf runs list` ```typescript -const run = await i.run({ +const run = await client.run({ message: "Say hello to John", functions: [sayHello], // Alternatively, subscribe an Inferable function as a result handler which will be called when the run is complete. diff --git a/sdk-node/src/Inferable.ts b/sdk-node/src/Inferable.ts index 0759370c..1d10f924 100644 --- a/sdk-node/src/Inferable.ts +++ b/sdk-node/src/Inferable.ts @@ -49,10 +49,6 @@ type TemplateRunInput = Omit & { input: Record; }; -type UpsertTemplateInput = Required< - Parameters["upsertPromptTemplate"]>[0] - >["body"] & { id: string, structuredOutput: z.ZodTypeAny }; - /** * The Inferable client. This is the main entry point for using Inferable. * @@ -229,70 +225,6 @@ export class Inferable { }); } - /** - * Registers or references a template instance. This can be used to trigger runs of a template. - * @param input The template definition or reference. - * @returns A registered template instance. - * @example - * ```ts - * const d = new Inferable({apiSecret: "API_SECRET"}); - * - * const template = await d.template({ - * id: "new-template-id", - * name: "my-template", - * attachedFunctions: ["my-service.hello"], - * prompt: "Hello {{name}}", - * structuredOutput: { greeting: z.string() } - * }); - * - * await template.run({ input: { name: "Jane Doe" } }); - * ``` - */ - public async template(input: UpsertTemplateInput) { - if (!this.clusterId) { - throw new InferableError( - "Cluster ID must be provided to manage templates", - ); - } - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let jsonSchema: any = undefined; - - if (!!input.structuredOutput) { - try { - jsonSchema = zodToJsonSchema(input.structuredOutput); - } catch (e) { - throw new InferableError("structuredOutput must be a valid JSON schema"); - } - } - - const upserted = await this.client.upsertPromptTemplate({ - body: { - ...input, - structuredOutput: jsonSchema, - }, - params: { - clusterId: this.clusterId, - templateId: input.id, - }, - }); - - if (upserted.status != 200) { - throw new InferableError(`Failed to register prompt template`, { - body: upserted.body, - status: upserted.status, - }); - } - - return { - id: input.id, - run: (input: TemplateRunInput) => - this.run({ - ...input, - template: { id: upserted.body.id, input: input.input }, - }), - }; - } /** * Creates a template reference. This can be used to trigger runs of a template that was previously registered via the UI.