Skip to content

Commit

Permalink
feature: support nextval and currval (#718)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lvnszn authored Jul 27, 2023
1 parent d085199 commit 17b2133
Show file tree
Hide file tree
Showing 6 changed files with 398 additions and 5 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ require (
github.com/docker/distribution v2.7.1+incompatible // indirect
github.com/docker/docker v20.10.11+incompatible // indirect
github.com/docker/go-connections v0.4.0 // indirect
github.com/dubbogo/tools v1.0.9 // indirect
github.com/dustin/go-humanize v1.0.0 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-errors/errors v1.0.1 // indirect
Expand Down
209 changes: 209 additions & 0 deletions go.sum

Large diffs are not rendered by default.

30 changes: 26 additions & 4 deletions pkg/runtime/optimize/dml/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,13 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err
stmt := o.Stmt.(*ast.SelectStatement)
enableLocalMathComputation := ctx.Value(proto.ContextKeyEnableLocalComputation{}).(bool)
if enableLocalMathComputation && len(stmt.From) == 0 {
isLocalFlag := true
var columnList []string
var valueList []proto.Value
var (
isLocalFlag = true
isSequence = false
columnList []string
valueList []proto.Value
vts []*rule.VTable
)
for i := range stmt.Select {
switch selectItem := stmt.Select[i].(type) {
case *ast.SelectElementExpr:
Expand Down Expand Up @@ -95,9 +99,27 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err
}
valueList = append(valueList, calculateRes)
columnList = append(columnList, stmt.Select[i].DisplayName())

case *ast.SelectElementColumn:
if len(selectItem.Name) == 2 &&
(strings.EqualFold(selectItem.Name[1], "currval") || strings.EqualFold(selectItem.Name[1], "nextval")) {
isSequence = true
vt, ok := o.Rule.VTable(selectItem.Name[0])
if !ok {
return nil, proto.ErrorNotFoundSequence
}
vts = append(vts, vt)
}
}
}
if isSequence {
ret := &dml.LocalSequencePlan{
Stmt: stmt,
VTs: vts,
ColumnList: columnList,
}
ret.BindArgs(o.Args)
return ret, nil
}
if isLocalFlag {

ret := &dml.LocalSelectPlan{
Expand Down
4 changes: 3 additions & 1 deletion pkg/runtime/plan/dml/local_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ func (s *LocalSelectPlan) Type() proto.PlanType {
func (s *LocalSelectPlan) ExecIn(ctx context.Context, _ proto.VConn) (proto.Result, error) {
_, span := plan.Tracer.Start(ctx, "LocalSelectPlan.ExecIn")
defer span.End()
var theadLocalSelect thead.Thead
var (
theadLocalSelect thead.Thead
)

for i, item := range s.ColumnList {
sRes := s.Result[i].String()
Expand Down
99 changes: 99 additions & 0 deletions pkg/runtime/plan/dml/local_sequence.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package dml

import (
"context"
"strings"
)

import (
"github.com/pkg/errors"
)

import (
consts "github.com/arana-db/arana/pkg/constants/mysql"
"github.com/arana-db/arana/pkg/dataset"
"github.com/arana-db/arana/pkg/mysql/rows"
"github.com/arana-db/arana/pkg/mysql/thead"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/proto/rule"
"github.com/arana-db/arana/pkg/resultx"
"github.com/arana-db/arana/pkg/runtime/ast"
rcontext "github.com/arana-db/arana/pkg/runtime/context"
"github.com/arana-db/arana/pkg/runtime/plan"
)

var _ proto.Plan = (*LocalSequencePlan)(nil)

type LocalSequencePlan struct {
plan.BasePlan
Stmt *ast.SelectStatement
VTs []*rule.VTable
ColumnList []string
}

func (s *LocalSequencePlan) Type() proto.PlanType {
return proto.PlanTypeQuery
}

func (s *LocalSequencePlan) ExecIn(ctx context.Context, _ proto.VConn) (proto.Result, error) {
_, span := plan.Tracer.Start(ctx, "LocalSequencePlan.ExecIn")

defer span.End()
var (
theadLocalSelect thead.Thead
columns []proto.Field
values []proto.Value
)

for idx := 0; s.Stmt.From == nil && idx < len(s.Stmt.Select); idx++ {
if seqColumn, ok := s.Stmt.Select[idx].(*ast.SelectElementColumn); ok && len(seqColumn.Name) == 2 {
seqName, seqFunc := seqColumn.Name[0], seqColumn.Name[1]
colName := seqColumn.Alias()
if colName == "" {
colName = strings.Join(seqColumn.Name, ".")
}
theadLocalSelect = append(theadLocalSelect, thead.Col{Name: colName, FieldType: consts.FieldTypeLong})
seq, err := proto.LoadSequenceManager().GetSequence(ctx, rcontext.Tenant(ctx), rcontext.Schema(ctx), seqName)
if err != nil {
return nil, errors.WithStack(err)
}

switch strings.ToLower(seqFunc) {
case "currval":
values = append(values, proto.NewValueInt64(seq.(proto.EnhancedSequence).CurrentVal()))
case "nextval":
nextSeqVal, err := seq.Acquire(ctx)
if err != nil {
return nil, err
}
values = append(values, proto.NewValueInt64(nextSeqVal))
}
}
}

columns = theadLocalSelect.ToFields()
ds := &dataset.VirtualDataset{
Columns: columns,
}

ds.Rows = append(ds.Rows, rows.NewTextVirtualRow(columns, values))
return resultx.New(resultx.WithDataset(ds)), nil

}
60 changes: 60 additions & 0 deletions test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package test

import (
"context"
"database/sql"
"fmt"
"sort"
Expand All @@ -40,6 +41,8 @@ import (
)

import (
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/runtime"
"github.com/arana-db/arana/pkg/util/rand2"
utils "github.com/arana-db/arana/pkg/util/tableprint"
)
Expand Down Expand Up @@ -1376,3 +1379,60 @@ func (s *IntegrationSuite) TestExplain() {
})
}
}

func (s *IntegrationSuite) TestSequence() {
var (
db = s.DB()
t = s.T()
)

rt, err := runtime.Load("arana", "employees")
if err != nil {
panic(err)
}
ctx := context.WithValue(context.Background(), proto.RuntimeCtxKey{}, rt)
ctx = context.WithValue(ctx, proto.ContextKeyTenant{}, "arana")
ctx = context.WithValue(ctx, proto.ContextKeySchema{}, "employees")
_, err = proto.LoadSequenceManager().CreateSequence(ctx, "arana", "employees", proto.SequenceConfig{Name: "student", Type: "group"})
if err != nil {
panic(err)
}

type testCase struct {
sql string
exceptVal int64
}

for _, it := range [...]testCase{
{
"select student.nextVal",
1,
},
{
"select student.currVal",
1,
},
{
"select student.nextVal",
2,
},
{
"select notexist.currVal",
-1,
},
} {
t.Run(it.sql, func(t *testing.T) {
rows, err := db.Query(it.sql)
if it.exceptVal == -1 {
assert.True(t, err != nil, err)
return
}
defer rows.Close()
assert.NoError(t, err, "should query successfully")
var val int64
records, _ := utils.PrintTable(rows)
val, err = strconv.ParseInt(records[0][0], 10, 64)
assert.Equal(t, it.exceptVal, val)
})
}
}

0 comments on commit 17b2133

Please sign in to comment.