diff --git a/storage/mysql/mysql.go b/storage/mysql/mysql.go index ec9a9529..2adc33c6 100644 --- a/storage/mysql/mysql.go +++ b/storage/mysql/mysql.go @@ -20,6 +20,7 @@ import ( "context" "database/sql" "fmt" + "strings" "time" _ "github.com/go-sql-driver/mysql" @@ -27,7 +28,6 @@ import ( tessera "github.com/transparency-dev/trillian-tessera" "github.com/transparency-dev/trillian-tessera/api" "github.com/transparency-dev/trillian-tessera/storage" - "golang.org/x/sync/errgroup" "k8s.io/klog/v2" ) @@ -259,35 +259,48 @@ func (s *Storage) sequenceBatch(ctx context.Context, entries []*tessera.Entry) e // integrate incorporates the provided entries into the log starting at fromSeq. func (s *Storage) integrate(ctx context.Context, tx *sql.Tx, fromSeq uint64, entries []*tessera.Entry) error { tb := storage.NewTreeBuilder(func(ctx context.Context, tileIDs []storage.TileID, treeSize uint64) ([]*api.HashTile, error) { - r := make([]*api.HashTile, len(tileIDs)) + hashTiles := make([]*api.HashTile, len(tileIDs)) - // TODO(#21): Refactor the following to fully utilise the MySQL for fetching multiple tiles in one query with the same ordering. - errG := errgroup.Group{} + // Build the SQL and args to fetch the hash tiles. + var sql strings.Builder + args := make([]uint64, 0, len(tileIDs)*2) for i, id := range tileIDs { - i := i - id := id - errG.Go(func() error { - row := tx.QueryRowContext(ctx, selectSubtreeByLevelAndIndexSQL, id.Level, id.Index) - if err := row.Err(); err != nil { - return err - } - - var tile []byte - if err := row.Scan(&tile); err != nil { - return err - } - t := &api.HashTile{} - if err := t.UnmarshalText(tile); err != nil { - return fmt.Errorf("api.HashTile.unmarshalText(level: %d, index: %d): %w", id.Level, id.Index, err) - } - r[i] = t - return nil - }) + if i != 0 { + sql.WriteString(" UNION ALL ") + } + _, err := sql.WriteString(selectSubtreeByLevelAndIndexSQL) + if err != nil { + return nil, err + } + args = append(args, id.Level, id.Index) } - if err := errG.Wait(); err != nil { - return nil, err + + rows, err := tx.QueryContext(ctx, sql.String(), args) + if err != nil { + return nil, fmt.Errorf("failed to query the hash tiles with SQL (%s): %w", sql.String(), err) + } + defer func() { + if err := rows.Close(); err != nil { + klog.Warningf("Failed to close the rows: %v", err) + } + }() + + for rows.Next() { + var tile []byte + if err := rows.Scan(&tile); err != nil { + return nil, fmt.Errorf("rows.Scan: %w", err) + } + t := &api.HashTile{} + if err := t.UnmarshalText(tile); err != nil { + return nil, fmt.Errorf("api.HashTile.unmarshalText: %w", err) + } + hashTiles = append(hashTiles, t) + } + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("rows.Err: %w", err) } - return r, nil + + return hashTiles, nil }) sequencedEntries := make([]storage.SequencedEntry, len(entries))