diff --git a/impl/sql/mutationstorage/queue.go b/impl/sql/mutationstorage/queue.go index dc1f9190e..f8421b05b 100644 --- a/impl/sql/mutationstorage/queue.go +++ b/impl/sql/mutationstorage/queue.go @@ -17,6 +17,7 @@ package mutationstorage import ( "context" "database/sql" + "math/rand" "time" "github.com/google/keytransparency/core/mutator" @@ -42,18 +43,38 @@ func (m *Mutations) AddShards(ctx context.Context, domainID string, shardIDs ... return nil } +// randShard returns a random, enabled shard for domainID. +func (m *Mutations) randShard(ctx context.Context, domainID string) (int64, error) { + // Read all enabled shards for domainID. + // TODO(gbelvin): Cache these results. + var shardIDs []int64 + rows, err := m.db.QueryContext(ctx, + `SELECT ShardID from Shards WHERE DomainID = ? AND Enabled = ?;`, + domainID, true) + if err != nil { + return 0, err + } + for rows.Next() { + var shardID int64 + rows.Scan(&shardID) + shardIDs = append(shardIDs, shardID) + } + if err := rows.Err(); err != nil { + return 0, err + } + if len(shardIDs) == 0 { + return 0, status.Errorf(codes.NotFound, "No shard found for domain %v", domainID) + } + + // Return a random shard. + return shardIDs[rand.Intn(len(shardIDs))], nil +} + // Send writes mutations to the leading edge (by sequence number) of the mutations table. func (m *Mutations) Send(ctx context.Context, domainID string, update *pb.EntryUpdate) error { glog.Infof("mutationstorage: Send(%v, )", domainID) - // Select a shard to write to - var shardID int64 - err := m.db.QueryRowContext(ctx, - `SELECT ShardID from Shards WHERE DomainID = ? AND Enabled = ? ORDER BY RANDOM() LIMIT 1;`, - domainID, true).Scan(&shardID) - switch { - case err == sql.ErrNoRows: - return status.Errorf(codes.NotFound, "No shard found for domain %v", domainID) - case err != nil: + shardID, err := m.randShard(ctx, domainID) + if err != nil { return err } diff --git a/impl/sql/mutationstorage/queue_test.go b/impl/sql/mutationstorage/queue_test.go index c75e16144..3c4ebf211 100644 --- a/impl/sql/mutationstorage/queue_test.go +++ b/impl/sql/mutationstorage/queue_test.go @@ -19,11 +19,59 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + _ "github.com/mattn/go-sqlite3" pb "github.com/google/keytransparency/core/api/v1/keytransparency_go_proto" ) +func TestRandShard(t *testing.T) { + ctx := context.Background() + db := newDB(t) + m, err := New(db) + if err != nil { + t.Fatalf("Failed to create mutations: %v", err) + } + domainID := "foo" + + for _, tc := range []struct { + desc string + send bool + wantCode codes.Code + wantShards map[int64]bool + }{ + {desc: "no rows", wantCode: codes.NotFound, wantShards: map[int64]bool{}}, + {desc: "second", send: true, wantShards: map[int64]bool{ + 1: true, + 2: true, + 3: true, + }}, + } { + if tc.send { + if err := m.AddShards(ctx, domainID, 1, 2, 3); err != nil { + t.Fatalf("AddShards(): %v", err) + } + } + shards := make(map[int64]bool) + for i := 0; i < 20; i++ { + shard, err := m.randShard(ctx, domainID) + if got, want := status.Code(err), tc.wantCode; got != want { + t.Errorf("randShard(): %v, want %v", got, want) + } + if err != nil { + break + } + shards[shard] = true + } + if got, want := shards, tc.wantShards; !cmp.Equal(got, want) { + t.Errorf("shards: %v, want %v", got, want) + } + } +} + func TestWatermark(t *testing.T) { ctx := context.Background() db := newDB(t)