Skip to content

Commit

Permalink
feat: Add run creation to .NET SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith committed Oct 31, 2024
1 parent b1b204e commit e97f0be
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 30 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,9 @@ jobs:
- name: Test
run: dotnet test --no-restore
env:
INFERABLE_API_ENDPOINT: "https://api.inferable.ai"
INFERABLE_CLUSTER_ID: ${{ secrets.INFERABLE_CLUSTER_ID }}
INFERABLE_MACHINE_SECRET: ${{ secrets.INFERABLE_MACHINE_SECRET }}
INFERABLE_CONSUME_SECRET: ${{ secrets.INFERABLE_CONSUME_SECRET }}
INFERABLE_TEST_API_ENDPOINT: "https://api.inferable.ai"
INFERABLE_TEST_CLUSTER_ID: ${{ secrets.INFERABLE_CLUSTER_ID }}
INFERABLE_TEST_API_SECRET: ${{ secrets.INFERABLE_MACHINE_SECRET }}

build-go:
needs: check_changes
Expand Down
38 changes: 25 additions & 13 deletions sdk-dotnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,45 @@ client.Default.RegisterFunction(new FunctionRegistration<MyInput>
Function = new Func<MyInput, MyResult>>((input) => {
// Your code here
}),
Name = "MyFunction",
Name = "SayHello",
Description = "A simple greeting function",
});

