Skip to content

Commit

Permalink
chore(dotnet): Get clusterId from API
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith committed Nov 3, 2024
1 parent 12323ae commit 48ea931
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 38 deletions.
14 changes: 10 additions & 4 deletions sdk-dotnet/src/API/Models.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@ public override void Write(Utf8JsonWriter writer, JsonSchema value, JsonSerializ

public struct CreateMachineInput
{
[JsonPropertyName("service")]
public required string Service { get; set; }
[JsonPropertyName("functions")]
public required List<Function> Functions { get; set; }
[
JsonPropertyName("service"),
JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)
]
public string Service { get; set; }
[
JsonPropertyName("functions"),
JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)
]
public List<Function> Functions { get; set; }
}

public struct CreateMachineResult
Expand Down
27 changes: 15 additions & 12 deletions sdk-dotnet/src/Inferable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ public class InferableOptions
public string? BaseUrl { get; set; }
public string? ApiSecret { get; set; }
public string? MachineId { get; set; }
public string? ClusterId { get; set; }
}

public struct PollRunOptions
Expand Down Expand Up @@ -68,7 +67,7 @@ public class InferableClient

private readonly ApiClient _client;
private readonly ILogger<InferableClient> _logger;
private readonly string? _clusterId;
private string? _clusterId;

// Dictionary of service name to list of functions
private Dictionary<string, List<IFunctionRegistration>> _functionRegistry = new Dictionary<string, List<IFunctionRegistration>>();
Expand Down Expand Up @@ -128,7 +127,6 @@ public InferableClient(InferableOptions? options = null, ILogger<InferableClient
string? apiSecret = options?.ApiSecret ?? Environment.GetEnvironmentVariable("INFERABLE_API_SECRET");
string baseUrl = options?.BaseUrl ?? Environment.GetEnvironmentVariable("INFERABLE_API_ENDPOINT") ?? DefaultBaseUrl;
string machineId = options?.MachineId ?? Machine.GenerateMachineId();
this._clusterId = options?.ClusterId ?? Environment.GetEnvironmentVariable("INFERABLE_CLUSTER_ID");

if (apiSecret == null)
{
Expand Down Expand Up @@ -202,11 +200,8 @@ public RegisteredService RegisterService(string name)
/// </summary>
async public Task<RunReference> CreateRunAsync(CreateRunInput input)
{
if (this._clusterId == null) {
throw new ArgumentException("Cluster ID must be provided to manage runs");
}

var result = await this._client.CreateRunAsync(this._clusterId, input);
var clusterId = await this.GetClusterId();
var result = await this._client.CreateRunAsync(clusterId, input);

return new RunReference {
ID = result.ID,
Expand All @@ -217,7 +212,7 @@ async public Task<RunReference> CreateRunAsync(CreateRunInput input)
var start = DateTime.Now;
var end = start + MaxWaitTime;
while (DateTime.Now < end) {
var pollResult = await this._client.GetRun(this._clusterId, result.ID);
var pollResult = await this._client.GetRun(clusterId, result.ID);

var transientStates = new List<string> { "paused", "pending", "running" };
if (transientStates.Contains(pollResult.Status)) {
Expand Down Expand Up @@ -295,7 +290,7 @@ internal async Task StartServiceAsync(string name) {

var functions = this._functionRegistry[name];

var service = new Service(name, this._client, this._logger, functions);
var service = new Service(name, await this.GetClusterId(), this._client, this._logger, functions);

this._services.Add(service);
await service.Start();
Expand All @@ -308,6 +303,16 @@ internal async Task StopServiceAsync(string name) {
}
await existing.Stop();
}

internal async Task<string> GetClusterId() {
if (this._clusterId == null) {
// Call register machine without any services to test API key and get clusterId
var registerResult = await _client.CreateMachine(new CreateMachineInput {});
this._clusterId = registerResult.ClusterId;
}

return this._clusterId;
}
}

public struct RegisteredService
Expand Down Expand Up @@ -357,6 +362,4 @@ async public Task StopAsync() {
await this._inferable.StopServiceAsync(this._name);
}
}


}
11 changes: 5 additions & 6 deletions sdk-dotnet/src/Service.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ internal class Service
static int DEFAULT_RETRY_AFTER_SECONDS = 10;

private string _name;
private string? _clusterId;
private string _clusterId;
private bool _polling = false;

private int _retryAfter = DEFAULT_RETRY_AFTER_SECONDS;
Expand All @@ -23,10 +23,11 @@ internal class Service

private List<IFunctionRegistration> _functions = new List<IFunctionRegistration>();

internal Service(string name, ApiClient client, ILogger? logger, List<IFunctionRegistration> functions)
internal Service(string name, string clusterId, ApiClient client, ILogger? logger, List<IFunctionRegistration> functions)
{
this._name = name;
this._functions = functions;
this._clusterId = clusterId;

this._client = client;
this._logger = logger ?? NullLogger.Instance;
Expand All @@ -51,7 +52,7 @@ internal bool Polling
async internal Task<string> Start()
{
this._logger.LogDebug("Starting service '{name}'", this._name);
this._clusterId = await RegisterMachine();
await RegisterMachine();

// Purposely not awaiting
_ = this.runLoop();
Expand Down Expand Up @@ -131,7 +132,7 @@ async private Task pollIteration()
_logger.LogDebug($"Polling service {this._name}");
}

async private Task<string> RegisterMachine()
async private Task RegisterMachine()
{
this._logger.LogDebug("Registering machine");
var functions = new List<Function>();
Expand All @@ -152,8 +153,6 @@ async private Task<string> RegisterMachine()
Service = this._name,
Functions = functions
});

return registerResult.ClusterId;
}
}
}
1 change: 0 additions & 1 deletion sdk-dotnet/tests/Inferable.Tests/InferableTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ static InferableClient CreateInferableClient()
return new InferableClient(new InferableOptions {
ApiSecret = System.Environment.GetEnvironmentVariable("INFERABLE_TEST_API_SECRET")!,
BaseUrl = System.Environment.GetEnvironmentVariable("INFERABLE_TEST_API_ENDPOINT")!,
ClusterId = System.Environment.GetEnvironmentVariable("INFERABLE_TEST_CLUSTER_ID")!,
}, logger);
}

