diff --git a/clients/goredis.go b/clients/goredis.go index 267e443..f7cfb63 100644 --- a/clients/goredis.go +++ b/clients/goredis.go @@ -9,12 +9,25 @@ import ( "github.com/nitishm/go-rejson/rjs" ) -var ctx = context.Background() - // GoRedis implements ReJSON interface for Go-Redis/Redis Redis client // Link: https://github.com/go-redis/redis type GoRedis struct { Conn *goredis.Client // import goredis "github.com/go-redis/redis/v8" + + // ctx defines context for the provided connection + ctx context.Context +} + +// NewGoRedisClient returns a new GoRedis ReJSON client with the provided context +// and connection, if ctx is nil default context.Background will be used +func NewGoRedisClient(ctx context.Context, conn *goredis.Client) *GoRedis { + if ctx == nil { + ctx = context.Background() + } + return &GoRedis{ + ctx: ctx, + Conn: conn, + } } // JSONSet used to set a json object @@ -39,7 +52,7 @@ func (r *GoRedis) JSONSet(key string, path string, obj interface{}, opts ...rjs. return nil, err } args = append([]interface{}{name}, args...) - res, err = r.Conn.Do(ctx, args...).Result() + res, err = r.Conn.Do(r.ctx, args...).Result() if err != nil && err.Error() == rjs.ErrGoRedisNil.Error() { err = nil @@ -75,7 +88,7 @@ func (r *GoRedis) JSONGet(key, path string, opts ...rjs.GetOption) (res interfac } args = append([]interface{}{name}, args...) - res, err = r.Conn.Do(ctx, args...).Result() + res, err = r.Conn.Do(r.ctx, args...).Result() if err != nil { return } @@ -103,7 +116,7 @@ func (r *GoRedis) JSONMGet(path string, keys ...string) (res interface{}, err er return nil, err } args = append([]interface{}{name}, args...) - res, err = r.Conn.Do(ctx, args...).Result() + res, err = r.Conn.Do(r.ctx, args...).Result() if err != nil { return } @@ -131,7 +144,7 @@ func (r *GoRedis) JSONDel(key string, path string) (res interface{}, err error) return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } // JSONType to get the type of key or member at path. @@ -146,7 +159,7 @@ func (r *GoRedis) JSONType(key, path string) (res interface{}, err error) { return nil, err } args = append([]interface{}{name}, args...) - res, err = r.Conn.Do(ctx, args...).Result() + res, err = r.Conn.Do(r.ctx, args...).Result() if err != nil && err.Error() == rjs.ErrGoRedisNil.Error() { err = nil @@ -166,7 +179,7 @@ func (r *GoRedis) JSONNumIncrBy(key, path string, number int) (res interface{}, return nil, err } args = append([]interface{}{name}, args...) - res, err = r.Conn.Do(ctx, args...).Result() + res, err = r.Conn.Do(r.ctx, args...).Result() if err != nil { return } @@ -185,7 +198,7 @@ func (r *GoRedis) JSONNumMultBy(key, path string, number int) (res interface{}, return nil, err } args = append([]interface{}{name}, args...) - res, err = r.Conn.Do(ctx, args...).Result() + res, err = r.Conn.Do(r.ctx, args...).Result() if err != nil { return } @@ -204,7 +217,7 @@ func (r *GoRedis) JSONStrAppend(key, path, jsonstring string) (res interface{}, return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } // JSONStrLen to return the length of a string member @@ -219,7 +232,7 @@ func (r *GoRedis) JSONStrLen(key, path string) (res interface{}, err error) { return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } // JSONArrAppend to append json value into array at path @@ -241,7 +254,7 @@ func (r *GoRedis) JSONArrAppend(key, path string, values ...interface{}) (res in return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } // JSONArrLen returns the length of the json array at path @@ -256,7 +269,7 @@ func (r *GoRedis) JSONArrLen(key, path string) (res interface{}, err error) { return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } // JSONArrPop removes and returns element from the index in the array @@ -273,7 +286,7 @@ func (r *GoRedis) JSONArrPop(key, path string, index int) (res interface{}, err } args = append([]interface{}{name}, args...) - res, err = r.Conn.Do(ctx, args...).Result() + res, err = r.Conn.Do(r.ctx, args...).Result() if err != nil { return } @@ -303,7 +316,7 @@ func (r *GoRedis) JSONArrIndex(key, path string, jsonValue interface{}, optional return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } // JSONArrTrim trims an array so that it contains only the specified inclusive range of elements @@ -318,7 +331,7 @@ func (r *GoRedis) JSONArrTrim(key, path string, start, end int) (res interface{} return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } // JSONArrInsert inserts the json value(s) into the array at path before the index (shifts to the right). @@ -340,7 +353,7 @@ func (r *GoRedis) JSONArrInsert(key, path string, index int, values ...interface return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } // JSONObjKeys returns the keys in the object that's referenced by path @@ -355,7 +368,7 @@ func (r *GoRedis) JSONObjKeys(key, path string) (res interface{}, err error) { return nil, err } args = append([]interface{}{name}, args...) - res, err = r.Conn.Do(ctx, args...).Result() + res, err = r.Conn.Do(r.ctx, args...).Result() if err != nil { return } @@ -380,7 +393,7 @@ func (r *GoRedis) JSONObjLen(key, path string) (res interface{}, err error) { return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } // JSONDebug reports information @@ -401,7 +414,7 @@ func (r *GoRedis) JSONDebug(subcommand rjs.DebugSubCommand, key, path string) (r return nil, err } args = append([]interface{}{name}, args...) - res, err = r.Conn.Do(ctx, args...).Result() + res, err = r.Conn.Do(r.ctx, args...).Result() if err != nil { return } @@ -430,7 +443,7 @@ func (r *GoRedis) JSONForget(key, path string) (res interface{}, err error) { return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } // JSONResp returns the JSON in key in Redis Serialization Protocol (RESP). @@ -445,5 +458,5 @@ func (r *GoRedis) JSONResp(key, path string) (res interface{}, err error) { return nil, err } args = append([]interface{}{name}, args...) - return r.Conn.Do(ctx, args...).Result() + return r.Conn.Do(r.ctx, args...).Result() } diff --git a/context.go b/context.go new file mode 100644 index 0000000..8007f40 --- /dev/null +++ b/context.go @@ -0,0 +1,30 @@ +package rejson + +import ( + "context" + "github.com/nitishm/go-rejson/clients" + "github.com/nitishm/go-rejson/rjs" +) + +// SetContext helps redis-clients, provide use of command level context +// in the ReJSON commands. +// Currently, only go-redis@v8 supports command level context, therefore +// a separate method is added to support it, maintaining the support for +// other clients and for backward compatibility. (nitishm/go-rejson#46) +func (r *Handler) SetContext(ctx context.Context) *Handler { + if r == nil { + return r // nil + } + + if r.clientName == rjs.ClientGoRedis { + if old, ok := r.implementation.(*clients.GoRedis); ok { + return &Handler{ + clientName: r.clientName, + implementation: clients.NewGoRedisClient(ctx, old.Conn), + } + } + } + + // for other clients, context is of no use, hence return same + return r +} diff --git a/rejson_test.go b/rejson_test.go index 8965cda..06c4ffa 100644 --- a/rejson_test.go +++ b/rejson_test.go @@ -12,8 +12,6 @@ import ( redigo "github.com/gomodule/redigo/redis" ) -var ctx = context.Background() - func TestUnsupportedCommand(t *testing.T) { _, _, err := rjs.CommandBuilder(1234, nil) if err == nil { @@ -23,53 +21,34 @@ func TestUnsupportedCommand(t *testing.T) { } type TestClient struct { + *testing.T name string conn interface{} rh *Handler } -func (t *TestClient) init() { +type helper struct { + cli interface{} + name string + closeFunc func() +} + +func (t *TestClient) init() []helper { t.name = "-" t.conn = "inactive" t.rh = NewReJSONHandler() -} - -func (t *TestClient) SetTestingClient(conn interface{}) { - t.conn = conn - - switch conn := conn.(type) { - case redigo.Conn: - t.name = "Redigo-" - t.rh.SetRedigoClient(conn) - case *goredis.Client: - t.name = "GoRedis-" - t.rh.SetGoRedisClient(conn) - default: - t.name = "-" - t.conn = "inactive" - t.rh.SetClientInactive() - } -} - -func TestReJSON(t *testing.T) { - test := TestClient{} - test.init() // Redigo Test Client redigoCli, err := redigo.Dial("tcp", ":6379") if err != nil { t.Fatalf("redigo - could not connect to redigo: %v", err) - return + return nil } // GoRedis Test Client goredisCli := goredis.NewClient(&goredis.Options{Addr: "localhost:6379"}) - clientsObj := []struct { - cli interface{} - name string - closeFunc func() - }{ + return []helper{ {cli: redigoCli, name: "Redigo ", closeFunc: func() { _, err = redigoCli.Do("FLUSHALL") if err != nil { @@ -81,7 +60,7 @@ func TestReJSON(t *testing.T) { } }}, {cli: goredisCli, name: "GoRedis ", closeFunc: func() { - if err := goredisCli.FlushAll(ctx).Err(); err != nil { + if err := goredisCli.FlushAll(context.Background()).Err(); err != nil { t.Fatalf("goredis - failed to flush: %v", err) } if err := goredisCli.Close(); err != nil { @@ -89,8 +68,29 @@ func TestReJSON(t *testing.T) { } }}, } +} - for _, obj := range clientsObj { +func (t *TestClient) SetTestingClient(conn interface{}) { + t.conn = conn + + switch conn := conn.(type) { + case redigo.Conn: + t.name = "Redigo-" + t.rh.SetRedigoClient(conn) + case *goredis.Client: + t.name = "GoRedis-" + t.rh.SetGoRedisClient(conn) + default: + t.name = "-" + t.conn = "inactive" + t.rh.SetClientInactive() + } +} + +func TestReJSON(t *testing.T) { + test := TestClient{T: t} + list := test.init() + for _, obj := range list { t.Run(obj.name+"TestJSONSet", func(t *testing.T) { test.SetTestingClient(obj.cli) testJSONSet(test.rh, t) @@ -176,6 +176,75 @@ func TestReJSON(t *testing.T) { } +func TestReJSONWithContext(t *testing.T) { + ctx := context.Background() + ctxCn, cancel := context.WithCancel(ctx) + cancel() + + testObj := TestObject{ + Name: "itemName", + Number: 1, + } + res := []byte("{\"name\":\"itemName\",\"number\":1}") + + test := TestClient{T: t} + list := test.init() + for _, obj := range list { + test.SetTestingClient(obj.cli) + rh := test.rh + + // check with canceled context + ok, err := rh.SetContext(ctxCn).JSONSet("testObj#1", ".", testObj) + if rh.clientName == rjs.ClientGoRedis { + if err == nil || ok == "OK" { + t.Errorf("JSONSet() got = %v %v, want nil, error: context.Canceled", ok, err) + } + got, err := rh.JSONGet("testObj#1", ".") + if err == nil || reflect.DeepEqual(got, res) { + t.Errorf("JSONGet() got = %v %v, want: no key found error", got, err) + } + } else { + if err != nil || ok != "OK" { + t.Errorf("JSONSet() got = %v %v, want OK, nil", ok, err) + } + got, err := rh.JSONGet("testObj#1", ".") + if err != nil || !reflect.DeepEqual(got, res) { + t.Errorf("JSONGet() got = %v %v, want: %v", got, err, res) + } + } + + // check with normal context + ok, err = rh.SetContext(ctx).JSONSet("testObj#2", ".", testObj) + if err != nil || ok != "OK" { + t.Errorf("JSONSet() got = %v %v, want OK, nil", ok, err) + } + got, err := rh.JSONGet("testObj#2", ".") + if err != nil || !reflect.DeepEqual(got, res) { + t.Errorf("JSONGet() got = %v %v, want: %v", got, err, res) + } + got, err = rh.SetContext(ctx).JSONGet("testObj#2", ".") + if err != nil || !reflect.DeepEqual(got, res) { + t.Errorf("JSONGet() got = %v %v, want: %v", got, err, res) + } + + // check without context + ok, err = rh.JSONSet("testObj#3", ".", testObj) + if err != nil || ok != "OK" { + t.Errorf("JSONSet() got = %v %v, want OK, nil", ok, err) + } + got, err = rh.JSONGet("testObj#3", ".") + if err != nil || !reflect.DeepEqual(got, res) { + t.Errorf("JSONGet() got = %v %v, want: %v", got, err, res) + } + got, err = rh.SetContext(ctx).JSONGet("testObj#3", ".") + if err != nil || !reflect.DeepEqual(got, res) { + t.Errorf("JSONGet() got = %v %v, want: %v", got, err, res) + } + + obj.closeFunc() + } +} + type TestObject struct { Name string `json:"name"` Number int `json:"number"` diff --git a/rjs/constants.go b/rjs/constants.go index 1665fcc..f06a3a3 100644 --- a/rjs/constants.go +++ b/rjs/constants.go @@ -17,6 +17,12 @@ const ( // ClientInactive signifies that the client is inactive in Handler ClientInactive = "inactive" + // ClientRedigo signifies that the current client is redigo + ClientRedigo = "redigo" + + // ClientGoRedis signifies that the current client is go-redis + ClientGoRedis = "goredis" + // PopArrLast gives index of the last element for JSONArrPop PopArrLast = -1 diff --git a/set_client.go b/set_client.go index 9edba97..2f04962 100644 --- a/set_client.go +++ b/set_client.go @@ -1,6 +1,7 @@ package rejson import ( + "context" goredis "github.com/go-redis/redis/v8" redigo "github.com/gomodule/redigo/redis" "github.com/nitishm/go-rejson/clients" @@ -29,8 +30,14 @@ func (r *Handler) SetRedigoClient(conn redigo.Conn) { } // SetGoRedisClient sets Go-Redis (https://github.com/go-redis/redis) client to -// the handler +// the handler. It is left for backward compatibility. func (r *Handler) SetGoRedisClient(conn *goredis.Client) { + r.SetGoRedisClientWithContext(nil, conn) +} + +// SetGoRedisClientWithContext sets Go-Redis (https://github.com/go-redis/redis) client to +// the handler with a global context for the connection +func (r *Handler) SetGoRedisClientWithContext(ctx context.Context, conn *goredis.Client) { r.clientName = "goredis" - r.implementation = &clients.GoRedis{Conn: conn} + r.implementation = clients.NewGoRedisClient(ctx, conn) }