Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Evaluate clusterId lazily #34

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions sdk-dotnet/src/API/Models.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@ public override void Write(Utf8JsonWriter writer, JsonSchema value, JsonSerializ

public struct CreateMachineInput
{
[JsonPropertyName("service")]
public required string Service { get; set; }
[JsonPropertyName("functions")]
public required List<Function> Functions { get; set; }
[
JsonPropertyName("service"),
JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)
]
public string Service { get; set; }
[
JsonPropertyName("functions"),
JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)
]
public List<Function> Functions { get; set; }
}

public struct CreateMachineResult
Expand Down
27 changes: 15 additions & 12 deletions sdk-dotnet/src/Inferable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ public class InferableOptions
public string? BaseUrl { get; set; }
public string? ApiSecret { get; set; }
public string? MachineId { get; set; }
public string? ClusterId { get; set; }
}

public struct PollRunOptions
Expand Down Expand Up @@ -68,7 +67,7 @@ public class InferableClient

private readonly ApiClient _client;
private readonly ILogger<InferableClient> _logger;
private readonly string? _clusterId;
private string? _clusterId;

// Dictionary of service name to list of functions
private Dictionary<string, List<IFunctionRegistration>> _functionRegistry = new Dictionary<string, List<IFunctionRegistration>>();
Expand Down Expand Up @@ -128,7 +127,6 @@ public InferableClient(InferableOptions? options = null, ILogger<InferableClient
string? apiSecret = options?.ApiSecret ?? Environment.GetEnvironmentVariable("INFERABLE_API_SECRET");
string baseUrl = options?.BaseUrl ?? Environment.GetEnvironmentVariable("INFERABLE_API_ENDPOINT") ?? DefaultBaseUrl;
string machineId = options?.MachineId ?? Machine.GenerateMachineId();
this._clusterId = options?.ClusterId ?? Environment.GetEnvironmentVariable("INFERABLE_CLUSTER_ID");

if (apiSecret == null)
{
Expand Down Expand Up @@ -202,11 +200,8 @@ public RegisteredService RegisterService(string name)
/// </summary>
async public Task<RunReference> CreateRunAsync(CreateRunInput input)
{
if (this._clusterId == null) {
throw new ArgumentException("Cluster ID must be provided to manage runs");
}

var result = await this._client.CreateRunAsync(this._clusterId, input);
var clusterId = await this.GetClusterId();
var result = await this._client.CreateRunAsync(clusterId, input);

return new RunReference {
ID = result.ID,
Expand All @@ -217,7 +212,7 @@ async public Task<RunReference> CreateRunAsync(CreateRunInput input)
var start = DateTime.Now;
var end = start + MaxWaitTime;
while (DateTime.Now < end) {
var pollResult = await this._client.GetRun(this._clusterId, result.ID);
var pollResult = await this._client.GetRun(clusterId, result.ID);

var transientStates = new List<string> { "paused", "pending", "running" };
if (transientStates.Contains(pollResult.Status)) {
Expand Down Expand Up @@ -295,7 +290,7 @@ internal async Task StartServiceAsync(string name) {

var functions = this._functionRegistry[name];

var service = new Service(name, this._client, this._logger, functions);
var service = new Service(name, await this.GetClusterId(), this._client, this._logger, functions);

this._services.Add(service);
await service.Start();
Expand All @@ -308,6 +303,16 @@ internal async Task StopServiceAsync(string name) {
}
await existing.Stop();
}

internal async Task<string> GetClusterId() {
if (this._clusterId == null) {
// Call register machine without any services to test API key and get clusterId
var registerResult = await _client.CreateMachine(new CreateMachineInput {});
this._clusterId = registerResult.ClusterId;
}

return this._clusterId;
}
}

public struct RegisteredService
Expand Down Expand Up @@ -357,6 +362,4 @@ async public Task StopAsync() {
await this._inferable.StopServiceAsync(this._name);
}
}


}
11 changes: 5 additions & 6 deletions sdk-dotnet/src/Service.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ internal class Service
static int DEFAULT_RETRY_AFTER_SECONDS = 10;

private string _name;
private string? _clusterId;
private string _clusterId;
private bool _polling = false;

private int _retryAfter = DEFAULT_RETRY_AFTER_SECONDS;
Expand All @@ -23,10 +23,11 @@ internal class Service

private List<IFunctionRegistration> _functions = new List<IFunctionRegistration>();

internal Service(string name, ApiClient client, ILogger? logger, List<IFunctionRegistration> functions)
internal Service(string name, string clusterId, ApiClient client, ILogger? logger, List<IFunctionRegistration> functions)
{
this._name = name;
this._functions = functions;
this._clusterId = clusterId;

this._client = client;
this._logger = logger ?? NullLogger.Instance;
Expand All @@ -51,7 +52,7 @@ internal bool Polling
async internal Task<string> Start()
{
this._logger.LogDebug("Starting service '{name}'", this._name);
this._clusterId = await RegisterMachine();
await RegisterMachine();

// Purposely not awaiting
_ = this.runLoop();
Expand Down Expand Up @@ -131,7 +132,7 @@ async private Task pollIteration()
_logger.LogDebug($"Polling service {this._name}");
}

async private Task<string> RegisterMachine()
async private Task RegisterMachine()
{
this._logger.LogDebug("Registering machine");
var functions = new List<Function>();
Expand All @@ -152,8 +153,6 @@ async private Task<string> RegisterMachine()
Service = this._name,
Functions = functions
});

return registerResult.ClusterId;
}
}
}
1 change: 0 additions & 1 deletion sdk-dotnet/tests/Inferable.Tests/InferableTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
return new InferableClient(new InferableOptions {
ApiSecret = System.Environment.GetEnvironmentVariable("INFERABLE_TEST_API_SECRET")!,
BaseUrl = System.Environment.GetEnvironmentVariable("INFERABLE_TEST_API_ENDPOINT")!,
ClusterId = System.Environment.GetEnvironmentVariable("INFERABLE_TEST_CLUSTER_ID")!,
}, logger);
}

