diff --git a/internal/hass/client.go b/internal/hass/client.go index ea2c11631..f81ffdd3a 100644 --- a/internal/hass/client.go +++ b/internal/hass/client.go @@ -3,13 +3,10 @@ // This software is released under the MIT License. // https://opensource.org/licenses/MIT -//revive:disable:max-public-structs -//go:generate go run github.com/matryer/moq -out client_mocks_test.go . PostRequest Registry package hass import ( "context" - "encoding/json" "errors" "fmt" "log/slog" @@ -28,7 +25,10 @@ const ( DefaultTimeout = 30 * time.Second ) -var tracker = sensor.NewTracker() +var ( + sensorRegistry Registry + sensorTracker *sensor.Tracker +) var ( ErrGetConfigFailed = errors.New("could not fetch Home Assistant config") @@ -54,26 +54,6 @@ var ( } ) -// GetRequest is a HTTP GET request. -type GetRequest any - -// PostRequest is a HTTP POST request with the request body provided by Body(). -type PostRequest interface { - RequestBody() json.RawMessage -} - -// Authenticated represents a request that requires passing an authentication -// header with the value returned by Auth(). -type Authenticated interface { - Auth() string -} - -// Encrypted represents a request that should be encrypted with the secret -// provided by Secret(). -type Encrypted interface { - Secret() string -} - type Registry interface { SetDisabled(id string, state bool) error SetRegistered(id string, state bool) error @@ -83,22 +63,19 @@ type Registry interface { type Client struct { endpoint *resty.Client - registry Registry } func NewClient(ctx context.Context) (*Client, error) { var err error - reg, err := registry.Load(ctx) + sensorTracker = sensor.NewTracker() + + sensorRegistry, err = registry.Load(ctx) if err != nil { return nil, fmt.Errorf("could not start registry: %w", err) } - client := &Client{ - registry: reg, - } - - return client, nil + return &Client{}, nil } func (c *Client) Endpoint(url string, timeout time.Duration) { @@ -129,149 +106,74 @@ func (c *Client) ProcessEvent(ctx context.Context, details event.Event) error { req := &request{Data: details, RequestType: requestTypeEvent} if err := req.Validate(); err != nil { - return fmt.Errorf("validation failed: %w", err) + return fmt.Errorf("invalid event request: %w", err) } - _, err := send[eventResponse](ctx, c, req) + resp, err := send[response](ctx, c, req) if err != nil { return fmt.Errorf("failed to send event request: %w", err) } - return nil -} - -func (c *Client) ProcessSensor(ctx context.Context, details sensor.Entity) error { - if c.isDisabled(ctx, details) { - logging.FromContext(ctx). - Debug("Not sending request for disabled sensor.", - sensorLogAttrs(details)) - - return nil - } - - if _, ok := details.Value.(*LocationRequest); ok { - // LocationRequest: - return c.handleLocationUpdate(ctx, details) - } - - if c.registry.IsRegistered(details.ID) { - // Sensor Update (existing sensor). - return c.handleSensorUpdate(ctx, details) - } - // Sensor Registration (new sensor). - return c.handleRegistration(ctx, details) -} - -func (c *Client) handleLocationUpdate(ctx context.Context, details sensor.Entity) error { - // req, err := sensor.NewLocationUpdateRequest(details) - req, err := newEntityRequest(requestTypeLocation, details) - if err != nil { - return fmt.Errorf("unable to handle location update: %w", err) - } - - resp, err := send[locationResponse](ctx, c, req) - if err != nil { - return fmt.Errorf("failed to send location update request: %w", err) - } - - if err := resp.updated(); err != nil { //nolint:staticcheck - return fmt.Errorf("location update failed: %w", err) + if _, err := resp.Status(); err != nil { + return err } return nil } -func (c *Client) handleSensorUpdate(ctx context.Context, details sensor.Entity) error { - // req, err := sensor.NewUpdateRequest(details) - req, err := newEntityRequest(requestTypeUpdate, details) - if err != nil { - return fmt.Errorf("unable to handle sensor update: %w", err) - } - - response, err := send[stateUpdateResponse](ctx, c, req) - if err != nil { - return fmt.Errorf("failed to send sensor update request for %s: %w", details.ID, err) - } +func (c *Client) ProcessSensor(ctx context.Context, details sensor.Entity) error { + req := &request{} - if response == nil { - return ErrStateUpdateUnknown + if _, ok := details.Value.(*LocationRequest); ok { + req = &request{Data: details.Value, RequestType: requestTypeLocation} } - // At this point, the sensor update was successful. Any errors are really - // warnings and non-critical. - var warnings error - - for id, update := range response { - success, err := update.success() - if !success { - // The update failed. - warnings = errors.Join(warnings, err) - } - - // If HA reports the sensor as disabled, update the registry. - if c.registry.IsDisabled(id) != update.disabled() { + if sensorRegistry.IsRegistered(details.ID) { + // For sensor updates, if the sensor is disabled, don't continue. + if c.isDisabled(ctx, details) { logging.FromContext(ctx). - Info("Sensor is disabled in Home Assistant. Setting disabled in local registry.", + Debug("Not sending request for disabled sensor.", sensorLogAttrs(details)) - if err := c.registry.SetDisabled(id, update.disabled()); err != nil { - warnings = errors.Join(warnings, fmt.Errorf("%w: %w", ErrRegDisableFailed, err)) - } - } - - // Add the sensor update to the tracker. - if err := tracker.Add(&details); err != nil { - warnings = errors.Join(warnings, fmt.Errorf("%w: %w", ErrTrkUpdateFailed, err)) + return nil } - } - if warnings != nil { - logging.FromContext(ctx). - Debug("Sensor updated with warnings.", - sensorLogAttrs(details), - slog.Any("warnings", warnings)) + req = &request{Data: details.State, RequestType: requestTypeUpdate} } else { - logging.FromContext(ctx). - Debug("Sensor updated.", - sensorLogAttrs(details)) + req = &request{Data: details, RequestType: requestTypeRegister} } - // Return success status and any warnings. - return warnings -} -func (c *Client) handleRegistration(ctx context.Context, details sensor.Entity) error { - req, err := newEntityRequest(requestTypeRegister, details) - if err != nil { - return fmt.Errorf("unable to handle sensor update: %w", err) + if err := req.Validate(); err != nil { + return fmt.Errorf("invalid sensor request: %w", err) } - response, err := send[registrationResponse](ctx, c, req) - if err != nil { - return fmt.Errorf("failed to send sensor registration request for %s: %w", details.ID, err) - } + switch req.RequestType { + case requestTypeLocation: + resp, err := send[response](ctx, c, req) + if err != nil { + return fmt.Errorf("failed to send location update: %w", err) + } - // If the registration failed, log a warning. - success, err := response.registered() - if !success { - return errors.Join(ErrRegistrationFailed, err) - } + if _, err := resp.Status(); err != nil { + return err + } + case requestTypeUpdate: + resp, err := send[bulkSensorUpdateResponse](ctx, c, req) + if err != nil { + return fmt.Errorf("failed to send location update: %w", err) + } - // At this point, the sensor registration was successful. Any errors are really - // warnings and non-critical. - var warnings error + go resp.Process(ctx, details) + case requestTypeRegister: + resp, err := send[registrationResponse](ctx, c, req) + if err != nil { + return fmt.Errorf("failed to send location update: %w", err) + } - // Set the sensor as registered in the registry. - err = c.registry.SetRegistered(details.ID, true) - if err != nil { - warnings = errors.Join(warnings, fmt.Errorf("%w: %w", ErrRegAddFailed, err)) - } - // Update the sensor state in the tracker. - if err := tracker.Add(&details); err != nil { - warnings = errors.Join(warnings, fmt.Errorf("%w: %w", ErrTrkUpdateFailed, err)) + go resp.Process(ctx, details) } - // Return success status and any warnings. - return warnings + return nil } // isDisabled handles processing a sensor that is disabled. For a sensor that is @@ -291,7 +193,7 @@ func (c *Client) isDisabled(ctx context.Context, details sensor.Entity) bool { if !disabledInHA { slog.Info("Sensor re-enabled in Home Assistant, Re-enabling in local registry and sending updates.", sensorLogAttrs(details)) - if err := c.registry.SetDisabled(details.ID, false); err != nil { + if err := sensorRegistry.SetDisabled(details.ID, false); err != nil { slog.Error("Could not re-enable sensor.", sensorLogAttrs(details), slog.Any("error", err)) @@ -309,8 +211,10 @@ func (c *Client) isDisabled(ctx context.Context, details sensor.Entity) bool { // isDisabledInReg returns the disabled state of the sensor from the local // registry. +// +//revive:disable:unused-receiver func (c *Client) isDisabledInReg(id string) bool { - return c.registry.IsDisabled(id) + return sensorRegistry.IsDisabled(id) } // isDisabledInHA returns the disabled state of the sensor from Home Assistant. @@ -338,64 +242,8 @@ func (c *Client) isDisabledInHA(ctx context.Context, details sensor.Entity) bool return status } -func send[T any](ctx context.Context, client *Client, requestDetails any) (T, error) { - var ( - response T - responseErr apiError - responseObj *resty.Response - ) - - if client.endpoint == nil { - return response, ErrInvalidClient - } - - requestObj := client.endpoint.R().SetContext(ctx) - requestObj = requestObj.SetError(&responseErr) - requestObj = requestObj.SetResult(&response) - - // If the request is authenticated, set the auth header with the token. - if a, ok := requestDetails.(Authenticated); ok { - requestObj = requestObj.SetAuthToken(a.Auth()) - } - - switch req := requestDetails.(type) { - case PostRequest: - logging.FromContext(ctx). - LogAttrs(ctx, logging.LevelTrace, - "Sending request.", - slog.String("method", "POST"), - slog.String("body", string(req.RequestBody())), - slog.Time("sent_at", time.Now())) - - responseObj, _ = requestObj.SetBody(req.RequestBody()).Post("") //nolint:errcheck // error is checked with responseObj.IsError() - case GetRequest: - logging.FromContext(ctx). - LogAttrs(ctx, logging.LevelTrace, - "Sending request.", - slog.String("method", "GET"), - slog.Time("sent_at", time.Now())) - - responseObj, _ = requestObj.Get("") //nolint:errcheck // error is checked with responseObj.IsError() - } - - logging.FromContext(ctx). - LogAttrs(ctx, logging.LevelTrace, - "Received response.", - slog.Int("statuscode", responseObj.StatusCode()), - slog.String("status", responseObj.Status()), - slog.String("protocol", responseObj.Proto()), - slog.Duration("time", responseObj.Time()), - slog.String("body", string(responseObj.Body()))) - - if responseObj.IsError() { - return response, &apiError{Code: responseObj.StatusCode(), Message: responseObj.Status()} - } - - return response, nil -} - func GetSensor(id string) (*sensor.Entity, error) { - details, err := tracker.Get(id) + details, err := sensorTracker.Get(id) if err != nil { return nil, fmt.Errorf("could not get sensor details: %w", err) } @@ -404,7 +252,7 @@ func GetSensor(id string) (*sensor.Entity, error) { } func SensorList() []string { - return tracker.SensorList() + return sensorTracker.SensorList() } // sensorLogAttrs is a convienience function that returns some slog attributes diff --git a/internal/hass/client_mocks_test.go b/internal/hass/client_mocks_test.go deleted file mode 100644 index 59630256a..000000000 --- a/internal/hass/client_mocks_test.go +++ /dev/null @@ -1,278 +0,0 @@ -// Code generated by moq; DO NOT EDIT. -// github.com/matryer/moq - -package hass - -import ( - "encoding/json" - "sync" -) - -// Ensure, that PostRequestMock does implement PostRequest. -// If this is not the case, regenerate this file with moq. -var _ PostRequest = &PostRequestMock{} - -// PostRequestMock is a mock implementation of PostRequest. -// -// func TestSomethingThatUsesPostRequest(t *testing.T) { -// -// // make and configure a mocked PostRequest -// mockedPostRequest := &PostRequestMock{ -// RequestBodyFunc: func() json.RawMessage { -// panic("mock out the RequestBody method") -// }, -// } -// -// // use mockedPostRequest in code that requires PostRequest -// // and then make assertions. -// -// } -type PostRequestMock struct { - // RequestBodyFunc mocks the RequestBody method. - RequestBodyFunc func() json.RawMessage - - // calls tracks calls to the methods. - calls struct { - // RequestBody holds details about calls to the RequestBody method. - RequestBody []struct { - } - } - lockRequestBody sync.RWMutex -} - -// RequestBody calls RequestBodyFunc. -func (mock *PostRequestMock) RequestBody() json.RawMessage { - if mock.RequestBodyFunc == nil { - panic("PostRequestMock.RequestBodyFunc: method is nil but PostRequest.RequestBody was just called") - } - callInfo := struct { - }{} - mock.lockRequestBody.Lock() - mock.calls.RequestBody = append(mock.calls.RequestBody, callInfo) - mock.lockRequestBody.Unlock() - return mock.RequestBodyFunc() -} - -// RequestBodyCalls gets all the calls that were made to RequestBody. -// Check the length with: -// -// len(mockedPostRequest.RequestBodyCalls()) -func (mock *PostRequestMock) RequestBodyCalls() []struct { -} { - var calls []struct { - } - mock.lockRequestBody.RLock() - calls = mock.calls.RequestBody - mock.lockRequestBody.RUnlock() - return calls -} - -// Ensure, that RegistryMock does implement Registry. -// If this is not the case, regenerate this file with moq. -var _ Registry = &RegistryMock{} - -// RegistryMock is a mock implementation of Registry. -// -// func TestSomethingThatUsesRegistry(t *testing.T) { -// -// // make and configure a mocked Registry -// mockedRegistry := &RegistryMock{ -// IsDisabledFunc: func(id string) bool { -// panic("mock out the IsDisabled method") -// }, -// IsRegisteredFunc: func(id string) bool { -// panic("mock out the IsRegistered method") -// }, -// SetDisabledFunc: func(id string, state bool) error { -// panic("mock out the SetDisabled method") -// }, -// SetRegisteredFunc: func(id string, state bool) error { -// panic("mock out the SetRegistered method") -// }, -// } -// -// // use mockedRegistry in code that requires Registry -// // and then make assertions. -// -// } -type RegistryMock struct { - // IsDisabledFunc mocks the IsDisabled method. - IsDisabledFunc func(id string) bool - - // IsRegisteredFunc mocks the IsRegistered method. - IsRegisteredFunc func(id string) bool - - // SetDisabledFunc mocks the SetDisabled method. - SetDisabledFunc func(id string, state bool) error - - // SetRegisteredFunc mocks the SetRegistered method. - SetRegisteredFunc func(id string, state bool) error - - // calls tracks calls to the methods. - calls struct { - // IsDisabled holds details about calls to the IsDisabled method. - IsDisabled []struct { - // ID is the id argument value. - ID string - } - // IsRegistered holds details about calls to the IsRegistered method. - IsRegistered []struct { - // ID is the id argument value. - ID string - } - // SetDisabled holds details about calls to the SetDisabled method. - SetDisabled []struct { - // ID is the id argument value. - ID string - // State is the state argument value. - State bool - } - // SetRegistered holds details about calls to the SetRegistered method. - SetRegistered []struct { - // ID is the id argument value. - ID string - // State is the state argument value. - State bool - } - } - lockIsDisabled sync.RWMutex - lockIsRegistered sync.RWMutex - lockSetDisabled sync.RWMutex - lockSetRegistered sync.RWMutex -} - -// IsDisabled calls IsDisabledFunc. -func (mock *RegistryMock) IsDisabled(id string) bool { - if mock.IsDisabledFunc == nil { - panic("RegistryMock.IsDisabledFunc: method is nil but Registry.IsDisabled was just called") - } - callInfo := struct { - ID string - }{ - ID: id, - } - mock.lockIsDisabled.Lock() - mock.calls.IsDisabled = append(mock.calls.IsDisabled, callInfo) - mock.lockIsDisabled.Unlock() - return mock.IsDisabledFunc(id) -} - -// IsDisabledCalls gets all the calls that were made to IsDisabled. -// Check the length with: -// -// len(mockedRegistry.IsDisabledCalls()) -func (mock *RegistryMock) IsDisabledCalls() []struct { - ID string -} { - var calls []struct { - ID string - } - mock.lockIsDisabled.RLock() - calls = mock.calls.IsDisabled - mock.lockIsDisabled.RUnlock() - return calls -} - -// IsRegistered calls IsRegisteredFunc. -func (mock *RegistryMock) IsRegistered(id string) bool { - if mock.IsRegisteredFunc == nil { - panic("RegistryMock.IsRegisteredFunc: method is nil but Registry.IsRegistered was just called") - } - callInfo := struct { - ID string - }{ - ID: id, - } - mock.lockIsRegistered.Lock() - mock.calls.IsRegistered = append(mock.calls.IsRegistered, callInfo) - mock.lockIsRegistered.Unlock() - return mock.IsRegisteredFunc(id) -} - -// IsRegisteredCalls gets all the calls that were made to IsRegistered. -// Check the length with: -// -// len(mockedRegistry.IsRegisteredCalls()) -func (mock *RegistryMock) IsRegisteredCalls() []struct { - ID string -} { - var calls []struct { - ID string - } - mock.lockIsRegistered.RLock() - calls = mock.calls.IsRegistered - mock.lockIsRegistered.RUnlock() - return calls -} - -// SetDisabled calls SetDisabledFunc. -func (mock *RegistryMock) SetDisabled(id string, state bool) error { - if mock.SetDisabledFunc == nil { - panic("RegistryMock.SetDisabledFunc: method is nil but Registry.SetDisabled was just called") - } - callInfo := struct { - ID string - State bool - }{ - ID: id, - State: state, - } - mock.lockSetDisabled.Lock() - mock.calls.SetDisabled = append(mock.calls.SetDisabled, callInfo) - mock.lockSetDisabled.Unlock() - return mock.SetDisabledFunc(id, state) -} - -// SetDisabledCalls gets all the calls that were made to SetDisabled. -// Check the length with: -// -// len(mockedRegistry.SetDisabledCalls()) -func (mock *RegistryMock) SetDisabledCalls() []struct { - ID string - State bool -} { - var calls []struct { - ID string - State bool - } - mock.lockSetDisabled.RLock() - calls = mock.calls.SetDisabled - mock.lockSetDisabled.RUnlock() - return calls -} - -// SetRegistered calls SetRegisteredFunc. -func (mock *RegistryMock) SetRegistered(id string, state bool) error { - if mock.SetRegisteredFunc == nil { - panic("RegistryMock.SetRegisteredFunc: method is nil but Registry.SetRegistered was just called") - } - callInfo := struct { - ID string - State bool - }{ - ID: id, - State: state, - } - mock.lockSetRegistered.Lock() - mock.calls.SetRegistered = append(mock.calls.SetRegistered, callInfo) - mock.lockSetRegistered.Unlock() - return mock.SetRegisteredFunc(id, state) -} - -// SetRegisteredCalls gets all the calls that were made to SetRegistered. -// Check the length with: -// -// len(mockedRegistry.SetRegisteredCalls()) -func (mock *RegistryMock) SetRegisteredCalls() []struct { - ID string - State bool -} { - var calls []struct { - ID string - State bool - } - mock.lockSetRegistered.RLock() - calls = mock.calls.SetRegistered - mock.lockSetRegistered.RUnlock() - return calls -} diff --git a/internal/hass/request.go b/internal/hass/request.go index eeefb9e50..59613250c 100644 --- a/internal/hass/request.go +++ b/internal/hass/request.go @@ -3,15 +3,20 @@ // This software is released under the MIT License. // https://opensource.org/licenses/MIT -//revive:disable:max-public-structs +//go:generate go run github.com/matryer/moq -out request_mocks_test.go . PostRequest package hass import ( + "context" "encoding/json" "errors" "fmt" + "log/slog" + "time" - "github.com/joshuar/go-hass-agent/internal/hass/sensor" + "github.com/go-resty/resty/v2" + + "github.com/joshuar/go-hass-agent/internal/logging" ) const ( @@ -26,6 +31,30 @@ var ( ErrUnknownDetails = errors.New("unknown sensor details") ) +// GetRequest is a HTTP GET request. +type GetRequest any + +// PostRequest is a HTTP POST request with the request body provided by Body(). +type PostRequest interface { + RequestBody() json.RawMessage +} + +// Authenticated represents a request that requires passing an authentication +// header with the value returned by Auth(). +type Authenticated interface { + Auth() string +} + +// Encrypted represents a request that should be encrypted with the secret +// provided by Secret(). +type Encrypted interface { + Secret() string +} + +type Validator interface { + Validate() error +} + // LocationRequest represents the location information that can be sent to HA to // update the location of the agent. This is exposed so that device code can // create location requests directly, as Home Assistant handles these @@ -63,23 +92,58 @@ func (r *request) RequestBody() json.RawMessage { return json.RawMessage(data) } -func newEntityRequest(requestType string, entity sensor.Entity) (*request, error) { - var req *request - - switch requestType { - case requestTypeLocation: - req = &request{Data: entity.Value, RequestType: requestType} - case requestTypeRegister: - req = &request{Data: entity, RequestType: requestType} - case requestTypeUpdate: - req = &request{Data: entity.State, RequestType: requestType} - default: - return nil, ErrUnknownDetails +func send[T any](ctx context.Context, client *Client, requestDetails any) (T, error) { + var ( + response T + responseErr apiError + responseObj *resty.Response + ) + + if client.endpoint == nil { + return response, ErrInvalidClient } - if err := req.Validate(); err != nil { - return nil, fmt.Errorf("validation failed: %w", err) + requestObj := client.endpoint.R().SetContext(ctx) + requestObj = requestObj.SetError(&responseErr) + requestObj = requestObj.SetResult(&response) + + // If the request is authenticated, set the auth header with the token. + if a, ok := requestDetails.(Authenticated); ok { + requestObj = requestObj.SetAuthToken(a.Auth()) + } + + switch req := requestDetails.(type) { + case PostRequest: + logging.FromContext(ctx). + LogAttrs(ctx, logging.LevelTrace, + "Sending request.", + slog.String("method", "POST"), + slog.String("body", string(req.RequestBody())), + slog.Time("sent_at", time.Now())) + + responseObj, _ = requestObj.SetBody(req.RequestBody()).Post("") //nolint:errcheck // error is checked with responseObj.IsError() + case GetRequest: + logging.FromContext(ctx). + LogAttrs(ctx, logging.LevelTrace, + "Sending request.", + slog.String("method", "GET"), + slog.Time("sent_at", time.Now())) + + responseObj, _ = requestObj.Get("") //nolint:errcheck // error is checked with responseObj.IsError() + } + + logging.FromContext(ctx). + LogAttrs(ctx, logging.LevelTrace, + "Received response.", + slog.Int("statuscode", responseObj.StatusCode()), + slog.String("status", responseObj.Status()), + slog.String("protocol", responseObj.Proto()), + slog.Duration("time", responseObj.Time()), + slog.String("body", string(responseObj.Body()))) + + if responseObj.IsError() { + return response, &apiError{Code: responseObj.StatusCode(), Message: responseObj.Status()} } - return req, nil + return response, nil } diff --git a/internal/hass/request_test.go b/internal/hass/request_test.go index ce06d895e..7d0111809 100644 --- a/internal/hass/request_test.go +++ b/internal/hass/request_test.go @@ -7,7 +7,6 @@ package hass import ( - "reflect" "testing" "github.com/joshuar/go-hass-agent/internal/hass/sensor" @@ -70,65 +69,3 @@ func Test_request_Validate(t *testing.T) { }) } } - -func Test_newEntityRequest(t *testing.T) { - locationEntity := sensor.Entity{ - State: &sensor.State{ - Value: &LocationRequest{ - Gps: []float64{0.0, 0.0}, - }, - }, - } - - entity := sensor.Entity{ - Name: "Mock Entity", - State: &sensor.State{ - ID: "mock_entity", - Value: "test", - }, - } - - type args struct { - requestType string - entity sensor.Entity - } - tests := []struct { - want *request - name string - args args - wantErr bool - }{ - { - name: "location request", - args: args{requestType: requestTypeLocation, entity: locationEntity}, - want: &request{Data: locationEntity.Value, RequestType: requestTypeLocation}, - }, - { - name: "update request", - args: args{requestType: requestTypeUpdate, entity: entity}, - want: &request{Data: entity.State, RequestType: requestTypeUpdate}, - }, - { - name: "registration request", - args: args{requestType: requestTypeRegister, entity: entity}, - want: &request{Data: entity, RequestType: requestTypeRegister}, - }, - { - name: "no request type", - args: args{entity: entity}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := newEntityRequest(tt.args.requestType, tt.args.entity) - if (err != nil) != tt.wantErr { - t.Errorf("newEntityRequest() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("newEntityRequest() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/hass/response.go b/internal/hass/response.go index ef4f705c6..88282c382 100644 --- a/internal/hass/response.go +++ b/internal/hass/response.go @@ -6,10 +6,28 @@ package hass import ( + "context" "fmt" + "log/slog" "strings" + + "github.com/joshuar/go-hass-agent/internal/hass/sensor" + "github.com/joshuar/go-hass-agent/internal/logging" +) + +const ( + Registered responseStatus = iota + 1 + Updated + Disabled + Failed ) +type responseStatus int + +type Response interface { + Status() (responseStatus, error) +} + type apiError struct { Code any `json:"code,omitempty"` Message string `json:"message,omitempty"` @@ -32,47 +50,111 @@ func (e *apiError) Error() string { return strings.Join(msg, ": ") } -type responseStatus struct { - ErrorDetails *apiError - IsSuccess bool `json:"success,omitempty"` +type response struct { + ErrorDetails *apiError `json:"error,omitempty"` + IsSuccess bool `json:"success,omitempty"` } -type updateResponseStatus struct { - responseStatus - IsDisabled bool `json:"is_disabled,omitempty"` +func (r *response) Status() (responseStatus, error) { + if r.IsSuccess { + return Updated, nil + } + + return Failed, r.ErrorDetails } -func (u *updateResponseStatus) disabled() bool { - return u.IsDisabled +type sensorUpdateReponse struct { + response + IsDisabled bool `json:"is_disabled,omitempty"` } -func (u *updateResponseStatus) success() (bool, error) { - if u.IsSuccess { - return true, nil +func (u *sensorUpdateReponse) Status() (responseStatus, error) { + switch { + case !u.IsSuccess: + return Failed, u.ErrorDetails + case u.IsDisabled: + return Disabled, u.ErrorDetails + default: + return Updated, nil } - - return false, u.ErrorDetails } -type stateUpdateResponse map[string]updateResponseStatus +type bulkSensorUpdateResponse map[string]sensorUpdateReponse + +func (u bulkSensorUpdateResponse) Process(ctx context.Context, details sensor.Entity) { + for id, sensorReponse := range u { + status, err := sensorReponse.Status() + + switch status { + case Failed: + logging.FromContext(ctx).Warn("Sensor update failed.", + slog.String("id", id), + slog.Any("error", err)) + + return + case Disabled: + // Already disabled in registry, nothing to do. + if sensorRegistry.IsDisabled(id) { + return + } + // Disable in registry. + logging.FromContext(ctx). + Info("Sensor is disabled in Home Assistant. Setting disabled in local registry.", + slog.String("id", id)) + + if err := sensorRegistry.SetDisabled(id, true); err != nil { + logging.FromContext(ctx).Warn("Unable to disable sensor in registry.", + slog.String("id", id), + slog.Any("error", err)) + } + case Updated: + logging.FromContext(ctx). + Debug("Sensor updated.", + sensorLogAttrs(details)) + } + + // Add the sensor update to the tracker. + if err := sensorTracker.Add(&details); err != nil { + logging.FromContext(ctx).Warn("Unable to update sensor state in tracker.", + slog.String("id", id), + slog.Any("error", err)) + } + } +} -type registrationResponse responseStatus +type registrationResponse response -func (r *registrationResponse) registered() (bool, error) { +func (r *registrationResponse) Status() (responseStatus, error) { if r.IsSuccess { - return true, nil + return Registered, nil } - return false, r.ErrorDetails + return Failed, r.ErrorDetails } -type locationResponse struct { - error -} - -//nolint:staticcheck -func (r *locationResponse) updated() error { - return r +func (r *registrationResponse) Process(ctx context.Context, details sensor.Entity) { + status, err := r.Status() + + switch status { + case Failed: + logging.FromContext(ctx).Warn("Sensor registration failed.", + slog.String("id", details.ID), + slog.Any("error", err)) + + return + case Registered: + // Set registration status in registry. + err = sensorRegistry.SetRegistered(details.ID, true) + if err != nil { + logging.FromContext(ctx).Warn("Unable to set sensor registration in registry.", + slog.String("id", details.ID), + slog.Any("error", err)) + } + // Add the sensor update to the tracker. + if err := sensorTracker.Add(&details); err != nil { + logging.FromContext(ctx).Warn("Unable to update sensor state in tracker.", + slog.String("id", details.ID), + slog.Any("error", err)) + } + } } - -type eventResponse struct{}