diff --git a/core/integration/storagetest/mutation_logs.go b/core/integration/storagetest/mutation_logs.go index ca57576bd..66825a0f0 100644 --- a/core/integration/storagetest/mutation_logs.go +++ b/core/integration/storagetest/mutation_logs.go @@ -63,7 +63,7 @@ func (mutationLogsTests) TestReadLog(ctx context.Context, t *testing.T, newForTe // Write ten batches, three entries each. for i := byte(0); i < 10; i++ { entry := &pb.EntryUpdate{Mutation: &pb.SignedEntry{Entry: mustMarshal(t, &pb.Entry{Index: []byte{i}})}} - if _, _, err := m.Send(ctx, directoryID, entry, entry, entry); err != nil { + if _, err := m.Send(ctx, directoryID, logID, entry, entry, entry); err != nil { t.Fatalf("Send(): %v", err) } } diff --git a/core/keyserver/keyserver.go b/core/keyserver/keyserver.go index 8097a92c4..32e7b47f5 100644 --- a/core/keyserver/keyserver.go +++ b/core/keyserver/keyserver.go @@ -18,6 +18,7 @@ package keyserver import ( "context" "fmt" + "math/rand" "runtime" "sync" "time" @@ -67,16 +68,18 @@ func createMetrics(mf monitoring.MetricFactory) { // MutationLogs provides sets of time ordered message logs. type MutationLogs interface { - // Send submits the whole group of mutations atomically to a random log. + // Send submits the whole group of mutations atomically to a given log. // TODO(gbelvin): Create a batch level object to make it clear that this a batch of updates. - // Returns the logID and timestamp that the mutation batch got written at. - Send(ctx context.Context, directoryID string, mutation ...*pb.EntryUpdate) (int64, time.Time, error) + // Returns the timestamp that the mutation batch got written at. + Send(ctx context.Context, directoryID string, logID int64, mutation ...*pb.EntryUpdate) (time.Time, error) // ReadLog returns the messages in the (low, high] range stored in the // specified log. ReadLog always returns complete units of the original // batches sent via Send, and will return more items than limit if // needed to do so. ReadLog(ctx context.Context, directoryID string, logID int64, low, high time.Time, limit int32) ([]*mutator.LogMessage, error) + // ListLogs returns a list of logs, optionally filtered by the writable bit. + ListLogs(ctx context.Context, directoryID string, writable bool) ([]int64, error) } // BatchReader reads batch definitions. @@ -655,8 +658,13 @@ func (s *Server) BatchQueueUserUpdate(ctx context.Context, in *pb.BatchQueueUser } tdone() - // Save mutation to the database. - wmLogID, wmTime, err := s.logs.Send(ctx, directory.DirectoryID, in.Updates...) + // Pick a random logID. Note, this effectively picks a random QoS. See issue #1377. + // TODO(gbelvin): Define an explicit QoS / Load ballancing API. + wmLogID, err := s.randLog(ctx, directory.DirectoryID) + if st := status.Convert(err); st.Code() != codes.OK { + return nil, status.Errorf(st.Code(), "Could not pick a log to write to: %v", err) + } + wmTime, err := s.logs.Send(ctx, directory.DirectoryID, wmLogID, in.Updates...) if st := status.Convert(err); st.Code() != codes.OK { glog.Errorf("mutations.Write failed: %v", err) return nil, status.Errorf(st.Code(), "Mutation write error") @@ -666,6 +674,18 @@ func (s *Server) BatchQueueUserUpdate(ctx context.Context, in *pb.BatchQueueUser return &empty.Empty{}, nil } +func (s *Server) randLog(ctx context.Context, directoryID string) (int64, error) { + // TODO(gbelvin): Cache these results. + writable := true + logIDs, err := s.logs.ListLogs(ctx, directoryID, writable) + if err != nil { + return 0, err + } + + // Return a random log. + return logIDs[rand.Intn(len(logIDs))], nil +} + // GetDirectory returns all info tied to the specified directory. // // This API to get all necessary data needed to verify a particular diff --git a/core/keyserver/revisions_test.go b/core/keyserver/revisions_test.go index 117f12296..ea74b1d84 100644 --- a/core/keyserver/revisions_test.go +++ b/core/keyserver/revisions_test.go @@ -71,8 +71,16 @@ func (b batchStorage) ReadBatch(ctx context.Context, dirID string, rev int64) (* type mutations map[int64][]*mutator.LogMessage // Map of logID to Slice of LogMessages -func (m *mutations) Send(ctx context.Context, dirID string, mutation ...*pb.EntryUpdate) (int64, time.Time, error) { - return 0, time.Time{}, errors.New("unimplemented") +func (m *mutations) Send(ctx context.Context, dirID string, _ int64, mutation ...*pb.EntryUpdate) (time.Time, error) { + return time.Time{}, errors.New("unimplemented") +} + +func (m *mutations) ListLogs(ctx context.Context, dirID string, _ bool) ([]int64, error) { + logIDs := []int64{} + for id := range *m { + logIDs = append(logIDs, id) + } + return logIDs, nil } func (m *mutations) ReadLog(ctx context.Context, dirID string, diff --git a/impl/sql/mutationstorage/mutation_logs.go b/impl/sql/mutationstorage/mutation_logs.go index 2f22c3e05..e8c75b40f 100644 --- a/impl/sql/mutationstorage/mutation_logs.go +++ b/impl/sql/mutationstorage/mutation_logs.go @@ -17,7 +17,6 @@ package mutationstorage import ( "context" "database/sql" - "math/rand" "time" "github.com/golang/glog" @@ -66,20 +65,16 @@ func (m *Mutations) AddLogs(ctx context.Context, directoryID string, logIDs ...i // Send writes mutations to the leading edge (by sequence number) of the mutations table. // Returns the logID/watermark pair that was written, or nil if nothing was written. // TODO(gbelvin): Make updates a slice. -func (m *Mutations) Send(ctx context.Context, directoryID string, updates ...*pb.EntryUpdate) (int64, time.Time, error) { +func (m *Mutations) Send(ctx context.Context, directoryID string, logID int64, updates ...*pb.EntryUpdate) (time.Time, error) { glog.Infof("mutationstorage: Send(%v, )", directoryID) if len(updates) == 0 { - return 0, time.Time{}, nil - } - logID, err := m.randLog(ctx, directoryID) - if err != nil { - return 0, time.Time{}, err + return time.Time{}, nil } updateData := make([][]byte, 0, len(updates)) for _, u := range updates { data, err := proto.Marshal(u) if err != nil { - return 0, time.Time{}, err + return time.Time{}, err } updateData = append(updateData, data) } @@ -87,9 +82,9 @@ func (m *Mutations) Send(ctx context.Context, directoryID string, updates ...*pb // we get timestamp contention. ts := time.Now() if err := m.send(ctx, ts, directoryID, logID, updateData...); err != nil { - return 0, time.Time{}, err + return time.Time{}, err } - return logID, ts, nil + return ts, nil } // ListLogs returns a list of all logs for directoryID, optionally filtered for writable logs. @@ -123,19 +118,6 @@ func (m *Mutations) ListLogs(ctx context.Context, directoryID string, writable b return logIDs, nil } -// randLog returns a random, enabled log for directoryID. -func (m *Mutations) randLog(ctx context.Context, directoryID string) (int64, error) { - // TODO(gbelvin): Cache these results. - writable := true - logIDs, err := m.ListLogs(ctx, directoryID, writable) - if err != nil { - return 0, err - } - - // Return a random log. - return logIDs[rand.Intn(len(logIDs))], nil -} - // ts must be greater than all other timestamps currently recorded for directoryID. func (m *Mutations) send(ctx context.Context, ts time.Time, directoryID string, logID int64, mData ...[]byte) (ret error) { diff --git a/impl/sql/mutationstorage/mutation_logs_test.go b/impl/sql/mutationstorage/mutation_logs_test.go index ffa62f95d..94c0e2af7 100644 --- a/impl/sql/mutationstorage/mutation_logs_test.go +++ b/impl/sql/mutationstorage/mutation_logs_test.go @@ -20,7 +20,6 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" "github.com/google/keytransparency/core/adminserver" "github.com/google/keytransparency/core/integration/storagetest" "github.com/google/keytransparency/core/keyserver" @@ -57,45 +56,6 @@ func TestLogsAdminIntegration(t *testing.T) { }) } -func TestRandLog(t *testing.T) { - ctx := context.Background() - directoryID := "TestRandLog" - - for _, tc := range []struct { - desc string - send []int64 - wantCode codes.Code - wantLogs map[int64]bool - }{ - {desc: "no rows", wantCode: codes.NotFound, wantLogs: map[int64]bool{}}, - {desc: "one row", send: []int64{10}, wantLogs: map[int64]bool{10: true}}, - {desc: "second", send: []int64{1, 2, 3}, wantLogs: map[int64]bool{ - 1: true, - 2: true, - 3: true, - }}, - } { - t.Run(tc.desc, func(t *testing.T) { - m, done := newForTest(ctx, t, directoryID, tc.send...) - defer done(ctx) - logs := make(map[int64]bool) - for i := 0; i < 10*len(tc.wantLogs); i++ { - logID, err := m.randLog(ctx, directoryID) - if got, want := status.Code(err), tc.wantCode; got != want { - t.Errorf("randLog(): %v, want %v", got, want) - } - if err != nil { - break - } - logs[logID] = true - } - if got, want := logs, tc.wantLogs; !cmp.Equal(got, want) { - t.Errorf("logs: %v, want %v", got, want) - } - }) - } -} - func BenchmarkSend(b *testing.B) { ctx := context.Background() directoryID := "BenchmarkSend" @@ -123,7 +83,7 @@ func BenchmarkSend(b *testing.B) { updates = append(updates, update) } for n := 0; n < b.N; n++ { - if _, _, err := m.Send(ctx, directoryID, updates...); err != nil { + if _, err := m.Send(ctx, directoryID, logID, updates...); err != nil { b.Errorf("Send(): %v", err) } }