diff --git a/go/vt/vtgate/vindexes/consistent_lookup.go b/go/vt/vtgate/vindexes/consistent_lookup.go index 3c2166c0aaf..1f0f1d54af5 100644 --- a/go/vt/vtgate/vindexes/consistent_lookup.go +++ b/go/vt/vtgate/vindexes/consistent_lookup.go @@ -360,8 +360,7 @@ func (lu *clCommon) handleDup(ctx context.Context, vcursor VCursor, values []sql return err } // Lock the target row using normal transaction priority. - // TODO: context needs to be passed on. - qr, err = vcursor.ExecuteKeyspaceID(context.Background(), lu.keyspace, existingksid, lu.lockOwnerQuery, bindVars, false /* rollbackOnError */, false /* autocommit */) + qr, err = vcursor.ExecuteKeyspaceID(ctx, lu.keyspace, existingksid, lu.lockOwnerQuery, bindVars, false /* rollbackOnError */, false /* autocommit */) if err != nil { return err } diff --git a/go/vt/vtgate/vindexes/consistent_lookup_test.go b/go/vt/vtgate/vindexes/consistent_lookup_test.go index 297732325ea..bf209d35fd4 100644 --- a/go/vt/vtgate/vindexes/consistent_lookup_test.go +++ b/go/vt/vtgate/vindexes/consistent_lookup_test.go @@ -75,8 +75,9 @@ func TestConsistentLookupMap(t *testing.T) { lookup := createConsistentLookup(t, "consistent_lookup", false) vc := &loggingVCursor{} vc.AddResult(makeTestResultLookup([]int{2, 2}), nil) + ctx := newTestContext() - got, err := lookup.Map(context.Background(), vc, []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}) + got, err := lookup.Map(ctx, vc, []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}) require.NoError(t, err) want := []key.Destination{ key.DestinationKeyspaceIDs([][]byte{ @@ -94,10 +95,11 @@ func TestConsistentLookupMap(t *testing.T) { vc.verifyLog(t, []string{ "ExecutePre select fromc1, toc from t where fromc1 in ::fromc1 [{fromc1 }] false", }) + vc.verifyContext(t, ctx) // Test query fail. vc.AddResult(nil, fmt.Errorf("execute failed")) - _, err = lookup.Map(context.Background(), vc, []sqltypes.Value{sqltypes.NewInt64(1)}) + _, err = lookup.Map(ctx, vc, []sqltypes.Value{sqltypes.NewInt64(1)}) wantErr := "lookup.Map: execute failed" if err == nil || err.Error() != wantErr { t.Errorf("lookup(query fail) err: %v, want %s", err, wantErr) @@ -126,8 +128,9 @@ func TestConsistentLookupUniqueMap(t *testing.T) { lookup := createConsistentLookup(t, "consistent_lookup_unique", false) vc := &loggingVCursor{} vc.AddResult(makeTestResultLookup([]int{0, 1}), nil) + ctx := newTestContext() - got, err := lookup.Map(context.Background(), vc, []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}) + got, err := lookup.Map(ctx, vc, []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}) require.NoError(t, err) want := []key.Destination{ key.DestinationNone{}, @@ -139,10 +142,11 @@ func TestConsistentLookupUniqueMap(t *testing.T) { vc.verifyLog(t, []string{ "ExecutePre select fromc1, toc from t where fromc1 in ::fromc1 [{fromc1 }] false", }) + vc.verifyContext(t, ctx) // More than one result is invalid vc.AddResult(makeTestResultLookup([]int{2}), nil) - _, err = lookup.Map(context.Background(), vc, []sqltypes.Value{sqltypes.NewInt64(1)}) + _, err = lookup.Map(ctx, vc, []sqltypes.Value{sqltypes.NewInt64(1)}) wanterr := "Lookup.Map: unexpected multiple results from vindex t: INT64(1)" if err == nil || err.Error() != wanterr { t.Errorf("lookup(query fail) err: %v, want %s", err, wanterr) @@ -171,8 +175,9 @@ func TestConsistentLookupMapAbsent(t *testing.T) { lookup := createConsistentLookup(t, "consistent_lookup", false) vc := &loggingVCursor{} vc.AddResult(makeTestResultLookup([]int{0, 0}), nil) + ctx := newTestContext() - got, err := lookup.Map(context.Background(), vc, []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}) + got, err := lookup.Map(ctx, vc, []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}) require.NoError(t, err) want := []key.Destination{ key.DestinationNone{}, @@ -184,6 +189,7 @@ func TestConsistentLookupMapAbsent(t *testing.T) { vc.verifyLog(t, []string{ "ExecutePre select fromc1, toc from t where fromc1 in ::fromc1 [{fromc1 }] false", }) + vc.verifyContext(t, ctx) } func TestConsistentLookupVerify(t *testing.T) { @@ -191,17 +197,19 @@ func TestConsistentLookupVerify(t *testing.T) { vc := &loggingVCursor{} vc.AddResult(makeTestResult(1), nil) vc.AddResult(makeTestResult(1), nil) + ctx := newTestContext() - _, err := lookup.Verify(context.Background(), vc, []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}, [][]byte{[]byte("test1"), []byte("test2")}) + _, err := lookup.Verify(ctx, vc, []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}, [][]byte{[]byte("test1"), []byte("test2")}) require.NoError(t, err) vc.verifyLog(t, []string{ "ExecutePre select fromc1 from t where fromc1 = :fromc1 and toc = :toc [{fromc1 1} {toc test1}] false", "ExecutePre select fromc1 from t where fromc1 = :fromc1 and toc = :toc [{fromc1 2} {toc test2}] false", }) + vc.verifyContext(t, ctx) // Test query fail. vc.AddResult(nil, fmt.Errorf("execute failed")) - _, err = lookup.Verify(context.Background(), vc, []sqltypes.Value{sqltypes.NewInt64(1)}, [][]byte{[]byte("\x16k@\xb4J\xbaK\xd6")}) + _, err = lookup.Verify(ctx, vc, []sqltypes.Value{sqltypes.NewInt64(1)}, [][]byte{[]byte("\x16k@\xb4J\xbaK\xd6")}) want := "lookup.Verify: execute failed" if err == nil || err.Error() != want { t.Errorf("lookup(query fail) err: %v, want %s", err, want) @@ -209,7 +217,7 @@ func TestConsistentLookupVerify(t *testing.T) { // Test write_only. lookup = createConsistentLookup(t, "consistent_lookup", true) - got, err := lookup.Verify(context.Background(), nil, []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}, [][]byte{[]byte(""), []byte("")}) + got, err := lookup.Verify(ctx, nil, []sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewInt64(2)}, [][]byte{[]byte(""), []byte("")}) require.NoError(t, err) wantBools := []bool{true, true} if !reflect.DeepEqual(got, wantBools) { @@ -221,8 +229,9 @@ func TestConsistentLookupCreateSimple(t *testing.T) { lookup := createConsistentLookup(t, "consistent_lookup", false) vc := &loggingVCursor{} vc.AddResult(&sqltypes.Result{}, nil) + ctx := newTestContext() - if err := lookup.(Lookup).Create(context.Background(), vc, [][]sqltypes.Value{{ + if err := lookup.(Lookup).Create(ctx, vc, [][]sqltypes.Value{{ sqltypes.NewInt64(1), sqltypes.NewInt64(2), }, { @@ -234,6 +243,7 @@ func TestConsistentLookupCreateSimple(t *testing.T) { vc.verifyLog(t, []string{ "ExecutePre insert into t(fromc1, fromc2, toc) values(:fromc1_0, :fromc2_0, :toc_0), (:fromc1_1, :fromc2_1, :toc_1) [{fromc1_0 1} {fromc1_1 3} {fromc2_0 2} {fromc2_1 4} {toc_0 test1} {toc_1 test2}] true", }) + vc.verifyContext(t, ctx) } func TestConsistentLookupCreateThenRecreate(t *testing.T) { @@ -242,8 +252,9 @@ func TestConsistentLookupCreateThenRecreate(t *testing.T) { vc.AddResult(nil, mysql.NewSQLError(mysql.ERDupEntry, mysql.SSConstraintViolation, "Duplicate entry")) vc.AddResult(&sqltypes.Result{}, nil) vc.AddResult(&sqltypes.Result{}, nil) + ctx := newTestContext() - if err := lookup.(Lookup).Create(context.Background(), vc, [][]sqltypes.Value{{ + if err := lookup.(Lookup).Create(ctx, vc, [][]sqltypes.Value{{ sqltypes.NewInt64(1), sqltypes.NewInt64(2), }}, [][]byte{[]byte("test1")}, false); err != nil { @@ -254,6 +265,7 @@ func TestConsistentLookupCreateThenRecreate(t *testing.T) { "ExecutePre select toc from t where fromc1 = :fromc1 and fromc2 = :fromc2 for update [{fromc1 1} {fromc2 2} {toc test1}] false", "ExecutePre insert into t(fromc1, fromc2, toc) values(:fromc1, :fromc2, :toc) [{fromc1 1} {fromc2 2} {toc test1}] true", }) + vc.verifyContext(t, ctx) } func TestConsistentLookupCreateThenUpdate(t *testing.T) { @@ -263,8 +275,9 @@ func TestConsistentLookupCreateThenUpdate(t *testing.T) { vc.AddResult(makeTestResult(1), nil) vc.AddResult(&sqltypes.Result{}, nil) vc.AddResult(&sqltypes.Result{}, nil) + ctx := newTestContext() - if err := lookup.(Lookup).Create(context.Background(), vc, [][]sqltypes.Value{{ + if err := lookup.(Lookup).Create(ctx, vc, [][]sqltypes.Value{{ sqltypes.NewInt64(1), sqltypes.NewInt64(2), }}, [][]byte{[]byte("test1")}, false); err != nil { @@ -276,6 +289,7 @@ func TestConsistentLookupCreateThenUpdate(t *testing.T) { "ExecuteKeyspaceID select fc1 from `dot.t1` where fc1 = :fromc1 and fc2 = :fromc2 lock in share mode [{fromc1 1} {fromc2 2} {toc test1}] false", "ExecutePre update t set toc=:toc where fromc1 = :fromc1 and fromc2 = :fromc2 [{fromc1 1} {fromc2 2} {toc test1}] true", }) + vc.verifyContext(t, ctx) } func TestConsistentLookupCreateThenSkipUpdate(t *testing.T) { @@ -285,8 +299,9 @@ func TestConsistentLookupCreateThenSkipUpdate(t *testing.T) { vc.AddResult(makeTestResult(1), nil) vc.AddResult(&sqltypes.Result{}, nil) vc.AddResult(&sqltypes.Result{}, nil) + ctx := newTestContext() - if err := lookup.(Lookup).Create(context.Background(), vc, [][]sqltypes.Value{{ + if err := lookup.(Lookup).Create(ctx, vc, [][]sqltypes.Value{{ sqltypes.NewInt64(1), sqltypes.NewInt64(2), }}, [][]byte{[]byte("1")}, false); err != nil { @@ -297,6 +312,7 @@ func TestConsistentLookupCreateThenSkipUpdate(t *testing.T) { "ExecutePre select toc from t where fromc1 = :fromc1 and fromc2 = :fromc2 for update [{fromc1 1} {fromc2 2} {toc 1}] false", "ExecuteKeyspaceID select fc1 from `dot.t1` where fc1 = :fromc1 and fc2 = :fromc2 lock in share mode [{fromc1 1} {fromc2 2} {toc 1}] false", }) + vc.verifyContext(t, ctx) } func TestConsistentLookupCreateThenDupkey(t *testing.T) { @@ -306,8 +322,9 @@ func TestConsistentLookupCreateThenDupkey(t *testing.T) { vc.AddResult(makeTestResult(1), nil) vc.AddResult(makeTestResult(1), nil) vc.AddResult(&sqltypes.Result{}, nil) + ctx := newTestContext() - err := lookup.(Lookup).Create(context.Background(), vc, [][]sqltypes.Value{{ + err := lookup.(Lookup).Create(ctx, vc, [][]sqltypes.Value{{ sqltypes.NewInt64(1), sqltypes.NewInt64(2), }}, [][]byte{[]byte("test1")}, false) @@ -318,14 +335,16 @@ func TestConsistentLookupCreateThenDupkey(t *testing.T) { "ExecutePre select toc from t where fromc1 = :fromc1 and fromc2 = :fromc2 for update [{fromc1 1} {fromc2 2} {toc test1}] false", "ExecuteKeyspaceID select fc1 from `dot.t1` where fc1 = :fromc1 and fc2 = :fromc2 lock in share mode [{fromc1 1} {fromc2 2} {toc test1}] false", }) + vc.verifyContext(t, ctx) } func TestConsistentLookupCreateNonDupError(t *testing.T) { lookup := createConsistentLookup(t, "consistent_lookup", false) vc := &loggingVCursor{} vc.AddResult(nil, errors.New("general error")) + ctx := newTestContext() - err := lookup.(Lookup).Create(context.Background(), vc, [][]sqltypes.Value{{ + err := lookup.(Lookup).Create(ctx, vc, [][]sqltypes.Value{{ sqltypes.NewInt64(1), sqltypes.NewInt64(2), }}, [][]byte{[]byte("test1")}, false) @@ -336,6 +355,7 @@ func TestConsistentLookupCreateNonDupError(t *testing.T) { vc.verifyLog(t, []string{ "ExecutePre insert into t(fromc1, fromc2, toc) values(:fromc1_0, :fromc2_0, :toc_0) [{fromc1_0 1} {fromc2_0 2} {toc_0 test1}] true", }) + vc.verifyContext(t, ctx) } func TestConsistentLookupCreateThenBadRows(t *testing.T) { @@ -343,8 +363,9 @@ func TestConsistentLookupCreateThenBadRows(t *testing.T) { vc := &loggingVCursor{} vc.AddResult(nil, vterrors.New(vtrpcpb.Code_ALREADY_EXISTS, "(errno 1062) (sqlstate 23000) Duplicate entry")) vc.AddResult(makeTestResult(2), nil) + ctx := newTestContext() - err := lookup.(Lookup).Create(context.Background(), vc, [][]sqltypes.Value{{ + err := lookup.(Lookup).Create(ctx, vc, [][]sqltypes.Value{{ sqltypes.NewInt64(1), sqltypes.NewInt64(2), }}, [][]byte{[]byte("test1")}, false) @@ -356,14 +377,16 @@ func TestConsistentLookupCreateThenBadRows(t *testing.T) { "ExecutePre insert into t(fromc1, fromc2, toc) values(:fromc1_0, :fromc2_0, :toc_0) [{fromc1_0 1} {fromc2_0 2} {toc_0 test1}] true", "ExecutePre select toc from t where fromc1 = :fromc1 and fromc2 = :fromc2 for update [{fromc1 1} {fromc2 2} {toc test1}] false", }) + vc.verifyContext(t, ctx) } func TestConsistentLookupDelete(t *testing.T) { lookup := createConsistentLookup(t, "consistent_lookup", false) vc := &loggingVCursor{} vc.AddResult(&sqltypes.Result{}, nil) + ctx := newTestContext() - if err := lookup.(Lookup).Delete(context.Background(), vc, [][]sqltypes.Value{{ + if err := lookup.(Lookup).Delete(ctx, vc, [][]sqltypes.Value{{ sqltypes.NewInt64(1), sqltypes.NewInt64(2), }}, []byte("test")); err != nil { @@ -372,6 +395,7 @@ func TestConsistentLookupDelete(t *testing.T) { vc.verifyLog(t, []string{ "ExecutePost delete from t where fromc1 = :fromc1 and fromc2 = :fromc2 and toc = :toc [{fromc1 1} {fromc2 2} {toc test}] true", }) + vc.verifyContext(t, ctx) } func TestConsistentLookupUpdate(t *testing.T) { @@ -379,8 +403,9 @@ func TestConsistentLookupUpdate(t *testing.T) { vc := &loggingVCursor{} vc.AddResult(&sqltypes.Result{}, nil) vc.AddResult(&sqltypes.Result{}, nil) + ctx := newTestContext() - if err := lookup.(Lookup).Update(context.Background(), vc, []sqltypes.Value{ + if err := lookup.(Lookup).Update(ctx, vc, []sqltypes.Value{ sqltypes.NewInt64(1), sqltypes.NewInt64(2), }, []byte("test"), []sqltypes.Value{ @@ -393,6 +418,7 @@ func TestConsistentLookupUpdate(t *testing.T) { "ExecutePost delete from t where fromc1 = :fromc1 and fromc2 = :fromc2 and toc = :toc [{fromc1 1} {fromc2 2} {toc test}] true", "ExecutePre insert into t(fromc1, fromc2, toc) values(:fromc1_0, :fromc2_0, :toc_0) [{fromc1_0 3} {fromc2_0 4} {toc_0 test}] true", }) + vc.verifyContext(t, ctx) } func TestConsistentLookupNoUpdate(t *testing.T) { @@ -469,13 +495,19 @@ func createConsistentLookup(t *testing.T, name string, writeOnly bool) SingleCol return l.(SingleColumn) } +func newTestContext() context.Context { + type testContextKey string // keep static checks from complaining about built-in types as context keys + return context.WithValue(context.Background(), (testContextKey)("test"), "foo") +} + var _ VCursor = (*loggingVCursor)(nil) type loggingVCursor struct { - results []*sqltypes.Result - errors []error - index int - log []string + results []*sqltypes.Result + errors []error + index int + log []string + contexts []context.Context } func (vc *loggingVCursor) LookupRowLockShardSession() vtgatepb.CommitOrder { @@ -508,14 +540,14 @@ func (vc *loggingVCursor) Execute(ctx context.Context, method string, query stri case vtgatepb.CommitOrder_AUTOCOMMIT: name = "ExecuteAutocommit" } - return vc.execute(name, query, bindvars, rollbackOnError) + return vc.execute(ctx, name, query, bindvars, rollbackOnError) } func (vc *loggingVCursor) ExecuteKeyspaceID(ctx context.Context, keyspace string, ksid []byte, query string, bindVars map[string]*querypb.BindVariable, rollbackOnError, autocommit bool) (*sqltypes.Result, error) { - return vc.execute("ExecuteKeyspaceID", query, bindVars, rollbackOnError) + return vc.execute(ctx, "ExecuteKeyspaceID", query, bindVars, rollbackOnError) } -func (vc *loggingVCursor) execute(method string, query string, bindvars map[string]*querypb.BindVariable, rollbackOnError bool) (*sqltypes.Result, error) { +func (vc *loggingVCursor) execute(ctx context.Context, method string, query string, bindvars map[string]*querypb.BindVariable, rollbackOnError bool) (*sqltypes.Result, error) { if vc.index >= len(vc.results) { return nil, fmt.Errorf("ran out of results to return: %s", query) } @@ -525,6 +557,7 @@ func (vc *loggingVCursor) execute(method string, query string, bindvars map[stri } sort.Slice(bvl, func(i, j int) bool { return bvl[i].Name < bvl[j].Name }) vc.log = append(vc.log, fmt.Sprintf("%s %s %v %v", method, query, bvl, rollbackOnError)) + vc.contexts = append(vc.contexts, ctx) idx := vc.index vc.index++ if vc.errors[idx] != nil { @@ -548,6 +581,15 @@ func (vc *loggingVCursor) verifyLog(t *testing.T, want []string) { } } +func (vc *loggingVCursor) verifyContext(t *testing.T, want context.Context) { + t.Helper() + for i, got := range vc.contexts { + if got != want { + t.Errorf("context(%d):\ngot: %v\nwant: %v", i, got, want) + } + } +} + // create lookup result with one to one mapping func makeTestResult(numRows int) *sqltypes.Result { result := &sqltypes.Result{