diff --git a/go/sqltypes/named_result_test.go b/go/sqltypes/named_result_test.go
index 8c9c32554da..ae42d4257dd 100644
--- a/go/sqltypes/named_result_test.go
+++ b/go/sqltypes/named_result_test.go
@@ -20,12 +20,15 @@ import (
"fmt"
"testing"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
querypb "vitess.io/vitess/go/vt/proto/query"
)
func TestToNamedResult(t *testing.T) {
+ require.Nil(t, ToNamedResult(nil))
+
in := &Result{
Fields: []*querypb.Field{{
Name: "id",
@@ -57,3 +60,116 @@ func TestToNamedResult(t *testing.T) {
require.Equal(t, uint64(i), named.Rows[i].AsUint64("uid", 0))
}
}
+
+func TestToNumericTypes(t *testing.T) {
+ row := RowNamedValues{
+ "test": Value{
+ val: []byte("0x1234"),
+ },
+ }
+ tests := []struct {
+ name string
+ fieldName string
+ expectedErr string
+ }{
+ {
+ name: "random fieldName",
+ fieldName: "random",
+ expectedErr: "No such field in RowNamedValues",
+ },
+ {
+ name: "right fieldName",
+ fieldName: "test",
+ expectedErr: "Cannot convert value to desired type",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := row.ToInt(tt.fieldName)
+ if tt.expectedErr != "" {
+ require.ErrorContains(t, err, tt.expectedErr)
+ } else {
+ require.NoError(t, err)
+ }
+
+ _, err = row.ToInt32(tt.fieldName)
+ if tt.expectedErr != "" {
+ require.ErrorContains(t, err, tt.expectedErr)
+ } else {
+ require.NoError(t, err)
+ }
+
+ _, err = row.ToInt64(tt.fieldName)
+ if tt.expectedErr != "" {
+ require.ErrorContains(t, err, tt.expectedErr)
+ } else {
+ require.NoError(t, err)
+ }
+
+ _, err = row.ToUint64(tt.fieldName)
+ if tt.expectedErr != "" {
+ require.ErrorContains(t, err, tt.expectedErr)
+ } else {
+ require.NoError(t, err)
+ }
+
+ _, err = row.ToFloat64(tt.fieldName)
+ if tt.expectedErr != "" {
+ require.ErrorContains(t, err, tt.expectedErr)
+ } else {
+ require.NoError(t, err)
+ }
+
+ _, err = row.ToBool(tt.fieldName)
+ if tt.expectedErr != "" {
+ require.ErrorContains(t, err, tt.expectedErr)
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestToBytes(t *testing.T) {
+ row := RowNamedValues{
+ "test": Value{
+ val: []byte("0x1234"),
+ },
+ }
+
+ _, err := row.ToBytes("random")
+ require.ErrorContains(t, err, "No such field in RowNamedValues")
+
+ val, err := row.ToBytes("test")
+ require.NoError(t, err)
+ require.Equal(t, []byte{0x30, 0x78, 0x31, 0x32, 0x33, 0x34}, val)
+}
+
+func TestRow(t *testing.T) {
+ row := RowNamedValues{}
+ tests := []struct {
+ name string
+ res *NamedResult
+ expectedRow RowNamedValues
+ }{
+ {
+ name: "empty results",
+ res: &NamedResult{},
+ expectedRow: nil,
+ },
+ {
+ name: "non-empty results",
+ res: &NamedResult{
+ Rows: []RowNamedValues{row},
+ },
+ expectedRow: row,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.expectedRow, tt.res.Row())
+ })
+ }
+}
diff --git a/go/sqltypes/parse_rows_test.go b/go/sqltypes/parse_rows_test.go
index a32f2fd35b0..45c55da019b 100644
--- a/go/sqltypes/parse_rows_test.go
+++ b/go/sqltypes/parse_rows_test.go
@@ -168,20 +168,123 @@ func TestRowParsing(t *testing.T) {
}
func TestRowsEquals(t *testing.T) {
- var cases = []struct {
+ tests := []struct {
+ name string
left, right string
+ expectedErr string
}{
- {"[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]", "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]"},
+ {
+ name: "Both equal",
+ left: "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]",
+ right: "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]",
+ },
+ {
+ name: "length mismatch",
+ left: "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]",
+ right: "[[INT64(2)] [INT64(2)] [INT64(1)]]",
+ expectedErr: "results differ: expected 4 rows in result, got 3\n\twant: [[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]\n\tgot: [[INT64(2)] [INT64(2)] [INT64(1)]]",
+ },
+ {
+ name: "elements mismatch",
+ left: "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]",
+ right: "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(4)]]",
+ expectedErr: "results differ: row [INT64(1)] is missing from result\n\twant: [[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]\n\tgot: [[INT64(1)] [INT64(2)] [INT64(2)] [INT64(4)]]",
+ },
}
- for _, tc := range cases {
- left, err := ParseRows(tc.left)
- require.NoError(t, err)
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ left, err := ParseRows(tt.left)
+ require.NoError(t, err)
- right, err := ParseRows(tc.right)
- require.NoError(t, err)
+ right, err := ParseRows(tt.right)
+ require.NoError(t, err)
- err = RowsEquals(left, right)
- require.NoError(t, err)
+ err = RowsEquals(left, right)
+ if tt.expectedErr == "" {
+ require.NoError(t, err)
+ } else {
+ require.ErrorContains(t, err, tt.expectedErr)
+ }
+ })
+ }
+}
+
+func TestRowsEqualStr(t *testing.T) {
+ tests := []struct {
+ name string
+ want string
+ got []Row
+ expectedErr string
+ }{
+ {
+ name: "Unknown type",
+ want: "[[RANDOM(1)]]",
+ got: []Row{
+ {
+ NewInt64(1),
+ },
+ },
+ expectedErr: "malformed row assertion: unknown SQL type \"RANDOM\" at :1:3",
+ },
+ {
+ name: "Invalid row",
+ want: "[[INT64(1]]",
+ got: []Row{
+ {
+ NewInt64(1),
+ },
+ },
+ expectedErr: "malformed row assertion: unexpected token ']' at :1:10",
+ },
+ {
+ name: "Both equal",
+ want: "[[INT64(1)]]",
+ got: []Row{
+ {
+ NewInt64(1),
+ },
+ },
+ },
+ {
+ name: "length mismatch",
+ want: "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]",
+ got: []Row{
+ {
+ NewInt64(1),
+ },
+ },
+ expectedErr: "results differ: expected 4 rows in result, got 1\n\twant: [[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]\n\tgot: [[INT64(1)]]",
+ },
+ {
+ name: "elements mismatch",
+ want: "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]",
+ got: []Row{
+ {
+ NewInt64(1),
+ },
+ {
+ NewInt64(1),
+ },
+ {
+ NewInt64(1),
+ },
+ {
+ NewInt64(1),
+ },
+ },
+ expectedErr: "results differ: row [INT64(2)] is missing from result\n\twant: [[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]\n\tgot: [[INT64(1)] [INT64(1)] [INT64(1)] [INT64(1)]]",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := RowsEqualsStr(tt.want, tt.got)
+ if tt.expectedErr == "" {
+ require.NoError(t, err)
+ } else {
+ require.ErrorContains(t, err, tt.expectedErr)
+ }
+ })
}
}
diff --git a/go/sqltypes/query_response_test.go b/go/sqltypes/query_response_test.go
new file mode 100644
index 00000000000..30b6fe62e14
--- /dev/null
+++ b/go/sqltypes/query_response_test.go
@@ -0,0 +1,105 @@
+/*
+Copyright 2024 The Vitess Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package sqltypes
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestQueryResponsesEqual(t *testing.T) {
+ tests := []struct {
+ name string
+ r1 []QueryResponse
+ r2 []QueryResponse
+ isEqual bool
+ }{
+ {
+ name: "1 response in each",
+ r1: []QueryResponse{
+ {
+ QueryResult: &Result{},
+ QueryError: nil,
+ },
+ },
+ r2: []QueryResponse{
+ {
+ QueryResult: &Result{},
+ QueryError: nil,
+ },
+ },
+ isEqual: true,
+ },
+ {
+ name: "different lengths",
+ r1: []QueryResponse{
+ {
+ QueryResult: &Result{},
+ QueryError: nil,
+ },
+ },
+ r2: []QueryResponse{},
+ isEqual: false,
+ },
+ {
+ name: "different query errors",
+ r1: []QueryResponse{
+ {
+ QueryResult: &Result{},
+ QueryError: fmt.Errorf("some error"),
+ },
+ },
+ r2: []QueryResponse{
+ {
+ QueryResult: &Result{
+ Info: "Test",
+ },
+ QueryError: nil,
+ },
+ },
+ isEqual: false,
+ },
+ {
+ name: "different query results",
+ r1: []QueryResponse{
+ {
+ QueryResult: &Result{
+ RowsAffected: 7,
+ },
+ QueryError: nil,
+ },
+ },
+ r2: []QueryResponse{
+ {
+ QueryResult: &Result{
+ RowsAffected: 10,
+ },
+ QueryError: nil,
+ },
+ },
+ isEqual: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.isEqual, QueryResponsesEqual(tt.r1, tt.r2))
+ })
+ }
+}
diff --git a/go/sqltypes/type_test.go b/go/sqltypes/type_test.go
index edf340b2abb..8493dc23e05 100644
--- a/go/sqltypes/type_test.go
+++ b/go/sqltypes/type_test.go
@@ -20,6 +20,8 @@ import (
"strings"
"testing"
+ "github.com/stretchr/testify/assert"
+
querypb "vitess.io/vitess/go/vt/proto/query"
)
@@ -512,3 +514,88 @@ func TestPrintTypeChecks(t *testing.T) {
t.Logf("%s(): %s", f.name, strings.Join(match, ", "))
}
}
+
+func TestIsTextOrBinary(t *testing.T) {
+ tests := []struct {
+ name string
+ ty querypb.Type
+ isTextorBinary bool
+ }{
+ {
+ name: "null type",
+ ty: querypb.Type_NULL_TYPE,
+ isTextorBinary: false,
+ },
+ {
+ name: "blob type",
+ ty: querypb.Type_BLOB,
+ isTextorBinary: true,
+ },
+ {
+ name: "text type",
+ ty: querypb.Type_TEXT,
+ isTextorBinary: true,
+ },
+ {
+ name: "binary type",
+ ty: querypb.Type_BINARY,
+ isTextorBinary: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.isTextorBinary, IsTextOrBinary(tt.ty))
+ })
+ }
+}
+
+func TestIsDateOrTime(t *testing.T) {
+ tests := []struct {
+ name string
+ ty querypb.Type
+ isDateOrTime bool
+ }{
+ {
+ name: "null type",
+ ty: querypb.Type_NULL_TYPE,
+ isDateOrTime: false,
+ },
+ {
+ name: "blob type",
+ ty: querypb.Type_BLOB,
+ isDateOrTime: false,
+ },
+ {
+ name: "timestamp type",
+ ty: querypb.Type_TIMESTAMP,
+ isDateOrTime: true,
+ },
+ {
+ name: "date type",
+ ty: querypb.Type_DATE,
+ isDateOrTime: true,
+ },
+ {
+ name: "time type",
+ ty: querypb.Type_TIME,
+ isDateOrTime: true,
+ },
+ {
+ name: "date time type",
+ ty: querypb.Type_DATETIME,
+ isDateOrTime: true,
+ },
+ {
+ name: "year type",
+ ty: querypb.Type_YEAR,
+ isDateOrTime: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.isDateOrTime, IsDateOrTime(tt.ty))
+ })
+ }
+}