Expand Down
38 changes: 26 additions & 12 deletions sdk-go/inferable.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type Inferable struct {
apiSecret string
functionRegistry functionRegistry
machineID string
clusterID string
_clusterID string
// Convenience reference to a service with the name 'default'.
//
// Returns:
Expand Down Expand Up @@ -137,15 +137,10 @@ func New(options InferableOptions) (*Inferable, error) {
return nil, fmt.Errorf("error registering default service: %v", err)
}

// Call register machine without any services to test API key and get clusterId
clusterId, err := inferable.registerMachine(nil);

if err != nil {
return nil, fmt.Errorf("error registering machine: %v", err)
}

inferable.clusterID = clusterId

return inferable, nil
}

Expand Down Expand Up @@ -185,10 +180,10 @@ func (i *Inferable) RegisterService(serviceName string) (*service, error) {
if _, exists := i.functionRegistry.services[serviceName]; exists {
return nil, fmt.Errorf("service with name '%s' already registered", serviceName)
}

service := &service{
Name: serviceName,
Functions: make(map[string]Function),
ClusterID: i.clusterID,
inferable: i, // Set the reference to the Inferable instance
}
i.functionRegistry.services[serviceName] = service
Expand All @@ -204,8 +199,13 @@ func (i *Inferable) getRun(runID string) (*runResult, error) {
"X-Machine-SDK-Language": "go",
}

clusterId, err := i.GetClusterId()
if err != nil {
return nil, fmt.Errorf("failed to get cluster id: %v", err)
}

options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/runs/%s", i.clusterID, runID),
Path: fmt.Sprintf("/clusters/%s/runs/%s", clusterId, runID),
Method: "GET",
Headers: headers,
}
Expand Down Expand Up @@ -254,9 +254,10 @@ func (i *Inferable) getRun(runID string) (*runResult, error) {
//
// fmt.Println("Run result:", result)
func (i *Inferable) CreateRun(input CreateRunInput) (*runReference, error) {
if i.clusterID == "" {
return nil, fmt.Errorf("cluster ID must be provided to manage runs")
}
clusterId, err := i.GetClusterId()
if err != nil {
return nil, fmt.Errorf("failed to get cluster id: %v", err)
}

// Marshal the payload to JSON
jsonPayload, err := json.Marshal(input)
Expand All @@ -273,7 +274,7 @@ func (i *Inferable) CreateRun(input CreateRunInput) (*runReference, error) {
}

options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/runs", i.clusterID),
Path: fmt.Sprintf("/clusters/%s/runs", clusterId),
Method: "POST",
Headers: headers,
Body: string(jsonPayload),
Expand Down Expand Up @@ -424,6 +425,19 @@ func (i *Inferable) serverOk() error {
return nil
}

func (i *Inferable) GetClusterId() (string, error) {
if i._clusterID == "" {
clusterId, err := i.registerMachine(nil)
if err != nil {
return "", fmt.Errorf("failed to register machine: %v", err)
}

i._clusterID = clusterId
}

return i._clusterID, nil
}

func (i *Inferable) registerMachine(s *service) (string, error) {

// Prepare the payload for registration
Expand Down
15 changes: 12 additions & 3 deletions sdk-go/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ type service struct {
Name string
Functions map[string]Function
inferable *Inferable
ClusterID string
ctx context.Context
cancel context.CancelFunc
retryAfter int
Expand Down Expand Up @@ -199,8 +198,13 @@ func (s *service) poll() error {
"X-Machine-SDK-Language": "go",
}

clusterId, err := s.inferable.GetClusterId()
if err != nil {
return fmt.Errorf("failed to get cluster id: %v", err)
}

options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/calls?acknowledge=true&service=%s&status=pending&limit=10", s.ClusterID, s.Name),
Path: fmt.Sprintf("/clusters/%s/calls?acknowledge=true&service=%s&status=pending&limit=10", clusterId, s.Name),
Method: "GET",
Headers: headers,
}
Expand Down Expand Up @@ -316,8 +320,13 @@ func (s *service) persistJobResult(jobID string, result callResult) error {
"X-Machine-SDK-Language": "go",
}

clusterId, err := s.inferable.GetClusterId()
if err != nil {
return fmt.Errorf("failed to get cluster id: %v", err)
}

options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/calls/%s/result", s.ClusterID, jobID),
Path: fmt.Sprintf("/clusters/%s/calls/%s/result", clusterId, jobID),
Method: "POST",
Headers: headers,
Body: string(payloadJSON),
Expand Down

0 comments on commit 48ea931

Please sign in to comment.