Expand All @@ -57,7 +56,7 @@

Assert.NotNull(inferable);

Assert.NotNull(inferable.Default);

Check warning on line 59 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / build-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)

Check warning on line 59 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / build-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)

Check warning on line 59 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / test-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)
}

[Fact]
Expand All @@ -67,7 +66,7 @@

var service = inferable.RegisterService("test");

Assert.NotNull(service);

Check warning on line 69 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / build-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)

Check warning on line 69 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / build-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)

Check warning on line 69 in sdk-dotnet/tests/Inferable.Tests/InferableTest.cs

View workflow job for this annotation

GitHub Actions / test-dotnet

Do not use Assert.NotNull() on value type 'RegisteredService'. Remove this assert. (https://xunit.net/xunit.analyzers/rules/xUnit2002)
}

[Fact]
Expand Down
120 changes: 108 additions & 12 deletions sdk-go/inferable.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type Inferable struct {
apiSecret string
functionRegistry functionRegistry
machineID string
clusterID string
_clusterID string
// Convenience reference to a service with the name 'default'.
//
// Returns:
Expand Down Expand Up @@ -63,7 +63,6 @@ type InferableOptions struct {
APIEndpoint string
APISecret string
MachineID string
ClusterID string
}

// Struct type that will be returned to a Run's OnStatusChange Function
Expand Down Expand Up @@ -128,7 +127,6 @@ func New(options InferableOptions) (*Inferable, error) {
client: client,
apiEndpoint: options.APIEndpoint,
apiSecret: options.APISecret,
clusterID: options.ClusterID,
functionRegistry: functionRegistry{services: make(map[string]*service)},
machineID: machineID,
}
Expand All @@ -139,6 +137,10 @@ func New(options InferableOptions) (*Inferable, error) {
return nil, fmt.Errorf("error registering default service: %v", err)
}

if err != nil {
return nil, fmt.Errorf("error registering machine: %v", err)
}

return inferable, nil
}

