Skip to content

Commit

Permalink
chore(sdk-node): Evaluate clusterId lazily (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith authored Nov 3, 2024
1 parent a654f97 commit 1998d9e
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 96 deletions.
114 changes: 103 additions & 11 deletions sdk-go/inferable.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ type InferableOptions struct {
APIEndpoint string
APISecret string
MachineID string
ClusterID string
}

// Struct type that will be returned to a Run's OnStatusChange Function
Expand Down Expand Up @@ -128,7 +127,6 @@ func New(options InferableOptions) (*Inferable, error) {
client: client,
apiEndpoint: options.APIEndpoint,
apiSecret: options.APISecret,
clusterID: options.ClusterID,
functionRegistry: functionRegistry{services: make(map[string]*service)},
machineID: machineID,
}
Expand Down Expand Up @@ -178,6 +176,7 @@ func (i *Inferable) RegisterService(serviceName string) (*service, error) {
if _, exists := i.functionRegistry.services[serviceName]; exists {
return nil, fmt.Errorf("service with name '%s' already registered", serviceName)
}

service := &service{
Name: serviceName,
Functions: make(map[string]Function),
Expand All @@ -188,10 +187,6 @@ func (i *Inferable) RegisterService(serviceName string) (*service, error) {
}

func (i *Inferable) getRun(runID string) (*runResult, error) {
if i.clusterID == "" {
return nil, fmt.Errorf("cluster ID must be provided to manage runs")
}

// Prepare headers
headers := map[string]string{
"Authorization": "Bearer " + i.apiSecret,
Expand All @@ -200,8 +195,13 @@ func (i *Inferable) getRun(runID string) (*runResult, error) {
"X-Machine-SDK-Language": "go",
}

clusterId, err := i.getClusterId()
if err != nil {
return nil, fmt.Errorf("failed to get cluster id: %v", err)
}

options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/runs/%s", i.clusterID, runID),
Path: fmt.Sprintf("/clusters/%s/runs/%s", clusterId, runID),
Method: "GET",
Headers: headers,
}
Expand Down Expand Up @@ -250,8 +250,9 @@ func (i *Inferable) getRun(runID string) (*runResult, error) {
//
// fmt.Println("Run result:", result)
func (i *Inferable) CreateRun(input CreateRunInput) (*runReference, error) {
if i.clusterID == "" {
return nil, fmt.Errorf("cluster ID must be provided to manage runs")
clusterId, err := i.getClusterId()
if err != nil {
return nil, fmt.Errorf("failed to get cluster id: %v", err)
}

// Marshal the payload to JSON
Expand All @@ -268,9 +269,8 @@ func (i *Inferable) CreateRun(input CreateRunInput) (*runReference, error) {
"X-Machine-SDK-Language": "go",
}

// Call the registerMachine endpoint
options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/runs", i.clusterID),
Path: fmt.Sprintf("/clusters/%s/runs", clusterId),
Method: "POST",
Headers: headers,
Body: string(jsonPayload),
Expand Down Expand Up @@ -420,3 +420,95 @@ func (i *Inferable) serverOk() error {

return nil
}

func (i *Inferable) getClusterId() (string, error) {
if i.clusterID == "" {
clusterId, err := i.registerMachine(nil)
if err != nil {
return "", fmt.Errorf("failed to register machine: %v", err)
}

i.clusterID = clusterId
}

return i.clusterID, nil
}

func (i *Inferable) registerMachine(s *service) (string, error) {

// Prepare the payload for registration
payload := struct {
Service string `json:"service,omitempty"`
Functions []struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema string `json:"schema,omitempty"`
} `json:"functions,omitempty"`
}{}

if s != nil {
payload.Service = s.Name

// Check if there are any registered functions
if len(s.Functions) == 0 {
return "", fmt.Errorf("cannot register service '%s': no functions registered", s.Name)
}

// Add registered functions to the payload
for _, fn := range s.Functions {
schemaJSON, err := json.Marshal(fn.schema)
if err != nil {
return "", fmt.Errorf("failed to marshal schema for function '%s': %v", fn.Name, err)
}

payload.Functions = append(payload.Functions, struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema string `json:"schema,omitempty"`
}{
Name: fn.Name,
Description: fn.Description,
Schema: string(schemaJSON),
})
}
}

// Marshal the payload to JSON
jsonPayload, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %v", err)
}

// Prepare headers
headers := map[string]string{
"Authorization": "Bearer " + i.apiSecret,
"X-Machine-ID": i.machineID,
"X-Machine-SDK-Version": Version,
"X-Machine-SDK-Language": "go",
}

// Call the registerMachine endpoint
options := client.FetchDataOptions{
Path: "/machines",
Method: "POST",
Headers: headers,
Body: string(jsonPayload),
}

responseData, _, err, _ := i.fetchData(options)
if err != nil {
return "", fmt.Errorf("failed to register machine: %v", err)
}

// Parse the response
var response struct {
ClusterId string `json:"clusterId"`
}

err = json.Unmarshal(responseData, &response)
if err != nil {
return "", fmt.Errorf("failed to parse registration response: %v", err)
}

return response.ClusterId, nil
}
3 changes: 1 addition & 2 deletions sdk-go/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,11 @@ func TestInferableFunctions(t *testing.T) {

// This should match the example in the readme
func TestInferableE2E(t *testing.T) {
machineSecret, _, clusterID, apiEndpoint := util.GetTestVars()
machineSecret, _, _, apiEndpoint := util.GetTestVars()

client, err := New(InferableOptions{
APIEndpoint: apiEndpoint,
APISecret: machineSecret,
ClusterID: clusterID,
})

if err != nil {
Expand Down
97 changes: 14 additions & 83 deletions sdk-go/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ type service struct {
Name string
Functions map[string]Function
inferable *Inferable
clusterId string
ctx context.Context
cancel context.CancelFunc
retryAfter int
Expand Down Expand Up @@ -144,87 +143,9 @@ func (s *service) RegisterFunc(fn Function) (*FunctionReference, error) {
return &FunctionReference{Service: s.Name, Function: fn.Name}, nil
}

func (s *service) registerMachine() error {
// Check if there are any registered functions
if len(s.Functions) == 0 {
return fmt.Errorf("cannot register service '%s': no functions registered", s.Name)
}

// Prepare the payload for registration
payload := struct {
Service string `json:"service"`
Functions []struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema string `json:"schema,omitempty"`
} `json:"functions,omitempty"`
}{
Service: s.Name,
}

// Add registered functions to the payload
for _, fn := range s.Functions {
schemaJSON, err := json.Marshal(fn.schema)
if err != nil {
return fmt.Errorf("failed to marshal schema for function '%s': %v", fn.Name, err)
}

payload.Functions = append(payload.Functions, struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema string `json:"schema,omitempty"`
}{
Name: fn.Name,
Description: fn.Description,
Schema: string(schemaJSON),
})
}

// Marshal the payload to JSON
jsonPayload, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}

// Prepare headers
headers := map[string]string{
"Authorization": "Bearer " + s.inferable.apiSecret,
"X-Machine-ID": s.inferable.machineID,
"X-Machine-SDK-Version": Version,
"X-Machine-SDK-Language": "go",
}

// Call the registerMachine endpoint
options := client.FetchDataOptions{
Path: "/machines",
Method: "POST",
Headers: headers,
Body: string(jsonPayload),
}

responseData, _, err, _ := s.inferable.fetchData(options)
if err != nil {
return fmt.Errorf("failed to register machine: %v", err)
}

// Parse the response
var response struct {
ClusterId string `json:"clusterId"`
}

err = json.Unmarshal(responseData, &response)
if err != nil {
return fmt.Errorf("failed to parse registration response: %v", err)
}

s.clusterId = response.ClusterId

return nil
}

// Start initializes the service, registers the machine, and starts polling for messages
func (s *service) Start() error {
err := s.registerMachine()
_, err := s.inferable.registerMachine(s)
if err != nil {
return fmt.Errorf("failed to register machine: %v", err)
}
Expand Down Expand Up @@ -277,16 +198,21 @@ func (s *service) poll() error {
"X-Machine-SDK-Language": "go",
}

clusterId, err := s.inferable.getClusterId()
if err != nil {
return fmt.Errorf("failed to get cluster id: %v", err)
}

options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/calls?acknowledge=true&service=%s&status=pending&limit=10", s.clusterId, s.Name),
Path: fmt.Sprintf("/clusters/%s/calls?acknowledge=true&service=%s&status=pending&limit=10", clusterId, s.Name),
Method: "GET",
Headers: headers,
}

result, respHeaders, err, status := s.inferable.fetchData(options)

if status == 410 {
s.registerMachine()
s.inferable.registerMachine(s)
}

if err != nil {
Expand Down Expand Up @@ -394,8 +320,13 @@ func (s *service) persistJobResult(jobID string, result callResult) error {
"X-Machine-SDK-Language": "go",
}

clusterId, err := s.inferable.getClusterId()
if err != nil {
return fmt.Errorf("failed to get cluster id: %v", err)
}

options := client.FetchDataOptions{
Path: fmt.Sprintf("/clusters/%s/calls/%s/result", s.clusterId, jobID),
Path: fmt.Sprintf("/clusters/%s/calls/%s/result", clusterId, jobID),
Method: "POST",
Headers: headers,
Body: string(payloadJSON),
Expand Down

0 comments on commit 1998d9e

Please sign in to comment.