diff --git a/go/sqltypes/result.go b/go/sqltypes/result.go index 4fd8f29d57a..b2818f4fb13 100644 --- a/go/sqltypes/result.go +++ b/go/sqltypes/result.go @@ -93,7 +93,7 @@ func (result *Result) Copy() *Result { out := &Result{ RowsAffected: result.RowsAffected, InsertID: result.InsertID, - InsertIDChanged: result.InsertIDUpdated(), + InsertIDChanged: result.InsertIDChanged, SessionStateChanges: result.SessionStateChanges, StatusFlags: result.StatusFlags, Info: result.Info, @@ -132,7 +132,7 @@ func (result *Result) Metadata() *Result { return &Result{ Fields: result.Fields, InsertID: result.InsertID, - InsertIDChanged: result.InsertIDUpdated(), + InsertIDChanged: result.InsertIDChanged, RowsAffected: result.RowsAffected, Info: result.Info, SessionStateChanges: result.SessionStateChanges, @@ -157,7 +157,7 @@ func (result *Result) Truncate(l int) *Result { out := &Result{ InsertID: result.InsertID, - InsertIDChanged: result.InsertIDUpdated(), + InsertIDChanged: result.InsertIDChanged, RowsAffected: result.RowsAffected, Info: result.Info, SessionStateChanges: result.SessionStateChanges, @@ -333,8 +333,8 @@ func (result *Result) AppendResult(src *Result) { result.RowsAffected += src.RowsAffected if src.InsertIDUpdated() { result.InsertID = src.InsertID + result.InsertIDChanged = true } - result.InsertIDChanged = result.InsertIDUpdated() || src.InsertIDUpdated() if result.Fields == nil { result.Fields = src.Fields } diff --git a/go/sqltypes/result_test.go b/go/sqltypes/result_test.go index d49e184f109..c358497baf3 100644 --- a/go/sqltypes/result_test.go +++ b/go/sqltypes/result_test.go @@ -21,11 +21,17 @@ import ( "github.com/stretchr/testify/assert" - "vitess.io/vitess/go/test/utils" - querypb "vitess.io/vitess/go/vt/proto/query" ) +func assertEqualResults(t *testing.T, a, b *Result) { + t.Helper() + assert.Truef(t, a.Equal(b), "Results are not equal: \n%v\n%v", a, b) + if !a.Equal(b) { + t.Errorf("Results are not equal: %v %v", a, b) + } +} + func TestRepair(t *testing.T) { fields := []*querypb.Field{{ Type: Int64, @@ -45,9 +51,7 @@ func TestRepair(t *testing.T) { }, } in.Repair(fields) - if !in.Equal(want) { - t.Errorf("Repair:\n%#v, want\n%#v", in, want) - } + assertEqualResults(t, in, want) } func TestCopy(t *testing.T) { @@ -67,7 +71,7 @@ func TestCopy(t *testing.T) { }, } out := in.Copy() - utils.MustMatch(t, in, out) + assertEqualResults(t, in, out) } func TestTruncate(t *testing.T) { @@ -88,9 +92,7 @@ func TestTruncate(t *testing.T) { } out := in.Truncate(0) - if !out.Equal(in) { - t.Errorf("Truncate(0):\n%v, want\n%v", out, in) - } + assertEqualResults(t, in, out) out = in.Truncate(1) want := &Result{ @@ -106,9 +108,7 @@ func TestTruncate(t *testing.T) { {TestValue(Int64, "3")}, }, } - if !out.Equal(want) { - t.Errorf("Truncate(1):\n%v, want\n%v", out, want) - } + assertEqualResults(t, out, want) } func TestStripMetaData(t *testing.T) { @@ -286,9 +286,7 @@ func TestStripMetaData(t *testing.T) { t.Run(tcase.name, func(t *testing.T) { inCopy := tcase.in.Copy() out := inCopy.StripMetadata(tcase.includedFields) - if !out.Equal(tcase.expected) { - t.Errorf("StripMetaData unexpected result for %v: %v", tcase.name, out) - } + assertEqualResults(t, out, tcase.expected) if len(tcase.in.Fields) > 0 { // check the out array is different than the in array. if out.Fields[0] == inCopy.Fields[0] && tcase.includedFields != querypb.ExecuteOptions_ALL { @@ -296,7 +294,7 @@ func TestStripMetaData(t *testing.T) { } } // check we didn't change the original result. - utils.MustMatch(t, tcase.in, inCopy) + assertEqualResults(t, tcase.in, inCopy) }) } } @@ -348,10 +346,7 @@ func TestAppendResult(t *testing.T) { } result.AppendResult(src) - - if !result.Equal(want) { - t.Errorf("Got:\n%#v, want:\n%#v", result, want) - } + assertEqualResults(t, result, want) } func TestReplaceKeyspace(t *testing.T) {