Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: do not roll back transaction on partial identity insert error #4211

Merged
merged 6 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ func TestHandler(t *testing.T) {
})

t.Run("suite=PATCH identities", func(t *testing.T) {
t.Run("case=fails on > 100 identities", func(t *testing.T) {
t.Run("case=fails with too many patches", func(t *testing.T) {
tooMany := make([]*identity.BatchIdentityPatch, identity.BatchPatchIdentitiesLimit+1)
for i := range tooMany {
tooMany[i] = &identity.BatchIdentityPatch{Create: validCreateIdentityBody("too-many-patches", i)}
Expand All @@ -767,8 +767,8 @@ func TestHandler(t *testing.T) {
t.Run("case=fails some on a bad identity", func(t *testing.T) {
// Test setup: we have a list of valid identitiy patches and a list of invalid ones.
// Each run adds one invalid patch to the list and sends it to the server.
// --> we expect the server to fail all patches in the list.
// Finally, we send just the valid patches
// --> we expect the server to fail only the bad patches in the list.
// Finally, we send just valid patches
// --> we expect the server to succeed all patches in the list.

t.Run("case=invalid patches fail", func(t *testing.T) {
Expand All @@ -782,24 +782,23 @@ func TestHandler(t *testing.T) {
{Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits
{Create: validCreateIdentityBody("valid", 4)},
}
expectedToPass := []*identity.BatchIdentityPatch{patches[0], patches[1], patches[3], patches[5], patches[7]}

// Create unique IDs for each patch
var patchIDs []string
patchIDs := make([]string, len(patches))
for i, p := range patches {
id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i))
p.ID = &id
patchIDs = append(patchIDs, id.String())
patchIDs[i] = id.String()
}

req := &identity.BatchPatchIdentitiesBody{Identities: patches}
body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req)
var actions []string
for _, a := range body.Get("identities.#.action").Array() {
actions = append(actions, a.String())
}
assert.Equal(t,
require.NoErrorf(t, json.Unmarshal(([]byte)(body.Get("identities.#.action").Raw), &actions), "%s", body)
assert.Equalf(t,
[]string{"create", "create", "error", "create", "error", "create", "error", "create"},
actions, body)
actions, "%s", body)

// Check that all patch IDs are returned
for i, gotPatchID := range body.Get("identities.#.patch_id").Array() {
Expand All @@ -811,6 +810,27 @@ func TestHandler(t *testing.T) {
assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String())
assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String())

var identityIDs []uuid.UUID
require.NoErrorf(t, json.Unmarshal(([]byte)(body.Get("identities.#.identity").Raw), &identityIDs), "%s", body)

actualIdentities, _, err := reg.Persister().ListIdentities(ctx, identity.ListIdentityParameters{IdsFilter: identityIDs})
require.NoError(t, err)
actualIdentityIDs := make([]uuid.UUID, len(actualIdentities))
for i, id := range actualIdentities {
actualIdentityIDs[i] = id.ID
}
assert.ElementsMatchf(t, identityIDs, actualIdentityIDs, "%s", body)

expectedTraits := make(map[string]string, len(expectedToPass))
for i, p := range expectedToPass {
expectedTraits[identityIDs[i].String()] = string(p.Create.Traits)
}
actualTraits := make(map[string]string, len(actualIdentities))
for _, id := range actualIdentities {
actualTraits[id.ID.String()] = string(id.Traits)
}

assert.Equal(t, expectedTraits, actualTraits)
})

t.Run("valid patches succeed", func(t *testing.T) {
Expand Down Expand Up @@ -1928,7 +1948,7 @@ func validCreateIdentityBody(prefix string, i int) *identity.CreateIdentityBody
identity.VerifiableAddressStatusCompleted,
}

for j := 0; j < 4; j++ {
for j := range 4 {
email := fmt.Sprintf("%s-%d-%[email protected]", prefix, i, j)
traits.Emails = append(traits.Emails, email)
verifiableAddresses = append(verifiableAddresses, identity.VerifiableAddress{
Expand Down
10 changes: 8 additions & 2 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ type CreateIdentitiesError struct {
failedIdentities map[*Identity]*herodot.DefaultError
}

func NewCreateIdentitiesError(capacity int) *CreateIdentitiesError {
return &CreateIdentitiesError{
failedIdentities: make(map[*Identity]*herodot.DefaultError, capacity),
}
}

func (e *CreateIdentitiesError) Error() string {
e.init()
return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities))
Expand Down Expand Up @@ -370,7 +376,7 @@ func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity {
return nil
}
func (e *CreateIdentitiesError) ErrOrNil() error {
if len(e.failedIdentities) == 0 {
if e == nil || len(e.failedIdentities) == 0 {
return nil
}
return e
Expand All @@ -385,7 +391,7 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity,
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities")
defer otelx.End(span, &err)

createIdentitiesError := &CreateIdentitiesError{}
createIdentitiesError := NewCreateIdentitiesError(len(identities))
validIdentities := make([]*Identity, 0, len(identities))
for _, ident := range identities {
if ident.SchemaID == "" {
Expand Down
48 changes: 48 additions & 0 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,60 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager,
assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute)
// because of mysql precision
assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second)
assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second)

require.NoError(t, p.DeleteIdentity(ctx, id.ID))
}
})

t.Run("create exactly the non-conflicting ones", func(t *testing.T) {
identities := make([]*identity.Identity, 100)
for i := range identities {
identities[i] = NewTestIdentity(4, "persister-create-multiple-2", i%60)
}
err := p.CreateIdentities(ctx, identities...)
if dbname == "mysql" {
// partial inserts are not supported on mysql
assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation)
return
}

errWithCtx := new(identity.CreateIdentitiesError)
require.ErrorAsf(t, err, &errWithCtx, "%#v", err)

for _, id := range identities[:60] {
require.NotZero(t, id.ID)

idFromDB, err := p.GetIdentity(ctx, id.ID, identity.ExpandEverything)
require.NoError(t, err)

credFromDB := idFromDB.Credentials[identity.CredentialsTypePassword]
assert.Equal(t, id.ID, idFromDB.ID)
assert.Equal(t, id.SchemaID, idFromDB.SchemaID)
assert.Equal(t, id.SchemaURL, idFromDB.SchemaURL)
assert.Equal(t, id.State, idFromDB.State)

// We test that the values are plausible in the handler test already.
assert.Equal(t, len(id.VerifiableAddresses), len(idFromDB.VerifiableAddresses))
assert.Equal(t, len(id.RecoveryAddresses), len(idFromDB.RecoveryAddresses))

assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute)
// because of mysql precision
assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second)
assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second)

require.NoError(t, p.DeleteIdentity(ctx, id.ID))
}

for _, id := range identities[60:] {
failed := errWithCtx.Find(id)
assert.NotNil(t, failed)
}
})
})

t.Run("case=should error when the identity ID does not exist", func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
13 changes: 9 additions & 4 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,16 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
}
}()

return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
var partialErr *identity.CreateIdentitiesError
if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
conn := &batch.TracerConnection{
Tracer: p.r.Tracer(ctx),
Connection: tx,
}

succeededIDs = make([]uuid.UUID, 0, len(identities))
failedIdentityIDs := make(map[uuid.UUID]struct{})
partialErr = nil

// Don't use batch.WithPartialInserts, because identities have no other
// constraints other than the primary key that could cause conflicts.
Expand Down Expand Up @@ -620,7 +622,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
// If any of the batch inserts failed on conflict, let's delete the corresponding
// identities and return a list of failed identities in the error.
if len(failedIdentityIDs) > 0 {
partialErr := &identity.CreateIdentitiesError{}
partialErr = identity.NewCreateIdentitiesError(len(failedIdentityIDs))
failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs))

for _, ident := range identities {
Expand All @@ -637,7 +639,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
return sqlcon.HandleError(err)
}

return partialErr
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that rolling back the transaction is unwanted, but we do want to return information about which identity inserts failed. Don't we lose that information here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, I misread, looks good!

return nil
} else {
// No failures: report all identities as created.
for _, ident := range identities {
Expand All @@ -646,7 +648,10 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
}

return nil
})
}); err != nil {
return err
}
return partialErr.ErrOrNil()
}

func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) {
Expand Down
Loading