Expand Down Expand Up @@ -178,6 +180,7 @@ func (i *Inferable) RegisterService(serviceName string) (*service, error) {
if _, exists := i.functionRegistry.services[serviceName]; exists {
return nil, fmt.Errorf("service with name '%s' already registered", serviceName)
}

service := &service{
Name: serviceName,
Functions: make(map[string]Function),
Expand All @@ -188,10 +191,6 @@ func (i *Inferable) RegisterService(serviceName string) (*service, error) {
}

func (i *Inferable) getRun(runID string) (*runResult, error) {
if i.clusterID == "" {
return nil, fmt.Errorf("cluster ID must be provided to manage runs")
}

// Prepare headers
headers := map[string]string{
"Authorization": "Bearer " + i.apiSecret,
Expand All @@ -200,8 +199,13 @@ func (i *Inferable) getRun(runID string) (*runResult, error) {
"X-Machine-SDK-Language": "go",
}

clusterId, err := i.GetClusterId()
if err != nil {
return nil, fmt.Errorf("failed to get cluster id: %v", err)
}

options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/runs/%s", i.clusterID, runID),
Path: fmt.Sprintf("/clusters/%s/runs/%s", clusterId, runID),
Method: "GET",
Headers: headers,
}
Expand Down Expand Up @@ -250,8 +254,9 @@ func (i *Inferable) getRun(runID string) (*runResult, error) {
//
// fmt.Println("Run result:", result)
func (i *Inferable) CreateRun(input CreateRunInput) (*runReference, error) {
if i.clusterID == "" {
return nil, fmt.Errorf("cluster ID must be provided to manage runs")
clusterId, err := i.GetClusterId()
if err != nil {
return nil, fmt.Errorf("failed to get cluster id: %v", err)
}

// Marshal the payload to JSON
Expand All @@ -268,9 +273,8 @@ func (i *Inferable) CreateRun(input CreateRunInput) (*runReference, error) {
"X-Machine-SDK-Language": "go",
}

// Call the registerMachine endpoint
options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/runs", i.clusterID),
Path: fmt.Sprintf("/clusters/%s/runs", clusterId),
Method: "POST",
Headers: headers,
Body: string(jsonPayload),
Expand Down Expand Up @@ -420,3 +424,95 @@ func (i *Inferable) serverOk() error {

return nil
}

func (i *Inferable) GetClusterId() (string, error) {
if i._clusterID == "" {
clusterId, err := i.registerMachine(nil)
if err != nil {
return "", fmt.Errorf("failed to register machine: %v", err)
}

i._clusterID = clusterId
}

return i._clusterID, nil
}

func (i *Inferable) registerMachine(s *service) (string, error) {

// Prepare the payload for registration
payload := struct {
Service string `json:"service,omitempty"`
Functions []struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema string `json:"schema,omitempty"`
} `json:"functions,omitempty"`
}{}

if s != nil {
payload.Service = s.Name

// Check if there are any registered functions
if len(s.Functions) == 0 {
return "", fmt.Errorf("cannot register service '%s': no functions registered", s.Name)
}

// Add registered functions to the payload
for _, fn := range s.Functions {
schemaJSON, err := json.Marshal(fn.schema)
if err != nil {
return "", fmt.Errorf("failed to marshal schema for function '%s': %v", fn.Name, err)
}

payload.Functions = append(payload.Functions, struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema string `json:"schema,omitempty"`
}{
Name: fn.Name,
Description: fn.Description,
Schema: string(schemaJSON),
})
}
}

// Marshal the payload to JSON
jsonPayload, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %v", err)
}

// Prepare headers
headers := map[string]string{
"Authorization": "Bearer " + i.apiSecret,
"X-Machine-ID": i.machineID,
"X-Machine-SDK-Version": Version,
"X-Machine-SDK-Language": "go",
}

// Call the registerMachine endpoint
options := client.FetchDataOptions{
Path: "/machines",
Method: "POST",
Headers: headers,
Body: string(jsonPayload),
}

responseData, _, err, _ := i.fetchData(options)
if err != nil {
return "", fmt.Errorf("failed to register machine: %v", err)
}

// Parse the response
var response struct {
ClusterId string `json:"clusterId"`
}

err = json.Unmarshal(responseData, &response)
if err != nil {
return "", fmt.Errorf("failed to parse registration response: %v", err)
}

return response.ClusterId, nil
}
3 changes: 1 addition & 2 deletions sdk-go/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,11 @@ func TestInferableFunctions(t *testing.T) {

// This should match the example in the readme
func TestInferableE2E(t *testing.T) {
machineSecret, _, clusterID, apiEndpoint := util.GetTestVars()
machineSecret, _, _, apiEndpoint := util.GetTestVars()

client, err := New(InferableOptions{
APIEndpoint: apiEndpoint,
APISecret: machineSecret,
ClusterID: clusterID,
})

if err != nil {
Expand Down
Loading
Loading