diff --git a/command/init.go b/command/init.go index 7757eca9..c7bd578c 100644 --- a/command/init.go +++ b/command/init.go @@ -126,8 +126,9 @@ func init() { "zrevrangebyscore": Desc{Proc: AutoCommit(ZRevRangeByScore), Txn: ZRevRangeByScore, Cons: Constraint{-4, flags("rF"), 1, 1, 1}}, "zrem": Desc{Proc: AutoCommit(ZRem), Txn: ZRem, Cons: Constraint{-3, flags("wF"), 1, 1, 1}}, "zcard": Desc{Proc: AutoCommit(ZCard), Txn: ZCard, Cons: Constraint{2, flags("rF"), 1, 1, 1}}, - //"zcount": Desc{Proc: AutoCommit(ZCount), Txn: ZCount, Cons: Constraint{4, flags("rF"), 1, 1, 1}}, - "zscore": Desc{Proc: AutoCommit(ZScore), Txn: ZScore, Cons: Constraint{3, flags("rF"), 1, 1, 1}}, + "zcount": Desc{Proc: AutoCommit(ZCount), Txn: ZCount, Cons: Constraint{-4, flags("rF"), 1, 1, 1}}, + "zscore": Desc{Proc: AutoCommit(ZScore), Txn: ZScore, Cons: Constraint{3, flags("rF"), 1, 1, 1}}, + "zscan": Desc{Proc: AutoCommit(ZScan), Txn: ZScan, Cons: Constraint{-3, flags("rF"), 1, 1, 1}}, // extension commands "escan": Desc{Proc: AutoCommit(Escan), Txn: Escan, Cons: Constraint{-1, flags("rR"), 0, 0, 0}}, diff --git a/command/zsets.go b/command/zsets.go index 05f633c3..752cd4f5 100644 --- a/command/zsets.go +++ b/command/zsets.go @@ -2,20 +2,18 @@ package command import ( "errors" - "fmt" "math" "strconv" "strings" "github.com/distributedio/titan/db" + "github.com/distributedio/titan/encoding/resp" ) // ZAdd adds the specified members with scores to the sorted set func ZAdd(ctx *Context, txn *db.Transaction) (OnCommit, error) { key := []byte(ctx.Args[0]) - fmt.Println("zadd", ctx.Args) - kvs := ctx.Args[1:] if len(kvs)%2 != 0 { return nil, errors.New("ERR syntax error") @@ -111,6 +109,41 @@ func ZRevRangeByScore(ctx *Context, txn *db.Transaction) (OnCommit, error) { return zAnyOrderRangeByScore(ctx, txn, false) } +func ZCount(ctx *Context, txn *db.Transaction) (OnCommit, error) { + key := []byte(ctx.Args[0]) + startScore, startInclude, err := getFloatAndInclude(ctx.Args[1]) + if err != nil { + return nil, ErrMinOrMaxNotFloat + } + endScore, endInclude, err := getFloatAndInclude(ctx.Args[2]) + if err != nil { + return nil, ErrMinOrMaxNotFloat + } + zset, err := txn.ZSet(key) + if err != nil { + if err == db.ErrTypeMismatch { + return nil, ErrTypeMismatch + } + return nil, errors.New("ERR " + err.Error()) + } + if !zset.Exist() { + return Integer(ctx.Out, 0), nil + } + + items, err := zset.ZAnyOrderRangeByScore(startScore, startInclude, + endScore, endInclude, + false, + int64(0), math.MaxInt64, + true) + if err != nil { + return nil, errors.New("ERR " + err.Error()) + } + if len(items) == 0 { + return Integer(ctx.Out, 0), nil + } + return Integer(ctx.Out, int64(len(items))), nil +} + func zAnyOrderRangeByScore(ctx *Context, txn *db.Transaction, positiveOrder bool) (OnCommit, error) { key := []byte(ctx.Args[0]) @@ -240,3 +273,91 @@ func ZScore(ctx *Context, txn *db.Transaction) (OnCommit, error) { return BulkString(ctx.Out, string(score)), nil } + +func ZScan(ctx *Context, txn *db.Transaction) (OnCommit, error) { + var ( + key []byte + cursor []byte + lastCursor = []byte("0") + count = uint64(defaultScanCount) + kvs = [][]byte{} + pattern []byte + isAll bool + err error + ) + key = []byte(ctx.Args[0]) + if strings.Compare(ctx.Args[1], "0") != 0 { + cursor = []byte(ctx.Args[1]) + } + + // define return result + result := func() { + if _, err := resp.ReplyArray(ctx.Out, 2); err != nil { + return + } + resp.ReplyBulkString(ctx.Out, string(lastCursor)) + if _, err := resp.ReplyArray(ctx.Out, len(kvs)); err != nil { + return + } + for i := range kvs { + resp.ReplyBulkString(ctx.Out, string(kvs[i])) + } + } + zset, err := txn.ZSet(key) + if err != nil { + if err == db.ErrTypeMismatch { + return nil, ErrTypeMismatch + } + return nil, errors.New("ERR " + err.Error()) + } + + if !zset.Exist() { + return result, nil + } + + if len(ctx.Args)%2 != 0 { + return nil, ErrSyntax + } + + for i := 2; i < len(ctx.Args); i += 2 { + arg := strings.ToLower(ctx.Args[i]) + next := ctx.Args[i+1] + switch arg { + case "count": + if count, err = strconv.ParseUint(next, 10, 64); err != nil { + return nil, ErrInteger + } + if count > ScanMaxCount { + count = ScanMaxCount + } + if count == 0 { + count = uint64(defaultScanCount) + } + case "match": + pattern = []byte(next) + isAll = (pattern[0] == '*' && len(pattern) == 1) + } + } + + if len(pattern) == 0 { + isAll = true + } + f := func(member, score []byte) bool { + if count <= 0 { + lastCursor = member + return false + } + if isAll || globMatch(pattern, member, false) { + kvs = append(kvs, member) + kvs = append(kvs, score) + count-- + } + return true + } + + if err := zset.ZScan(cursor, f); err != nil { + return nil, errors.New("ERR " + err.Error()) + } + return result, nil + +} diff --git a/db/zset.go b/db/zset.go index 2e89aed6..82146d8b 100644 --- a/db/zset.go +++ b/db/zset.go @@ -418,9 +418,7 @@ func (zset *ZSet) ZCard() int64 { } func (zset *ZSet) ZScore(member []byte) ([]byte, error) { - dkey := DataKey(zset.txn.db, zset.meta.ID) - memberKey := zsetMemberKey(dkey, member) - bytesScore, err := zset.txn.t.Get(zset.txn.ctx, memberKey) + bScore, err := zset.zScoreBytes(member) if err != nil { if IsErrNotFound(err) { return nil, nil @@ -428,11 +426,60 @@ func (zset *ZSet) ZScore(member []byte) ([]byte, error) { return nil, err } - fscore := DecodeFloat64(bytesScore) + fscore := DecodeFloat64(bScore) sscore := strconv.FormatFloat(fscore, 'f', -1, 64) return []byte(sscore), nil } +func (zset *ZSet) zScoreBytes(member []byte) ([]byte, error) { + dkey := DataKey(zset.txn.db, zset.meta.ID) + memberKey := zsetMemberKey(dkey, member) + bScore, err := zset.txn.t.Get(zset.txn.ctx, memberKey) + if err != nil { + return nil, err + } + return bScore, nil +} + +func (zset *ZSet) ZScan(cursor []byte, f func(key, val []byte) bool) error { + if !zset.Exist() { + return nil + } + dkey := DataKey(zset.txn.db, zset.meta.ID) + prefix := ZSetScorePrefix(dkey) + endPrefix := kv.Key(prefix).PrefixNext() + ikey := prefix + if len(cursor) > 0 { + bScore, err := zset.zScoreBytes(cursor) + if err != nil { + if IsErrNotFound(err) { + return nil + } + return err + } + if len(bScore) > 0 { + ikey = append(ikey, bScore...) + } + } + iter, err := zset.txn.t.Iter(ikey, endPrefix) + if err != nil { + return err + } + for iter.Valid() && iter.Key().HasPrefix(prefix) { + scoreAndMember := iter.Key()[len(prefix):] + member := scoreAndMember[byteScoreLen+len(":"):] + byteScore := scoreAndMember[0:byteScoreLen] + score := []byte(strconv.FormatFloat(DecodeFloat64(byteScore), 'f', -1, 64)) + if !f(member, score) { + break + } + if err := iter.Next(); err != nil { + return err + } + } + return nil +} + func zsetMemberKey(dkey []byte, member []byte) []byte { var memberKey []byte memberKey = append(memberKey, dkey...) diff --git a/db/zset_test.go b/db/zset_test.go index c0b028e7..42b9ac8a 100644 --- a/db/zset_test.go +++ b/db/zset_test.go @@ -464,3 +464,67 @@ func TestZSetZAnyOrderRangeScore(t *testing.T) { }) } } + +func TestZSet_ZScan(t *testing.T) { + var members [][]byte + var score []float64 + + members = append(members, []byte("abc")) + members = append(members, []byte("aec")) + members = append(members, []byte("acc")) + members = append(members, []byte("bc")) + score = append(score, -1.1, -1, 1, 2.1) + + zset, txn, err := getZSet(t, []byte("TestZSet_ZScan")) + assert.NoError(t, err) + assert.NotNil(t, txn) + assert.NotNil(t, zset) + count, err := zset.ZAdd(members, score) + assert.NoError(t, err) + assert.Equal(t, count, int64(len(members))) + txn.Commit(context.TODO()) + + type args struct { + cursor []byte + f func(key, val []byte) bool + } + var value [][]byte + count = 2 + + tests := []struct { + name string + args args + want [][]byte + }{ + { + name: "TestZSet_ZScan", + args: args{ + f: func(member, score []byte) bool { + if count == 0 { + return false + } + value = append(value, member, score) + count-- + return true + + }, + }, + want: append([][]byte{}, []byte("abc"), []byte("-1.1"), []byte("aec"), []byte("-1")), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + zset, txn, err := getZSet(t, []byte("TestZSet_ZScan")) + assert.NoError(t, err) + assert.NotNil(t, txn) + assert.NotNil(t, zset) + + err = zset.ZScan(tt.args.cursor, tt.args.f) + txn.Commit(context.TODO()) + + assert.Equal(t, value, tt.want) + assert.NoError(t, err) + }) + } +} diff --git a/tools/autotest/auto.go b/tools/autotest/auto.go index ed28b560..11408a42 100644 --- a/tools/autotest/auto.go +++ b/tools/autotest/auto.go @@ -162,6 +162,36 @@ func (ac *AutoClient) ZSetCase(t *testing.T) { ac.ez.ZRangeByScoreEqual(t, "key-zset", "(2", "3.6", true, "", "member6 2.05 member3 3.6") ac.ez.ZRangeByScoreEqual(t, "key-zset", "0", "(2", true, "", "member5 0 member2 1.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "+inf", "-inf", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0 member4 -3.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "(+inf", "-inf", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0 member4 -3.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "+inf", "(-inf", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0 member4 -3.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "-inf", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0 member4 -3.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "-inf", false, "", "member3 member6 member11 member1 member2 member5 member4") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "-3.5", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0 member4 -3.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "(-3.5", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "0.0", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5 member5 0") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "inf", "(0.0", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "3.6", "(0.0", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "+3.6", "(0.0", true, "", "member3 3.6 member6 2.05 member11 2 member1 2 member2 1.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "(3.6", "(0.0", true, "", "member6 2.05 member11 2 member1 2 member2 1.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "", "member6 2.05 member11 2 member1 2 member2 1.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT -1 1", "") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "limit 0 -1", "member6 2.05 member11 2 member1 2 member2 1.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 0 0", "") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 0 2", "member6 2.05 member11 2") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 0 4", "member6 2.05 member11 2 member1 2 member2 1.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 0 5", "member6 2.05 member11 2 member1 2 member2 1.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 1 2", "member11 2 member1 2") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 3 2", "member2 1.5") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "2.05", "(0.0", true, "LIMIT 4 2", "") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "3.6", "(2", true, "", "member3 3.6 member6 2.05") + ac.ez.ZRevRangeByScoreEqual(t, "key-zset", "(2", "0", true, "", "member2 1.5 member5 0") + + ac.ez.ZCountEqual(t, "key-zset", "0", "(2", int64(2)) + + ac.ez.ZScanEqual(t, "key-zset", "0", "*", 2, "member2 member4 -3.5 member5 0") + ac.ez.ZScanEqual(t, "key-zset", "member2", "member*", 2, "member11 member2 1.5 member1 2") + ac.ez.ZRemEqual(t, "key-zset", "member2", "member1", "member3", "member4", "member1") ac.ez.ZRangeEqual(t, "key-zset", 0, -1, true) diff --git a/tools/autotest/cmd/zset.go b/tools/autotest/cmd/zset.go index 70a5bd4f..c2f51169 100644 --- a/tools/autotest/cmd/zset.go +++ b/tools/autotest/cmd/zset.go @@ -1,6 +1,7 @@ package cmd import ( + "fmt" "sort" "strconv" "strings" @@ -174,6 +175,45 @@ func (ez *ExampleZSet) ZRevRangeEqualErr(t *testing.T, errValue string, args ... assert.EqualError(t, err, errValue) } +func (ez *ExampleZSet) ZScanEqual(t *testing.T, key string, cursor string, pattern string, count int, expected string) { + cmd := "zscan" + req := make([]interface{}, 0) + req = append(req, key) + req = append(req, cursor) + req = append(req, "match", pattern) + req = append(req, "count", count) + + reply, err := redis.MultiBulk(ez.conn.Do(cmd, req...)) + lastCursor, _ := redis.String(reply[0], err) + strs, _ := redis.Strings(reply[1], err) + fmt.Println(lastCursor, strs) + if expected != "" { + expectedStrs := strings.Split(expected, " ") + assert.Equal(t, expectedStrs[0], lastCursor) + assert.Equal(t, expectedStrs[1:], strs) + } else { + assert.Equal(t, "0", lastCursor) + } + assert.Nil(t, err) +} + +func (ez *ExampleZSet) ZCountEqual(t *testing.T, key string, start string, stop string, expected int64) { + cmd := "zcount" + req := make([]interface{}, 0) + req = append(req, key) + req = append(req, start) + req = append(req, stop) + + reply, err := redis.Int64(ez.conn.Do(cmd, req...)) + assert.Equal(t, expected, reply) + assert.Nil(t, err) +} + +func (ez *ExampleZSet) ZCountEqualErr(t *testing.T, errValue string, args ...interface{}) { + _, err := ez.conn.Do("zcount", args...) + assert.EqualError(t, err, errValue) +} + func (ez *ExampleZSet) ZRangeByScoreEqual(t *testing.T, key string, start string, stop string, withScores bool, limit string, expected string) { ez.ZAnyOrderRangeByScoreEqual(t, key, start, stop, withScores, true, limit, expected) }