diff --git a/pkg/sqlite/blob.go b/pkg/sqlite/blob.go index 0c814ffa183..36703ea1d31 100644 --- a/pkg/sqlite/blob.go +++ b/pkg/sqlite/blob.go @@ -308,17 +308,12 @@ func (qb *BlobStore) readFromDatabase(ctx context.Context, checksum string) (sql // Delete marks a checksum as no longer in use by a single reference. // If no references remain, the blob is deleted from the database and filesystem. func (qb *BlobStore) Delete(ctx context.Context, checksum string) error { - rollid, err := savepoint(ctx) - if err != nil { - return fmt.Errorf("savepoint %s: %w", rollid, err) - } - // try to delete the blob from the database if err := qb.delete(ctx, checksum); err != nil { if qb.isConstraintError(err) { // blob is still referenced - do not delete logger.Debugf("Blob %s is still referenced - not deleting", checksum) - return rollbackToSavepoint(ctx, rollid) + return nil } // unexpected error @@ -358,11 +353,14 @@ func (qb *BlobStore) delete(ctx context.Context, checksum string) error { q := dialect.Delete(table).Where(goqu.C(blobChecksumColumn).Eq(checksum)) - _, err := exec(ctx, q) + err := withSavepoint(ctx, func(ctx context.Context) error { + _, err := exec(ctx, q) + return err + }) + if err != nil { return fmt.Errorf("deleting from %s: %w", table, err) } - return nil } diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go index d900eed9f15..254d07dd8bd 100644 --- a/pkg/sqlite/performer_test.go +++ b/pkg/sqlite/performer_test.go @@ -1155,9 +1155,12 @@ func TestPerformerQueryForAutoTag(t *testing.T) { t.Errorf("Error finding performers: %s", err.Error()) } - assert.Len(t, performers, 2) - assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[0].Name)) - assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[1].Name)) + if assert.Len(t, performers, 2) { + assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[0].Name)) + assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[1].Name)) + } else { + t.Errorf("Skipping performer comparison as atleast 1 is missing") + } return nil }) diff --git a/pkg/sqlite/table.go b/pkg/sqlite/table.go index cba0d53c2e9..b2c473b9f0c 100644 --- a/pkg/sqlite/table.go +++ b/pkg/sqlite/table.go @@ -12,7 +12,6 @@ import ( "github.com/jmoiron/sqlx" "gopkg.in/guregu/null.v4" - "github.com/stashapp/stash/pkg/hash" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/sliceutil" @@ -1155,32 +1154,6 @@ func execID(ctx context.Context, stmt sqler) (*int64, error) { return &id, nil } -func savepoint(ctx context.Context) (string, error) { - tx, err := getTx(ctx) - if err != nil { - return "", err - } - - // Generate savepoint - rnd, err := hash.GenerateRandomKey(64) - if err != nil { - return "", err - } - - _, err = tx.QueryxContext(ctx, "SAVEPOINT "+rnd) - return rnd, err -} - -func rollbackToSavepoint(ctx context.Context, id string) error { - tx, err := getTx(ctx) - if err != nil { - return err - } - - _, err = tx.QueryxContext(ctx, "ROLLBACK TO SAVEPOINT "+id) - return err -} - func count(ctx context.Context, q *goqu.SelectDataset) (int, error) { var count int if err := querySimple(ctx, q, &count); err != nil { diff --git a/pkg/sqlite/tx.go b/pkg/sqlite/tx.go index e0aa2265485..46eb68fb76a 100644 --- a/pkg/sqlite/tx.go +++ b/pkg/sqlite/tx.go @@ -7,6 +7,7 @@ import ( "time" "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/hash" "github.com/stashapp/stash/pkg/logger" ) @@ -174,3 +175,43 @@ func (db *dbWrapperType) ExecStmt(ctx context.Context, stmt *stmt, args ...inter return ret, sqlError(err, stmt.query, args...) } + +type SavepointAction func(ctx context.Context) error + +func withSavepoint(ctx context.Context, action SavepointAction) error { + tx, err := getTx(ctx) + if err != nil { + return err + } + + // Generate savepoint + rnd, err := hash.GenerateRandomKey(64) + if err != nil { + return err + } + rnd = "savepoint_" + rnd + + // Create a savepoint + _, err = tx.Exec("SAVEPOINT " + rnd) + if err != nil { + return fmt.Errorf("failed to create savepoint: %w", err) + } + + // Execute the action + err = action(ctx) + if err != nil { + // Rollback to savepoint on error + if _, rbErr := tx.Exec("ROLLBACK TO SAVEPOINT " + rnd); rbErr != nil { + return fmt.Errorf("action failed and rollback to savepoint failed: %w", rbErr) + } + return fmt.Errorf("action failed: %w", err) + } + + // Release the savepoint on success + _, err = tx.Exec("RELEASE SAVEPOINT " + rnd) + if err != nil { + return fmt.Errorf("failed to release savepoint: %w", err) + } + + return nil +}