Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(policy): 1435 tech debt migrate attribute definition object queries to sqlc #1450

Merged
Merged
2 changes: 1 addition & 1 deletion service/integration/attribute_values_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ func (s *AttributeValuesSuite) Test_DeactivateAttribute_Cascades_List() {
}

listAttributes := func(state string) bool {
listedAttrs, err := s.db.PolicyClient.ListAllAttributes(s.ctx, state, "")
listedAttrs, err := s.db.PolicyClient.ListAttributes(s.ctx, state, "")
s.Require().NoError(err)
s.NotNil(listedAttrs)
for _, a := range listedAttrs {
Expand Down
46 changes: 27 additions & 19 deletions service/integration/attributes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"slices"
"strings"
"testing"
"time"

"github.com/opentdf/platform/protocol/go/common"
"github.com/opentdf/platform/protocol/go/policy"
Expand Down Expand Up @@ -367,10 +366,10 @@ func (s *AttributesSuite) Test_GetAttribute_ContainsKASGrants() {
s.Equal(createdKAS.GetId(), gotAttr.GetGrants()[0].GetId())
}

func (s *AttributesSuite) Test_ListAttribute() {
func (s *AttributesSuite) Test_ListAttributes() {
fixtures := s.getAttributeFixtures()

list, err := s.db.PolicyClient.ListAllAttributes(s.ctx, policydb.StateActive, "")
list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateActive, "")
s.Require().NoError(err)
s.NotNil(list)

Expand All @@ -387,7 +386,7 @@ func (s *AttributesSuite) Test_ListAttribute() {
}
}

func (s *AttributesSuite) Test_ListAttribute_FqnsIncluded() {
func (s *AttributesSuite) Test_ListAttributes_FqnsIncluded() {
// create an attribute
attr := &attributes.CreateAttributeRequest{
Name: "list_attribute_fqns_new_attr",
Expand All @@ -399,7 +398,7 @@ func (s *AttributesSuite) Test_ListAttribute_FqnsIncluded() {
s.Require().NoError(err)
s.NotNil(createdAttr)

list, err := s.db.PolicyClient.ListAllAttributes(s.ctx, policydb.StateActive, fixtureNamespaceID)
list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateActive, fixtureNamespaceID)
s.Require().NoError(err)
s.NotNil(list)

Expand All @@ -420,15 +419,15 @@ func (s *AttributesSuite) Test_ListAttribute_FqnsIncluded() {
}
}

func (s *AttributesSuite) Test_ListAttributesByNamespace() {
func (s *AttributesSuite) Test_ListAttributes_ByNamespaceIdOrName() {
// get all unique namespace_ids
namespaces := map[string]string{}
for _, f := range s.getAttributeFixtures() {
namespaces[f.NamespaceID] = ""
}
// list attributes by namespace id
for nsID := range namespaces {
list, err := s.db.PolicyClient.ListAllAttributes(s.ctx, policydb.StateAny, nsID)
list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateAny, nsID)
s.Require().NoError(err)
s.NotNil(list)
s.NotEmpty(list)
Expand All @@ -440,7 +439,19 @@ func (s *AttributesSuite) Test_ListAttributesByNamespace() {

// list attributes by namespace name
for _, nsName := range namespaces {
list, err := s.db.PolicyClient.ListAllAttributes(s.ctx, policydb.StateAny, nsName)
list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateAny, nsName)
s.Require().NoError(err)
s.NotNil(list)
s.NotEmpty(list)
for _, l := range list {
s.Equal(nsName, l.GetNamespace().GetName())
}
}

// list attributes by namespace name with case insensitivity
for _, nsName := range namespaces {
upperNsName := strings.ToUpper(nsName)
list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateAny, upperNsName)
s.Require().NoError(err)
s.NotNil(list)
s.NotEmpty(list)
Expand Down Expand Up @@ -478,14 +489,7 @@ func (s *AttributesSuite) Test_UpdateAttribute() {
Labels: labels,
},
}
start := time.Now().Add(-time.Second)
created, err := s.db.PolicyClient.CreateAttribute(s.ctx, attr)
end := time.Now().Add(time.Second)
metadata := created.GetMetadata()
updatedAt := metadata.GetUpdatedAt()
createdAt := metadata.GetCreatedAt()
s.True(createdAt.AsTime().After(start))
s.True(createdAt.AsTime().Before(end))
s.Require().NoError(err)
s.NotNil(created)

Expand All @@ -511,7 +515,12 @@ func (s *AttributesSuite) Test_UpdateAttribute() {
s.NotNil(got)
s.Equal(created.GetId(), got.GetId())
s.EqualValues(expectedLabels, got.GetMetadata().GetLabels())
s.True(got.GetMetadata().GetUpdatedAt().AsTime().After(updatedAt.AsTime()))
metadata := got.GetMetadata()
createdAt := metadata.GetCreatedAt()
updatedAt := metadata.GetUpdatedAt()
s.False(createdAt.AsTime().IsZero())
s.False(updatedAt.AsTime().IsZero())
s.True(updatedAt.AsTime().After(createdAt.AsTime()))
}

func (s *AttributesSuite) Test_UpdateAttribute_WithInvalidIdFails() {
Expand Down Expand Up @@ -752,7 +761,6 @@ func (s *AttributesSuite) Test_UnsafeUpdateAttribute_ReplaceValuesOrder() {
})
s.Require().NoError(err)
s.NotNil(updated)
s.Len(updated.GetValues(), 3)

// get attribute and ensure the order of the values is preserved and successfully reversed
got, err := s.db.PolicyClient.GetAttribute(s.ctx, created.GetId())
Expand Down Expand Up @@ -801,7 +809,7 @@ func (s *AttributesSuite) Test_UnsafeDeleteAttribute() {
s.NotEqual("", ns.GetId())

// attribute should not be listed anymore
list, err := s.db.PolicyClient.ListAllAttributes(s.ctx, policydb.StateAny, fixtureNamespaceID)
list, err := s.db.PolicyClient.ListAttributes(s.ctx, policydb.StateAny, fixtureNamespaceID)
s.Require().NoError(err)
s.NotNil(list)
for _, l := range list {
Expand Down Expand Up @@ -915,7 +923,7 @@ func (s *AttributesSuite) Test_DeactivateAttribute_Cascades_List() {
}

listAttributes := func(state string) bool {
listedAttrs, err := s.db.PolicyClient.ListAllAttributes(s.ctx, state, "")
listedAttrs, err := s.db.PolicyClient.ListAttributes(s.ctx, state, "")
s.Require().NoError(err)
s.NotNil(listedAttrs)
for _, a := range listedAttrs {
Expand Down
2 changes: 1 addition & 1 deletion service/integration/namespaces_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ func (s *NamespacesSuite) Test_DeactivateNamespace_Cascades_List() {
}

listAttributes := func(state string) bool {
listedAttrs, err := s.db.PolicyClient.ListAllAttributes(s.ctx, state, "")
listedAttrs, err := s.db.PolicyClient.ListAttributes(s.ctx, state, "")
s.Require().NoError(err)
s.NotNil(listedAttrs)
for _, a := range listedAttrs {
Expand Down
32 changes: 14 additions & 18 deletions service/policy/attributes/attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (s AttributesService) CreateAttribute(ctx context.Context,
s.logger.Debug("created new attribute definition", slog.String("name", req.GetName()))

auditParams.ObjectID = item.GetId()
auditParams.Original = item
s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

rsp.Attribute = item
Expand All @@ -69,7 +70,7 @@ func (s *AttributesService) ListAttributes(ctx context.Context,
s.logger.Debug("listing attribute definitions", slog.String("state", state))
rsp := &attributes.ListAttributesResponse{}

list, err := s.dbClient.ListAllAttributes(ctx, state, namespace)
list, err := s.dbClient.ListAttributes(ctx, state, namespace)
if err != nil {
return nil, db.StatusifyError(err, db.ErrTextListRetrievalFailed)
}
Expand Down Expand Up @@ -124,25 +125,20 @@ func (s *AttributesService) UpdateAttribute(ctx context.Context,
return nil, db.StatusifyError(err, db.ErrTextGetRetrievalFailed, slog.String("id", attributeID))
}

item, err := s.dbClient.UpdateAttribute(ctx, req.GetId(), req)
updated, err := s.dbClient.UpdateAttribute(ctx, attributeID, req)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextUpdateFailed, slog.String("id", req.GetId()), slog.String("attribute", req.String()))
}

// Item above only contains the attribute ID so we need to get the full
// attribute definition to compute the diff.
updated, err := s.dbClient.GetAttribute(ctx, attributeID)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextGetRetrievalFailed, slog.String("id", attributeID))
}

auditParams.Original = original
auditParams.Updated = updated
s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

rsp.Attribute = item
rsp.Attribute = &policy.Attribute{
Id: attributeID,
}

return rsp, nil
}

Expand All @@ -158,25 +154,25 @@ func (s *AttributesService) DeactivateAttribute(ctx context.Context,
ObjectID: attributeID,
}

originalAttribute, err := s.dbClient.GetAttribute(ctx, attributeID)
original, err := s.dbClient.GetAttribute(ctx, attributeID)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextGetRetrievalFailed, slog.String("id", attributeID))
}

// DeactivateAttribute actually returns the entire attribute so we can use it
// to compute the diff.
deactivatedAttribute, err := s.dbClient.DeactivateAttribute(ctx, attributeID)
updated, err := s.dbClient.DeactivateAttribute(ctx, attributeID)
if err != nil {
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
return nil, db.StatusifyError(err, db.ErrTextDeactivationFailed, slog.String("id", attributeID))
}

auditParams.Original = originalAttribute
auditParams.Updated = deactivatedAttribute
auditParams.Original = original
auditParams.Updated = updated
s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)

rsp.Attribute = deactivatedAttribute
rsp.Attribute = &policy.Attribute{
Id: attributeID,
}
return rsp, nil
}

Expand Down
2 changes: 1 addition & 1 deletion service/policy/db/attribute_fqn.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (c *PolicyDBClient) AttrFqnReindex(ctx context.Context) (res struct { //nol
}

// Get all attributes
attrs, err := c.ListAllAttributesWithout(ctx, StateAny)
attrs, err := c.ListAllAttributes(ctx)
if err != nil {
panic(fmt.Errorf("could not get attributes: %w", err))
}
Expand Down
Loading
Loading