await client.Default.Start();
```

### Starting and Stopping a Service
### 3. Trigger a run

The example above used the Default service, you can also register separate named services.
The following code will create an [Inferable run](https://docs.inferable.ai/pages/runs) with the message "Call the testFn" and the `TestFn` function attached.

> You can inspect the progress of the run:
>
> - in the [playground UI](https://app.inferable.ai/) via `inf app`
> - in the [CLI](https://www.npmjs.com/package/@inferable/cli) via `inf runs list`
```csharp
var userService = client.RegisterService(new ServiceRegistration
var run = await inferable.CreateRun(new CreateRunInput
{
Name = "UserService",
Message = "Call the testFn",
AttachedFunctions = new List<FunctionReference>
{
new FunctionReference {
Function = "TestFn",
Service = "default"
}
},
// Optional: Subscribe an Inferable function to receive notifications when the run status changes
//OnStatusChange = new CreateOnStatusChangeInput
//{
// Function = OnStatusChangeFunction
//}
});

userService.RegisterFunction(....)

await userService.Start();
// Wait for the run to complete and log.
var result = await run.Poll(null);
```

To stop the service, use:

```csharp
await userService.StopAsync();
```
> Runs can also be triggered via the [API](https://docs.inferable.ai/pages/invoking-a-run-api), [CLI](https://www.npmjs.com/package/@inferable/cli) or [playground UI](https://app.inferable.ai/).
## Contributing

Expand Down
74 changes: 70 additions & 4 deletions sdk-dotnet/src/API/APIClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ public ApiClient(ApiClientOptions options)
_client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("bearer", options.ApiSecret);
}

async private Task RethrowWithContext(HttpRequestException e, HttpResponseMessage response)
{
throw new Exception($"Failed to get run. Response: {await response.Content.ReadAsStringAsync()}", e);
}



async public Task<CreateMachineResult> CreateMachine(CreateMachineInput input)
{
Expand All @@ -45,7 +51,11 @@ async public Task<CreateMachineResult> CreateMachine(CreateMachineInput input)
new StringContent(jsonData, Encoding.UTF8, "application/json")
);

response.EnsureSuccessStatusCode();
try {
response.EnsureSuccessStatusCode();
} catch (HttpRequestException e) {
await RethrowWithContext(e, response);
}

string responseBody = await response.Content.ReadAsStringAsync();
return JsonSerializer.Deserialize<CreateMachineResult>(responseBody);
Expand All @@ -60,7 +70,55 @@ async public Task CreateCallResult(string clusterId, string callId, CreateResult
new StringContent(jsonData, Encoding.UTF8, "application/json")
);

response.EnsureSuccessStatusCode();
try {
response.EnsureSuccessStatusCode();
} catch (HttpRequestException e) {
await RethrowWithContext(e, response);
}
}

async public Task<CreateRunResult> CreateRun(string clusterId, CreateRunInput input)
{
string jsonData = JsonSerializer.Serialize(input);

HttpResponseMessage response = await _client.PostAsync(
$"/clusters/{clusterId}/runs",
new StringContent(jsonData, Encoding.UTF8, "application/json")
);

try {
response.EnsureSuccessStatusCode();
} catch (HttpRequestException e) {
await RethrowWithContext(e, response);
}

string responseBody = await response.Content.ReadAsStringAsync();
var result = JsonSerializer.Deserialize<CreateRunResult>(responseBody);

return result;
}

async public Task<GetRunResult> GetRun(string clusterId, string runId)
{
HttpResponseMessage response = await _client.GetAsync(
$"/clusters/{clusterId}/runs/{runId}"
);

try {
try {
response.EnsureSuccessStatusCode();
} catch (HttpRequestException e) {
await RethrowWithContext(e, response);
}

} catch (HttpRequestException e) {
throw new Exception($"Failed to get run. Status Code: {response.StatusCode}, Response: {await response.Content.ReadAsStringAsync()}", e);
}

string responseBody = await response.Content.ReadAsStringAsync();
var result = JsonSerializer.Deserialize<GetRunResult>(responseBody);

return result;
}

async public Task<(List<CallMessage>, int?)> ListCalls(string clusterId, string service)
Expand All @@ -69,7 +127,11 @@ async public Task CreateCallResult(string clusterId, string callId, CreateResult
$"/clusters/{clusterId}/calls?service={service}&acknowledge=true"
);

response.EnsureSuccessStatusCode();
try {
response.EnsureSuccessStatusCode();
} catch (HttpRequestException e) {
await RethrowWithContext(e, response);
}

string responseBody = await response.Content.ReadAsStringAsync();
var result = JsonSerializer.Deserialize<List<CallMessage>>(responseBody) ?? new List<CallMessage>();
Expand Down Expand Up @@ -97,7 +159,11 @@ async public Task<CreateCallResult> CreateCall(string clusterId, CreateCallInput
new StringContent(jsonData, Encoding.UTF8, "application/json")
);

response.EnsureSuccessStatusCode();
try {
response.EnsureSuccessStatusCode();
} catch (HttpRequestException e) {
await RethrowWithContext(e, response);
}

string responseBody = await response.Content.ReadAsStringAsync();
return JsonSerializer.Deserialize<CreateCallResult>(responseBody);
Expand Down
71 changes: 71 additions & 0 deletions sdk-dotnet/src/API/Models.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

namespace Inferable.API
{

public struct CreateMachineInput
{
[JsonPropertyName("service")]
Expand Down Expand Up @@ -112,4 +113,74 @@ public struct FunctionConfig
public int? TimeoutSeconds { get; set; }
}

public struct CreateRunInput
{
[JsonPropertyName("message")]
public string? Message { get; set; }

[
JsonPropertyName("attachedFunctions"),
JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull),
]
public List<FunctionReference>? AttachedFunctions { get; set; }

[
JsonPropertyName("metadata"),
JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)
]
public Dictionary<string, string>? Metadata { get; set; }

[
JsonPropertyName("onStatusChange"),
JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)
]
public CreateOnStatusChangeInput? OnStatusChange { get; set; }
}

public struct CreateOnStatusChangeInput
{
[
JsonPropertyName("function"),
JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull),
]
public FunctionReference? Function { get; set; }
}

public struct FunctionReference
{
[JsonPropertyName("service")]
public required string Service { get; set; }
[JsonPropertyName("function")]
public required string Function { get; set; }
}

public struct CreateRunResult
{
[JsonPropertyName("id")]
public string ID { get; set; }
}

public struct GetRunResult
{
[JsonPropertyName("id")]
public string ID { get; set; }

[JsonPropertyName("status")]
public string Status { get; set; }

[JsonPropertyName("failureReason")]
public string FailureReason { get; set; }

[JsonPropertyName("summary")]
public string Summary { get; set; }

[JsonPropertyName("result")]
public object? Result { get; set; }

[JsonPropertyName("attachedFunctions")]
public List<string> AttachedFunctions { get; set; }

[JsonPropertyName("metadata")]
public Dictionary<string, string> Metadata { get; set; }
}
}
79 changes: 74 additions & 5 deletions sdk-dotnet/src/Inferable.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Text.Json.Serialization;
using Inferable.API;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
Expand All @@ -9,24 +10,54 @@ public class Links
public static string DOCS_AUTH = "https://docs.inferable.ai/pages/auth";
}

/// <summary>
/// Object type that will be returned to a Run's OnStatusChange Function
/// </summary>
public struct OnStatusChangeInput<T>
{
[JsonPropertyName("runId")]
public string RunId { get; set; }

[JsonPropertyName("status")]
public string Status { get; set; }

[JsonPropertyName("summary")]
public string? Summary { get; set; }

[JsonPropertyName("result")]
public T? Result { get; set; }

[JsonPropertyName("metadata")]
public Dictionary<string, string> Metadata { get; set; }
}

public class InferableOptions
{
public string? BaseUrl { get; set; }
public string? ApiSecret { get; set; }
/// <summary>
/// PingInterval in seconds
/// </summary>
public int? PingInterval { get; set; }
public string? MachineId { get; set; }
public string? ClusterId { get; set; }
}

public struct PollRunOptions
{
public required TimeSpan MaxWaitTime { get; set; }
public required TimeSpan Interval { get; set; }
}

public class RunReference
{
public required string ID { get; set; }
public required Func<PollRunOptions?, Task<GetRunResult?>> Poll { get; set; }
}

public class InferableClient
{
public static string DefaultBaseUrl = "https://api.inferable.ai/";

private readonly ApiClient _client;
private readonly ILogger<InferableClient> _logger;
private readonly string? _clusterId;

// Dictionary of service name to list of functions
private Dictionary<string, List<IFunctionRegistration>> _functionRegistry = new Dictionary<string, List<IFunctionRegistration>>();
Expand All @@ -46,6 +77,7 @@ public InferableClient(InferableOptions? options = null, ILogger<InferableClient
string? apiSecret = options?.ApiSecret ?? Environment.GetEnvironmentVariable("INFERABLE_API_SECRET");
string baseUrl = options?.BaseUrl ?? Environment.GetEnvironmentVariable("INFERABLE_API_ENDPOINT") ?? DefaultBaseUrl;
string machineId = options?.MachineId ?? Machine.GenerateMachineId();
this._clusterId = options?.ClusterId ?? Environment.GetEnvironmentVariable("INFERABLE_CLUSTER_ID");

if (apiSecret == null)
{
Expand All @@ -72,6 +104,38 @@ public RegisteredService RegisterService(string name)
return new RegisteredService(name, this);
}

async public Task<RunReference> CreateRun(CreateRunInput input)
{
if (this._clusterId == null) {
throw new ArgumentException("Cluster ID must be provided to manage runs");
}

var result = await this._client.CreateRun(this._clusterId, input);

return new RunReference {
ID = result.ID,
Poll = async (PollRunOptions? options) => {
var MaxWaitTime = options?.MaxWaitTime ?? TimeSpan.FromSeconds(60);
var Interval = options?.Interval ?? TimeSpan.FromMilliseconds(500);

var start = DateTime.Now;
var end = start + MaxWaitTime;
while (DateTime.Now < end) {
var pollResult = await this._client.GetRun(this._clusterId, result.ID);

var transientStates = new List<string> { "pending", "running" };
if (transientStates.Contains(pollResult.Status)) {
await Task.Delay(Interval);
continue;
}

return pollResult;
}
return null;
}
};
}

public IEnumerable<string> ActiveServices
{
get
Expand Down Expand Up @@ -148,8 +212,13 @@ internal RegisteredService(string name, InferableClient inferable) {
this._inferable = inferable;
}

public void RegisterFunction<T>(FunctionRegistration<T> function) where T : struct {
public FunctionReference RegisterFunction<T>(FunctionRegistration<T> function) where T : struct {
this._inferable.RegisterFunction<T>(this._name, function);

return new FunctionReference {
Service = this._name,
Function = function.Name
};
}

async public Task Start() {
Expand Down
Loading

0 comments on commit e97f0be

Please sign in to comment.