Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release-21.0] fix issue with json unmarshalling of operators with space in them #16933

Merged
merged 1 commit into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1562,36 +1562,36 @@ func (op ComparisonExprOperator) ToString() string {
}
}

func ComparisonExprOperatorFromJson(s string) ComparisonExprOperator {
func ComparisonExprOperatorFromJson(s string) (ComparisonExprOperator, error) {
switch s {
case EqualStr:
return EqualOp
return EqualOp, nil
case JsonLessThanStr:
return LessThanOp
return LessThanOp, nil
case JsonGreaterThanStr:
return GreaterThanOp
return GreaterThanOp, nil
case JsonLessThanOrEqualStr:
return LessEqualOp
return LessEqualOp, nil
case JsonGreaterThanOrEqualStr:
return GreaterEqualOp
return GreaterEqualOp, nil
case NotEqualStr:
return NotEqualOp
return NotEqualOp, nil
case NullSafeEqualStr:
return NullSafeEqualOp
return NullSafeEqualOp, nil
case InStr:
return InOp
return InOp, nil
case NotInStr:
return NotInOp
return NotInOp, nil
case LikeStr:
return LikeOp
return LikeOp, nil
case NotLikeStr:
return NotLikeOp
return NotLikeOp, nil
case RegexpStr:
return RegexpOp
return RegexpOp, nil
case NotRegexpStr:
return NotRegexpOp
return NotRegexpOp, nil
default:
return 0
return 0, fmt.Errorf("unknown ComparisonExpOperator: %s", s)
}
}

Expand Down
177 changes: 40 additions & 137 deletions go/vt/vtgate/executor_vexplain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ package vtgate

import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -115,153 +118,53 @@ func TestSimpleVexplainTrace(t *testing.T) {
}

func TestVExplainKeys(t *testing.T) {
tests := []struct {
query string
expectedRowString string
}{
{
query: "select count(*), col2 from music group by col2",
expectedRowString: `{
"statementType": "SELECT",
"groupingColumns": [
"music.col2"
],
"selectColumns": [
"music.col2"
]
}`,
}, {
query: "select * from user u join user_extra ue on u.id = ue.user_id where u.col1 > 100 and ue.noLimit = 'foo'",
expectedRowString: `{
"statementType": "SELECT",
"joinColumns": [
"user.id =",
"user_extra.user_id ="
],
"filterColumns": [
"user.col1 gt",
"user_extra.noLimit ="
]
}`,
}, {
// same as above, but written differently
query: "select * from user_extra ue, user u where ue.noLimit = 'foo' and u.col1 > 100 and ue.user_id = u.id",
expectedRowString: `{
"statementType": "SELECT",
"joinColumns": [
"user.id =",
"user_extra.user_id ="
],
"filterColumns": [
"user.col1 gt",
"user_extra.noLimit ="
]
}`,
},
{
query: "select u.foo, ue.bar, count(*) from user u join user_extra ue on u.id = ue.user_id where u.name = 'John Doe' group by 1, 2",
expectedRowString: `{
"statementType": "SELECT",
"groupingColumns": [
"user.foo",
"user_extra.bar"
],
"joinColumns": [
"user.id =",
"user_extra.user_id ="
],
"filterColumns": [
"user.name ="
],
"selectColumns": [
"user.foo",
"user_extra.bar"
]
}`,
},
{
query: "select * from (select * from user) as derived where derived.amount > 1000",
expectedRowString: `{
"statementType": "SELECT"
}`,
},
{
query: "select name, sum(amount) from user group by name",
expectedRowString: `{
"statementType": "SELECT",
"groupingColumns": [
"user.name"
],
"selectColumns": [
"user.amount",
"user.name"
]
}`,
},
{
query: "select name from user where age > 30",
expectedRowString: `{
"statementType": "SELECT",
"filterColumns": [
"user.age gt"
],
"selectColumns": [
"user.name"
]
}`,
},
{
query: "select * from user where name = 'apa' union select * from user_extra where name = 'monkey'",
expectedRowString: `{
"statementType": "SELECT",
"filterColumns": [
"user.name =",
"user_extra.name ="
]
}`,
},
{
query: "update user set name = 'Jane Doe' where id = 1",
expectedRowString: `{
"statementType": "UPDATE",
"filterColumns": [
"user.id ="
]
}`,
},
{
query: "delete from user where order_date < '2023-01-01'",
expectedRowString: `{
"statementType": "DELETE",
"filterColumns": [
"user.order_date lt"
]
}`,
},
{
query: "select * from user where name between 'A' and 'C'",
expectedRowString: `{
"statementType": "SELECT",
"filterColumns": [
"user.name ge",
"user.name le"
]
}`,
},
type testCase struct {
Query string `json:"query"`
Expected json.RawMessage `json:"expected"`
}

var tests []testCase
data, err := os.ReadFile("testdata/executor_vexplain.json")
require.NoError(t, err)

err = json.Unmarshal(data, &tests)
require.NoError(t, err)

var updatedTests []testCase

for _, tt := range tests {
t.Run(tt.query, func(t *testing.T) {
t.Run(tt.Query, func(t *testing.T) {
executor, _, _, _, _ := createExecutorEnv(t)
session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"})
gotResult, err := executor.Execute(context.Background(), nil, "Execute", session, "vexplain keys "+tt.query, nil)
gotResult, err := executor.Execute(context.Background(), nil, "Execute", session, "vexplain keys "+tt.Query, nil)
require.NoError(t, err)

gotRowString := gotResult.Rows[0][0].ToString()
assert.Equal(t, tt.expectedRowString, gotRowString)
assert.JSONEq(t, string(tt.Expected), gotRowString)

updatedTests = append(updatedTests, testCase{
Query: tt.Query,
Expected: json.RawMessage(gotRowString),
})

if t.Failed() {
fmt.Println(gotRowString)
fmt.Println("Test failed for query:", tt.Query)
fmt.Println("Got result:", gotRowString)
}
})
}

// If anything failed, write the updated test cases to a temp file
if t.Failed() {
tempFilePath := filepath.Join(os.TempDir(), "updated_vexplain_keys_tests.json")
fmt.Println("Writing updated tests to:", tempFilePath)

updatedTestsData, err := json.MarshalIndent(updatedTests, "", "\t")
require.NoError(t, err)

err = os.WriteFile(tempFilePath, updatedTestsData, 0644)
require.NoError(t, err)

fmt.Println("Updated tests written to:", tempFilePath)
}
}
39 changes: 33 additions & 6 deletions go/vt/vtgate/planbuilder/operators/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,37 @@ func (cu *ColumnUse) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, &s); err != nil {
return err
}
parts := strings.Fields(s)
if len(parts) != 2 {
spaceIdx := strings.LastIndex(s, " ")
if spaceIdx == -1 {
return fmt.Errorf("invalid ColumnUse format: %s", s)
}
if err := cu.Column.UnmarshalJSON([]byte(`"` + parts[0] + `"`)); err != nil {
return err

for i := spaceIdx - 1; i >= 0; i-- {
// table.column not like
// table.`tricky not` like
if s[i] == '`' || s[i] == '.' {
break
}
if s[i] == ' ' {
spaceIdx = i
break
}
if i == 0 {
return fmt.Errorf("invalid ColumnUse format: %s", s)
}
}

colStr, opStr := s[:spaceIdx], s[spaceIdx+1:]

err := cu.Column.UnmarshalJSON([]byte(`"` + colStr + `"`))
if err != nil {
return fmt.Errorf("failed to unmarshal column: %w", err)
}

cu.Uses, err = sqlparser.ComparisonExprOperatorFromJson(strings.ToLower(opStr))
if err != nil {
return fmt.Errorf("failed to unmarshal operator: %w", err)
}
cu.Uses = sqlparser.ComparisonExprOperatorFromJson(strings.ToLower(parts[1]))
return nil
}

Expand Down Expand Up @@ -209,5 +232,9 @@ func createColumn(ctx *plancontext.PlanningContext, col *sqlparser.ColName) *Col
if table == nil {
return nil
}
return &Column{Table: table.Name.String(), Name: col.Name.String()}
return &Column{
// we want the escaped versions of the names
Table: sqlparser.String(table.Name),
Name: sqlparser.String(col.Name),
}
}
7 changes: 4 additions & 3 deletions go/vt/vtgate/planbuilder/operators/keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,21 @@ func TestMarshalUnmarshal(t *testing.T) {
StatementType: "SELECT",
TableName: []string{"users", "orders"},
GroupingColumns: []Column{
{Table: "", Name: "category"},
{Table: "orders", Name: "category"},
{Table: "users", Name: "department"},
},
JoinColumns: []ColumnUse{
{Column: Column{Table: "users", Name: "id"}, Uses: sqlparser.EqualOp},
{Column: Column{Table: "orders", Name: "user_id"}, Uses: sqlparser.EqualOp},
},
FilterColumns: []ColumnUse{
{Column: Column{Table: "", Name: "age"}, Uses: sqlparser.GreaterThanOp},
{Column: Column{Table: "users", Name: "age"}, Uses: sqlparser.GreaterThanOp},
{Column: Column{Table: "orders", Name: "total"}, Uses: sqlparser.LessThanOp},
{Column: Column{Table: "orders", Name: "`tricky name not`"}, Uses: sqlparser.InOp},
},
SelectColumns: []Column{
{Table: "users", Name: "name"},
{Table: "", Name: "email"},
{Table: "users", Name: "email"},
{Table: "orders", Name: "amount"},
},
}
Expand Down
Loading
Loading