Skip to content

Commit

Permalink
chore(sdk-dotnet): Evaluate clusterId lazily (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith authored Nov 3, 2024
1 parent 0a4a0c2 commit 1117500
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 23 deletions.
14 changes: 10 additions & 4 deletions sdk-dotnet/src/API/Models.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@ public override void Write(Utf8JsonWriter writer, JsonSchema value, JsonSerializ

public struct CreateMachineInput
{
[JsonPropertyName("service")]
public required string Service { get; set; }
[JsonPropertyName("functions")]
public required List<Function> Functions { get; set; }
[
JsonPropertyName("service"),
JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)
]
public string Service { get; set; }
[
JsonPropertyName("functions"),
JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)
]
public List<Function> Functions { get; set; }
}

public struct CreateMachineResult
Expand Down
27 changes: 15 additions & 12 deletions sdk-dotnet/src/Inferable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ public class InferableOptions
public string? BaseUrl { get; set; }
public string? ApiSecret { get; set; }
public string? MachineId { get; set; }
public string? ClusterId { get; set; }
}

public struct PollRunOptions
Expand Down Expand Up @@ -68,7 +67,7 @@ public class InferableClient

private readonly ApiClient _client;
private readonly ILogger<InferableClient> _logger;
private readonly string? _clusterId;
private string? _clusterId;

// Dictionary of service name to list of functions
private Dictionary<string, List<IFunctionRegistration>> _functionRegistry = new Dictionary<string, List<IFunctionRegistration>>();
Expand Down Expand Up @@ -128,7 +127,6 @@ public InferableClient(InferableOptions? options = null, ILogger<InferableClient
string? apiSecret = options?.ApiSecret ?? Environment.GetEnvironmentVariable("INFERABLE_API_SECRET");
string baseUrl = options?.BaseUrl ?? Environment.GetEnvironmentVariable("INFERABLE_API_ENDPOINT") ?? DefaultBaseUrl;
string machineId = options?.MachineId ?? Machine.GenerateMachineId();
this._clusterId = options?.ClusterId ?? Environment.GetEnvironmentVariable("INFERABLE_CLUSTER_ID");

if (apiSecret == null)
{
Expand Down Expand Up @@ -202,11 +200,8 @@ public RegisteredService RegisterService(string name)
/// </summary>
async public Task<RunReference> CreateRunAsync(CreateRunInput input)
{
if (this._clusterId == null) {
throw new ArgumentException("Cluster ID must be provided to manage runs");
}

var result = await this._client.CreateRunAsync(this._clusterId, input);
var clusterId = await this.GetClusterId();
var result = await this._client.CreateRunAsync(clusterId, input);

return new RunReference {
ID = result.ID,
Expand All @@ -217,7 +212,7 @@ async public Task<RunReference> CreateRunAsync(CreateRunInput input)
var start = DateTime.Now;
var end = start + MaxWaitTime;
while (DateTime.Now < end) {
var pollResult = await this._client.GetRun(this._clusterId, result.ID);
var pollResult = await this._client.GetRun(clusterId, result.ID);

var transientStates = new List<string> { "paused", "pending", "running" };
if (transientStates.Contains(pollResult.Status)) {
Expand Down Expand Up @@ -295,7 +290,7 @@ internal async Task StartServiceAsync(string name) {

var functions = this._functionRegistry[name];

var service = new Service(name, this._client, this._logger, functions);
var service = new Service(name, await this.GetClusterId(), this._client, this._logger, functions);

this._services.Add(service);
await service.Start();
Expand All @@ -308,6 +303,16 @@ internal async Task StopServiceAsync(string name) {
}
await existing.Stop();
}

internal async Task<string> GetClusterId() {
if (this._clusterId == null) {
// Call register machine without any services to test API key and get clusterId
var registerResult = await _client.CreateMachine(new CreateMachineInput {});
this._clusterId = registerResult.ClusterId;
}

return this._clusterId;
}
}

public struct RegisteredService
Expand Down Expand Up @@ -357,6 +362,4 @@ async public Task StopAsync() {
await this._inferable.StopServiceAsync(this._name);
}
}


}
11 changes: 5 additions & 6 deletions sdk-dotnet/src/Service.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ internal class Service
static int DEFAULT_RETRY_AFTER_SECONDS = 10;

private string _name;
private string? _clusterId;
private string _clusterId;
private bool _polling = false;

private int _retryAfter = DEFAULT_RETRY_AFTER_SECONDS;
Expand All @@ -23,10 +23,11 @@ internal class Service

private List<IFunctionRegistration> _functions = new List<IFunctionRegistration>();

internal Service(string name, ApiClient client, ILogger? logger, List<IFunctionRegistration> functions)
internal Service(string name, string clusterId, ApiClient client, ILogger? logger, List<IFunctionRegistration> functions)
{
this._name = name;
this._functions = functions;
this._clusterId = clusterId;

this._client = client;
this._logger = logger ?? NullLogger.Instance;
Expand All @@ -51,7 +52,7 @@ internal bool Polling
async internal Task<string> Start()
{
this._logger.LogDebug("Starting service '{name}'", this._name);
this._clusterId = await RegisterMachine();
await RegisterMachine();

// Purposely not awaiting
_ = this.runLoop();
Expand Down Expand Up @@ -131,7 +132,7 @@ async private Task pollIteration()
_logger.LogDebug($"Polling service {this._name}");
}

async private Task<string> RegisterMachine()
async private Task RegisterMachine()
{
this._logger.LogDebug("Registering machine");
var functions = new List<Function>();
Expand All @@ -152,8 +153,6 @@ async private Task<string> RegisterMachine()
Service = this._name,
Functions = functions
});

return registerResult.ClusterId;
}
}
}
1 change: 0 additions & 1 deletion sdk-dotnet/tests/Inferable.Tests/InferableTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ static InferableClient CreateInferableClient()
return new InferableClient(new InferableOptions {
ApiSecret = System.Environment.GetEnvironmentVariable("INFERABLE_TEST_API_SECRET")!,
BaseUrl = System.Environment.GetEnvironmentVariable("INFERABLE_TEST_API_ENDPOINT")!,
ClusterId = System.Environment.GetEnvironmentVariable("INFERABLE_TEST_CLUSTER_ID")!,
}, logger);
}

Expand Down

0 comments on commit 1117500

Please sign in to comment.