From 92d9bb2565be782805019c0548910f5d30a7ab1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Putra?= Date: Mon, 1 Aug 2022 18:53:15 +0200 Subject: [PATCH] session: prepare now prepares on all nodes Also made the driver fallback from token aware policy to round robin if both session keyspace and query keyspace are unspecified. Fixes #249 --- session.go | 47 +++++++++++++++++++---------- session_integration_test.go | 60 +++++++++++++++++++++++++++++++++++++ transport/cluster.go | 10 +++++-- transport/node.go | 4 +++ transport/policy.go | 6 ++-- transport/policy_test.go | 6 ++-- 6 files changed, 108 insertions(+), 25 deletions(-) diff --git a/session.go b/session.go index 8ff5be4e..4c4c2314 100644 --- a/session.go +++ b/session.go @@ -3,6 +3,7 @@ package scylla import ( "fmt" "log" + "sync" "github.com/mmatczuk/scylla-go-driver/frame" "github.com/mmatczuk/scylla-go-driver/transport" @@ -147,24 +148,38 @@ func (s *Session) Query(content string) Query { } func (s *Session) Prepare(content string) (Query, error) { - n := s.policy.Node(s.cluster.NewQueryInfo(), 0) - conn := n.LeastBusyConn() - if conn == nil { - return Query{}, errNoConnection - } - stmt := transport.Statement{Content: content, Consistency: frame.ALL} - res, err := conn.Prepare(stmt) - return Query{session: s, - stmt: res, - exec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes) (transport.QueryResult, error) { - return conn.Execute(stmt, pagingState) - }, - asyncExec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes, handler transport.ResponseHandler) { - conn.AsyncExecute(stmt, pagingState, handler) - }, - }, err + // Prepare on all nodes concurrently. + nodes := s.cluster.Topology().Nodes + resStmt := make([]transport.Statement, len(nodes)) + resErr := make([]error, len(nodes)) + var wg sync.WaitGroup + for i := range nodes { + wg.Add(1) + go func(idx int) { + defer wg.Done() + resStmt[idx], resErr[idx] = nodes[idx].Prepare(stmt) + }(i) + } + wg.Wait() + + // Find first result that succeeded. + for i := range nodes { + if resErr[i] == nil { + return Query{session: s, + stmt: resStmt[i], + exec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes) (transport.QueryResult, error) { + return conn.Execute(stmt, pagingState) + }, + asyncExec: func(conn *transport.Conn, stmt transport.Statement, pagingState frame.Bytes, handler transport.ResponseHandler) { + conn.AsyncExecute(stmt, pagingState, handler) + }, + }, nil + } + } + + return Query{}, fmt.Errorf("prepare failed on all nodes, details: %v", resErr) } func (s *Session) NewTokenAwarePolicy() transport.HostSelectionPolicy { diff --git a/session_integration_test.go b/session_integration_test.go index 410e0cf0..9933f017 100644 --- a/session_integration_test.go +++ b/session_integration_test.go @@ -6,8 +6,10 @@ import ( "crypto/tls" "crypto/x509" "errors" + "fmt" "io/ioutil" "testing" + "time" "go.uber.org/goleak" ) @@ -338,3 +340,61 @@ func TestTLSIntegration(t *testing.T) { }) } } + +func TestPrepareIntegration(t *testing.T) { + defer goleak.VerifyNone(t) + + cfg := DefaultSessionConfig("", "192.168.100.100:9042") + session, err := NewSession(cfg) + defer session.Close() + + if err != nil { + t.Fatal(err) + } + + initStmts := []string{ + "DROP KEYSPACE IF EXISTS testks", + "CREATE KEYSPACE IF NOT EXISTS testks WITH replication = {'class': 'SimpleStrategy', 'replication_factor' : 1}", + "CREATE TABLE IF NOT EXISTS testks.doubles (pk bigint PRIMARY KEY, v bigint)", + } + + for _, stmt := range initStmts { + q := session.Query(stmt) + if _, err := q.Exec(); err != nil { + t.Fatal(err) + } + time.Sleep(time.Second) + } + + q, err := session.Prepare("INSERT INTO testks.doubles (pk, v) VALUES (?, ?)") + if err != nil { + t.Fatal(err) + } + + for i := int64(0); i < 1000; i++ { + _, err := q.BindInt64(0, i).BindInt64(1, 2*i).Exec() + if err != nil { + t.Fatal(err) + } + } + + for i := int64(0); i < 1000; i++ { + q, err := session.Prepare("SELECT v FROM testks.doubles WHERE pk = " + fmt.Sprint(i)) + if err != nil { + t.Fatal(err) + } + + for rep := 0; rep < 3; rep++ { + res, err := q.Exec() + if err != nil { + t.Fatal(err) + } + + if v, err := res.Rows[0][0].AsInt64(); err != nil { + t.Fatal(err) + } else if v != 2*i { + t.Fatalf("expected %d, got %d", 2*i, v) + } + } + } +} diff --git a/transport/cluster.go b/transport/cluster.go index da3404e6..79768c2d 100644 --- a/transport/cluster.go +++ b/transport/cluster.go @@ -41,7 +41,7 @@ type topology struct { localDC string peers peerMap dcRacks dcRacksMap - nodes []*Node + Nodes []*Node policyInfo policyInfo keyspaces ksMap } @@ -94,6 +94,10 @@ func (c *Cluster) NewTokenAwareQueryInfo(t Token, ks string) (QueryInfo, error) top := c.Topology() // When keyspace is not specified, we take default keyspace from ConnConfig. if ks == "" { + if c.cfg.Keyspace == "" { + // We don't know anything about the keyspace, fallback to non-token aware query. + return c.NewQueryInfo(), nil + } ks = c.cfg.Keyspace } if stg, ok := top.keyspaces[ks]; ok { @@ -219,7 +223,7 @@ func (c *Cluster) refreshTopology() error { // Every encountered node becomes known host for future use. c.knownHosts[n.addr] = struct{}{} t.peers[n.addr] = n - t.nodes = append(t.nodes, n) + t.Nodes = append(t.Nodes, n) u[uniqueRack{dc: n.datacenter, rack: n.rack}] = struct{}{} if err := parseTokensFromRow(n, r, &t.policyInfo.ring); err != nil { return err @@ -251,7 +255,7 @@ func newTopology() *topology { return &topology{ peers: make(peerMap), dcRacks: make(dcRacksMap), - nodes: make([]*Node, 0), + Nodes: make([]*Node, 0), policyInfo: policyInfo{ ring: make(Ring, 0), }, diff --git a/transport/node.go b/transport/node.go index 81c97d5b..a9db47c8 100644 --- a/transport/node.go +++ b/transport/node.go @@ -37,6 +37,10 @@ func (n *Node) Conn(token Token) *Conn { return n.pool.Conn(token) } +func (n *Node) Prepare(s Statement) (Statement, error) { + return n.LeastBusyConn().Prepare(s) +} + type RingEntry struct { node *Node token Token diff --git a/transport/policy.go b/transport/policy.go index 19027ef0..937112a8 100644 --- a/transport/policy.go +++ b/transport/policy.go @@ -86,7 +86,7 @@ func (pi *policyInfo) Preprocess(t *topology, ks keyspace) { } func (pi *policyInfo) preprocessSimpleStrategy(t *topology, stg strategy) { - pi.localNodes = t.nodes + pi.localNodes = t.Nodes sort.Sort(pi.ring) trie := trieRoot() for i := range pi.ring { @@ -122,14 +122,14 @@ func (pi *policyInfo) preprocessSimpleStrategy(t *topology, stg strategy) { } func (pi *policyInfo) preprocessRoundRobinStrategy(t *topology) { - pi.localNodes = t.nodes + pi.localNodes = t.Nodes pi.remoteNodes = nil } func (pi *policyInfo) preprocessDCAwareRoundRobinStrategy(t *topology) { pi.localNodes = make([]*Node, 0) pi.remoteNodes = make([]*Node, 0) - for _, v := range t.nodes { + for _, v := range t.Nodes { if v.datacenter == t.localDC { pi.localNodes = append(pi.localNodes, v) } else { diff --git a/transport/policy_test.go b/transport/policy_test.go index b05c30a5..9c93f728 100644 --- a/transport/policy_test.go +++ b/transport/policy_test.go @@ -20,7 +20,7 @@ func mockTopologyRoundRobin() *topology { } return &topology{ - nodes: dummyNodes, + Nodes: dummyNodes, } } @@ -173,7 +173,7 @@ func mockTopologyTokenAwareSimpleStrategy() *topology { } return &topology{ - nodes: dummyNodes, + Nodes: dummyNodes, policyInfo: policyInfo{ ring: ring, }, @@ -288,7 +288,7 @@ func mockTopologyTokenAwareDCAwareStrategy() *topology { return &topology{ dcRacks: dcs, - nodes: dummyNodes, + Nodes: dummyNodes, policyInfo: policyInfo{ring: ring}, keyspaces: ks, }