diff --git a/go/mysql/collations/colldata/uca_contraction_test.go b/go/mysql/collations/colldata/uca_contraction_test.go index d17ff21e255..d09b2a6d982 100644 --- a/go/mysql/collations/colldata/uca_contraction_test.go +++ b/go/mysql/collations/colldata/uca_contraction_test.go @@ -21,7 +21,6 @@ import ( "fmt" "math/rand" "os" - "reflect" "sort" "testing" "unicode/utf8" @@ -36,7 +35,6 @@ type CollationWithContractions struct { Collation Collation Contractions []uca.Contraction ContractFast uca.Contractor - ContractTrie uca.Contractor } func findContractedCollations(t testing.TB, unique bool) (result []CollationWithContractions) { @@ -58,7 +56,7 @@ func findContractedCollations(t testing.TB, unique bool) (result []CollationWith continue } - rf, err := os.Open(fmt.Sprintf("testdata/mysqldata/%s.json", collation.Name())) + rf, err := os.Open(fmt.Sprintf("../testdata/mysqldata/%s.json", collation.Name())) if err != nil { t.Skipf("failed to open JSON metadata (%v). did you run colldump?", err) } @@ -91,14 +89,13 @@ func findContractedCollations(t testing.TB, unique bool) (result []CollationWith Collation: collation, Contractions: meta.Contractions, ContractFast: contract, - ContractTrie: uca.NewTrieContractor(meta.Contractions), }) } return } func testMatch(t *testing.T, name string, cnt uca.Contraction, result []uint16, remainder []byte, skip int) { - assert.True(t, reflect.DeepEqual(cnt.Weights, result), "%s didn't match: expected %#v, got %#v", name, cnt.Weights, result) + assert.Equal(t, result, cnt.Weights, "%s didn't match: expected %#v, got %#v", name, cnt.Weights, result) assert.Equal(t, 0, len(remainder), "%s bad remainder: %#v", name, remainder) assert.Equal(t, len(cnt.Path), skip, "%s bad skipped length %d for %#v", name, skip, cnt.Path) @@ -112,10 +109,7 @@ func TestUCAContractions(t *testing.T) { head := cnt.Path[0] tail := cnt.Path[1] - result := cwc.ContractTrie.FindContextual(head, tail) - testMatch(t, "ContractTrie", cnt, result, nil, 2) - - result = cwc.ContractFast.FindContextual(head, tail) + result := cwc.ContractFast.FindContextual(head, tail) testMatch(t, "ContractFast", cnt, result, nil, 2) continue } @@ -123,10 +117,7 @@ func TestUCAContractions(t *testing.T) { head := cnt.Path[0] tail := string(cnt.Path[1:]) - result, remainder, skip := cwc.ContractTrie.Find(charset.Charset_utf8mb4{}, head, []byte(tail)) - testMatch(t, "ContractTrie", cnt, result, remainder, skip) - - result, remainder, skip = cwc.ContractFast.Find(charset.Charset_utf8mb4{}, head, []byte(tail)) + result, remainder, skip := cwc.ContractFast.Find(charset.Charset_utf8mb4{}, head, []byte(tail)) testMatch(t, "ContractFast", cnt, result, remainder, skip) } }) @@ -239,10 +230,6 @@ func BenchmarkUCAContractions(b *testing.B) { b.Run(fmt.Sprintf("%s-%.02f-fast", cwc.Collation.Name(), frequency), func(b *testing.B) { benchmarkFind(b, input, cwc.ContractFast) }) - - b.Run(fmt.Sprintf("%s-%.02f-trie", cwc.Collation.Name(), frequency), func(b *testing.B) { - benchmarkFind(b, input, cwc.ContractTrie) - }) } } @@ -259,9 +246,5 @@ func BenchmarkUCAContractionsJA(b *testing.B) { b.Run(fmt.Sprintf("%s-%.02f-fast", cwc.Collation.Name(), frequency), func(b *testing.B) { benchmarkFindJA(b, input, cwc.ContractFast) }) - - b.Run(fmt.Sprintf("%s-%.02f-trie", cwc.Collation.Name(), frequency), func(b *testing.B) { - benchmarkFindJA(b, input, cwc.ContractTrie) - }) } } diff --git a/go/mysql/collations/colldata/uca_tables_test.go b/go/mysql/collations/colldata/uca_tables_test.go index 40c2f3bbed3..2071e0630cd 100644 --- a/go/mysql/collations/colldata/uca_tables_test.go +++ b/go/mysql/collations/colldata/uca_tables_test.go @@ -65,7 +65,7 @@ func verifyAllCodepoints(t *testing.T, expected map[rune][]uint16, weights uca.W } func loadExpectedWeights(t *testing.T, weights string) map[rune][]uint16 { - fullpath := fmt.Sprintf("testdata/mysqldata/%s.json", weights) + fullpath := fmt.Sprintf("../testdata/mysqldata/%s.json", weights) weightsMysqlFile, err := os.Open(fullpath) if err != nil { t.Skipf("failed to load %q (did you run 'colldump' locally?)", fullpath) diff --git a/go/mysql/collations/internal/uca/contractions.go b/go/mysql/collations/internal/uca/contractions.go index d894b0e206e..5866cf5bf53 100644 --- a/go/mysql/collations/internal/uca/contractions.go +++ b/go/mysql/collations/internal/uca/contractions.go @@ -17,93 +17,9 @@ limitations under the License. package uca import ( - "fmt" - "vitess.io/vitess/go/mysql/collations/charset" ) -type trie struct { - children map[rune]*trie - weights []uint16 -} - -func (t *trie) walkCharset(cs charset.Charset, remainder []byte, depth int) ([]uint16, []byte, int) { - if len(remainder) > 0 { - cp, width := cs.DecodeRune(remainder) - if cp == charset.RuneError && width < 3 { - return nil, nil, 0 - } - if ch := t.children[cp]; ch != nil { - return ch.walkCharset(cs, remainder[width:], depth+1) - } - } - return t.weights, remainder, depth + 1 -} - -func (t *trie) insert(path []rune, weights []uint16) { - if len(path) == 0 { - if t.weights != nil { - panic("duplicate contraction") - } - t.weights = weights - return - } - - if t.children == nil { - t.children = make(map[rune]*trie) - } - ch := t.children[path[0]] - if ch == nil { - ch = &trie{} - t.children[path[0]] = ch - } - ch.insert(path[1:], weights) -} - -type trieContractor struct { - tr trie -} - -func (ctr *trieContractor) insert(c *Contraction) { - if len(c.Path) < 2 { - panic("contraction is too short") - } - if len(c.Weights)%3 != 0 { - panic(fmt.Sprintf("weights are not well-formed: %#v has len=%d", c.Weights, len(c.Weights))) - } - if c.Contextual && len(c.Path) != 2 { - panic("contextual contractions can only span 2 codepoints") - } - ctr.tr.insert(c.Path, c.Weights) -} - -func (ctr *trieContractor) Find(cs charset.Charset, cp rune, remainder []byte) ([]uint16, []byte, int) { - if tr := ctr.tr.children[cp]; tr != nil { - return tr.walkCharset(cs, remainder, 0) - } - return nil, nil, 0 -} - -func (ctr *trieContractor) FindContextual(cp, prev rune) []uint16 { - if tr := ctr.tr.children[cp]; tr != nil { - if trc := tr.children[prev]; trc != nil { - return trc.weights - } - } - return nil -} - -func NewTrieContractor(all []Contraction) Contractor { - if len(all) == 0 { - return nil - } - ctr := &trieContractor{} - for _, c := range all { - ctr.insert(&c) - } - return ctr -} - type Contraction struct { Path []rune Weights []uint16 diff --git a/go/mysql/collations/tools/makecolldata/codegen/tablegen.go b/go/mysql/collations/tools/makecolldata/codegen/tablegen.go index b12d32f59d7..e1549c23bff 100644 --- a/go/mysql/collations/tools/makecolldata/codegen/tablegen.go +++ b/go/mysql/collations/tools/makecolldata/codegen/tablegen.go @@ -224,20 +224,6 @@ func (tg *TableGenerator) entryForCodepoint(codepoint rune) (*page, *entry) { return page, entry } -func (tg *TableGenerator) Add900(codepoint rune, rhs [][3]uint16) { - page, entry := tg.entryForCodepoint(codepoint) - page.entryCount++ - - for i, weights := range rhs { - if i >= uca.MaxCollationElementsPerCodepoint { - break - } - for _, we := range weights { - entry.weights = append(entry.weights, we) - } - } -} - func (tg *TableGenerator) Add(codepoint rune, weights []uint16) { page, entry := tg.entryForCodepoint(codepoint) page.entryCount++ @@ -248,22 +234,6 @@ func (tg *TableGenerator) Add(codepoint rune, weights []uint16) { entry.weights = append(entry.weights, weights...) } -func (tg *TableGenerator) AddFromAllkeys(lhs []rune, rhs [][]int, vars []int) { - if len(lhs) > 1 || lhs[0] > tg.maxChar { - // TODO: support contractions - return - } - - var weights [][3]uint16 - for _, we := range rhs { - if len(we) != 3 { - panic("non-triplet weight in allkeys.txt") - } - weights = append(weights, [3]uint16{uint16(we[0]), uint16(we[1]), uint16(we[2])}) - } - tg.Add900(lhs[0], weights) -} - func (tg *TableGenerator) writePage(g *Generator, p *page, layout uca.Layout) string { var weights []uint16 diff --git a/go/vt/logutil/logger.go b/go/vt/logutil/logger.go index 087c310011c..47c3f124238 100644 --- a/go/vt/logutil/logger.go +++ b/go/vt/logutil/logger.go @@ -206,27 +206,6 @@ func (cl *CallbackLogger) Printf(format string, v ...any) { }) } -// ChannelLogger is a Logger that sends the logging events through a channel for -// consumption. -type ChannelLogger struct { - CallbackLogger - C chan *logutilpb.Event -} - -// NewChannelLogger returns a CallbackLogger which will write the data -// on a channel -func NewChannelLogger(size int) *ChannelLogger { - c := make(chan *logutilpb.Event, size) - return &ChannelLogger{ - CallbackLogger: CallbackLogger{ - f: func(e *logutilpb.Event) { - c <- e - }, - }, - C: c, - } -} - // MemoryLogger keeps the logging events in memory. // All protected by a mutex. type MemoryLogger struct { diff --git a/go/vt/logutil/logger_test.go b/go/vt/logutil/logger_test.go index 0eb4edb2b93..ce25543da5f 100644 --- a/go/vt/logutil/logger_test.go +++ b/go/vt/logutil/logger_test.go @@ -112,44 +112,15 @@ func TestMemoryLogger(t *testing.T) { } } -func TestChannelLogger(t *testing.T) { - cl := NewChannelLogger(10) - cl.Infof("test %v", 123) - cl.Warningf("test %v", 123) - cl.Errorf("test %v", 123) - cl.Printf("test %v", 123) - close(cl.C) - - count := 0 - for e := range cl.C { - if got, want := e.Value, "test 123"; got != want { - t.Errorf("e.Value = %q, want %q", got, want) - } - if e.File != "logger_test.go" { - t.Errorf("Invalid file name: %v", e.File) - } - count++ - } - if got, want := count, 4; got != want { - t.Errorf("count = %v, want %v", got, want) - } -} - func TestTeeLogger(t *testing.T) { - ml := NewMemoryLogger() - cl := NewChannelLogger(10) - tl := NewTeeLogger(ml, cl) + ml1 := NewMemoryLogger() + ml2 := NewMemoryLogger() + tl := NewTeeLogger(ml1, ml2) tl.Infof("test infof %v %v", 1, 2) tl.Warningf("test warningf %v %v", 2, 3) tl.Errorf("test errorf %v %v", 3, 4) tl.Printf("test printf %v %v", 4, 5) - close(cl.C) - - clEvents := []*logutilpb.Event{} - for e := range cl.C { - clEvents = append(clEvents, e) - } wantEvents := []*logutilpb.Event{ {Level: logutilpb.Level_INFO, Value: "test infof 1 2"}, @@ -159,7 +130,7 @@ func TestTeeLogger(t *testing.T) { } wantFile := "logger_test.go" - for i, events := range [][]*logutilpb.Event{ml.Events, clEvents} { + for i, events := range [][]*logutilpb.Event{ml1.Events, ml2.Events} { if got, want := len(events), len(wantEvents); got != want { t.Fatalf("[%v] len(events) = %v, want %v", i, got, want) } diff --git a/go/vt/topotools/split.go b/go/vt/topotools/split.go index 0671c2c5cb8..9da6b99878f 100644 --- a/go/vt/topotools/split.go +++ b/go/vt/topotools/split.go @@ -17,10 +17,8 @@ limitations under the License. package topotools import ( - "context" "errors" "fmt" - "sort" "vitess.io/vitess/go/vt/key" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -76,185 +74,3 @@ func combineKeyRanges(shards []*topo.ShardInfo) (*topodatapb.KeyRange, error) { } return result, nil } - -// OverlappingShards contains sets of shards that overlap which each-other. -// With this library, there is no guarantee of which set will be left or right. -type OverlappingShards struct { - Left []*topo.ShardInfo - Right []*topo.ShardInfo -} - -// ContainsShard returns true if either Left or Right lists contain -// the provided Shard. -func (os *OverlappingShards) ContainsShard(shardName string) bool { - for _, l := range os.Left { - if l.ShardName() == shardName { - return true - } - } - for _, r := range os.Right { - if r.ShardName() == shardName { - return true - } - } - return false -} - -// OverlappingShardsForShard returns the OverlappingShards object -// from the list that has he provided shard, or nil -func OverlappingShardsForShard(os []*OverlappingShards, shardName string) *OverlappingShards { - for _, o := range os { - if o.ContainsShard(shardName) { - return o - } - } - return nil -} - -// FindOverlappingShards will return an array of OverlappingShards -// for the provided keyspace. -// We do not support more than two overlapping shards (for instance, -// having 40-80, 40-60 and 40-50 in the same keyspace is not supported and -// will return an error). -// If shards don't perfectly overlap, they are not returned. -func FindOverlappingShards(ctx context.Context, ts *topo.Server, keyspace string) ([]*OverlappingShards, error) { - shardMap, err := ts.FindAllShardsInKeyspace(ctx, keyspace, nil) - if err != nil { - return nil, err - } - - return findOverlappingShards(shardMap) -} - -// findOverlappingShards does the work for FindOverlappingShards but -// can be called on test data too. -func findOverlappingShards(shardMap map[string]*topo.ShardInfo) ([]*OverlappingShards, error) { - - var result []*OverlappingShards - - for len(shardMap) > 0 { - var left []*topo.ShardInfo - var right []*topo.ShardInfo - - // get the first value from the map, seed our left array with it - var name string - var si *topo.ShardInfo - for name, si = range shardMap { - break - } - left = append(left, si) - delete(shardMap, name) - - // keep adding entries until we have no more to add - for { - foundOne := false - - // try left to right - si := findIntersectingShard(shardMap, left) - if si != nil { - if intersect(si, right) { - return nil, fmt.Errorf("shard %v intersects with more than one shard, this is not supported", si.ShardName()) - } - foundOne = true - right = append(right, si) - } - - // try right to left - si = findIntersectingShard(shardMap, right) - if si != nil { - if intersect(si, left) { - return nil, fmt.Errorf("shard %v intersects with more than one shard, this is not supported", si.ShardName()) - } - foundOne = true - left = append(left, si) - } - - // we haven't found anything new, we're done - if !foundOne { - break - } - } - - // save what we found if it's good - if len(right) > 0 { - // sort both lists - sort.Sort(shardInfoList(left)) - sort.Sort(shardInfoList(right)) - - // we should not have holes on either side - hasHoles := false - for i := 0; i < len(left)-1; i++ { - if string(left[i].KeyRange.End) != string(left[i+1].KeyRange.Start) { - hasHoles = true - } - } - for i := 0; i < len(right)-1; i++ { - if string(right[i].KeyRange.End) != string(right[i+1].KeyRange.Start) { - hasHoles = true - } - } - if hasHoles { - continue - } - - // the two sides should match - if !key.KeyRangeStartEqual(left[0].KeyRange, right[0].KeyRange) { - continue - } - if !key.KeyRangeEndEqual(left[len(left)-1].KeyRange, right[len(right)-1].KeyRange) { - continue - } - - // all good, we have a valid overlap - result = append(result, &OverlappingShards{ - Left: left, - Right: right, - }) - } - } - return result, nil -} - -// findIntersectingShard will go through the map and take the first -// entry in there that intersect with the source array, remove it from -// the map, and return it -func findIntersectingShard(shardMap map[string]*topo.ShardInfo, sourceArray []*topo.ShardInfo) *topo.ShardInfo { - for name, si := range shardMap { - for _, sourceShardInfo := range sourceArray { - if si.KeyRange == nil || sourceShardInfo.KeyRange == nil || key.KeyRangeIntersect(si.KeyRange, sourceShardInfo.KeyRange) { - delete(shardMap, name) - return si - } - } - } - return nil -} - -// intersect returns true if the provided shard intersect with any shard -// in the destination array -func intersect(si *topo.ShardInfo, allShards []*topo.ShardInfo) bool { - for _, shard := range allShards { - if key.KeyRangeIntersect(si.KeyRange, shard.KeyRange) { - return true - } - } - return false -} - -// shardInfoList is a helper type to sort ShardInfo array by keyrange -type shardInfoList []*topo.ShardInfo - -// Len is part of sort.Interface -func (sil shardInfoList) Len() int { - return len(sil) -} - -// Less is part of sort.Interface -func (sil shardInfoList) Less(i, j int) bool { - return string(sil[i].KeyRange.Start) < string(sil[j].KeyRange.Start) -} - -// Swap is part of sort.Interface -func (sil shardInfoList) Swap(i, j int) { - sil[i], sil[j] = sil[j], sil[i] -} diff --git a/go/vt/topotools/split_test.go b/go/vt/topotools/split_test.go index 003dc767317..6e93ee345d3 100644 --- a/go/vt/topotools/split_test.go +++ b/go/vt/topotools/split_test.go @@ -17,7 +17,6 @@ limitations under the License. package topotools import ( - "encoding/hex" "testing" "github.com/stretchr/testify/assert" @@ -27,75 +26,6 @@ import ( topodatapb "vitess.io/vitess/go/vt/proto/topodata" ) -// helper methods for tests to be shorter - -func hki(hexValue string) []byte { - k, err := hex.DecodeString(hexValue) - if err != nil { - panic(err) - } - return k -} - -func si(start, end string) *topo.ShardInfo { - s := hki(start) - e := hki(end) - return topo.NewShardInfo("keyspace", start+"-"+end, &topodatapb.Shard{ - KeyRange: &topodatapb.KeyRange{ - Start: s, - End: e, - }, - }, nil) -} - -type expectedOverlappingShard struct { - left []string - right []string -} - -func overlappingShardMatch(ol []*topo.ShardInfo, or []*topo.ShardInfo, e expectedOverlappingShard) bool { - if len(ol)+1 != len(e.left) { - return false - } - if len(or)+1 != len(e.right) { - return false - } - for i, l := range ol { - if l.ShardName() != e.left[i]+"-"+e.left[i+1] { - return false - } - } - for i, r := range or { - if r.ShardName() != e.right[i]+"-"+e.right[i+1] { - return false - } - } - return true -} - -func compareResultLists(t *testing.T, os []*OverlappingShards, expected []expectedOverlappingShard) { - if len(os) != len(expected) { - t.Errorf("Unexpected result length, got %v, want %v", len(os), len(expected)) - return - } - - for _, o := range os { - found := false - for _, e := range expected { - if overlappingShardMatch(o.Left, o.Right, e) { - found = true - } - if overlappingShardMatch(o.Right, o.Left, e) { - found = true - } - } - if !found { - t.Errorf("OverlappingShard %v not found in expected %v", o, expected) - return - } - } -} - func TestValidateForReshard(t *testing.T) { testcases := []struct { sources []string @@ -169,191 +99,3 @@ func TestValidateForReshard(t *testing.T) { } } } - -func TestFindOverlappingShardsNoOverlap(t *testing.T) { - var shardMap map[string]*topo.ShardInfo - var os []*OverlappingShards - var err error - - // no shards - shardMap = map[string]*topo.ShardInfo{} - os, err = findOverlappingShards(shardMap) - if len(os) != 0 || err != nil { - t.Errorf("empty shard map: %v %v", os, err) - } - - // just one shard, full keyrange - shardMap = map[string]*topo.ShardInfo{ - "0": {}, - } - os, err = findOverlappingShards(shardMap) - if len(os) != 0 || err != nil { - t.Errorf("just one shard, full keyrange: %v %v", os, err) - } - - // just one shard, partial keyrange - shardMap = map[string]*topo.ShardInfo{ - "-80": si("", "80"), - } - os, err = findOverlappingShards(shardMap) - if len(os) != 0 || err != nil { - t.Errorf("just one shard, partial keyrange: %v %v", os, err) - } - - // two non-overlapping shards - shardMap = map[string]*topo.ShardInfo{ - "-80": si("", "80"), - "80": si("80", ""), - } - os, err = findOverlappingShards(shardMap) - if len(os) != 0 || err != nil { - t.Errorf("two non-overlapping shards: %v %v", os, err) - } - - // shards with holes - shardMap = map[string]*topo.ShardInfo{ - "-80": si("", "80"), - "80": si("80", ""), - "-20": si("", "20"), - // HOLE: "20-40": si("20", "40"), - "40-60": si("40", "60"), - "60-80": si("60", "80"), - } - os, err = findOverlappingShards(shardMap) - if len(os) != 0 || err != nil { - t.Errorf("shards with holes: %v %v", os, err) - } - - // shards not overlapping - shardMap = map[string]*topo.ShardInfo{ - "-80": si("", "80"), - "80": si("80", ""), - // MISSING: "-20": si("", "20"), - "20-40": si("20", "40"), - "40-60": si("40", "60"), - "60-80": si("60", "80"), - } - os, err = findOverlappingShards(shardMap) - if len(os) != 0 || err != nil { - t.Errorf("shards not overlapping: %v %v", os, err) - } -} - -func TestFindOverlappingShardsOverlap(t *testing.T) { - var shardMap map[string]*topo.ShardInfo - var os []*OverlappingShards - var err error - - // split in progress - shardMap = map[string]*topo.ShardInfo{ - "-80": si("", "80"), - "80": si("80", ""), - "-40": si("", "40"), - "40-80": si("40", "80"), - } - os, err = findOverlappingShards(shardMap) - if len(os) != 1 || err != nil { - t.Errorf("split in progress: %v %v", os, err) - } - compareResultLists(t, os, []expectedOverlappingShard{ - { - left: []string{"", "80"}, - right: []string{"", "40", "80"}, - }, - }) - - // 1 to 4 split - shardMap = map[string]*topo.ShardInfo{ - "-": si("", ""), - "-40": si("", "40"), - "40-80": si("40", "80"), - "80-c0": si("80", "c0"), - "c0-": si("c0", ""), - } - os, err = findOverlappingShards(shardMap) - if len(os) != 1 || err != nil { - t.Errorf("1 to 4 split: %v %v", os, err) - } - compareResultLists(t, os, []expectedOverlappingShard{ - { - left: []string{"", ""}, - right: []string{"", "40", "80", "c0", ""}, - }, - }) - - // 2 to 3 split - shardMap = map[string]*topo.ShardInfo{ - "-40": si("", "40"), - "40-80": si("40", "80"), - "80-": si("80", ""), - "-30": si("", "30"), - "30-60": si("30", "60"), - "60-80": si("60", "80"), - } - os, err = findOverlappingShards(shardMap) - if len(os) != 1 || err != nil { - t.Errorf("2 to 3 split: %v %v", os, err) - } - compareResultLists(t, os, []expectedOverlappingShard{ - { - left: []string{"", "40", "80"}, - right: []string{"", "30", "60", "80"}, - }, - }) - - // multiple concurrent splits - shardMap = map[string]*topo.ShardInfo{ - "-80": si("", "80"), - "80-": si("80", ""), - "-40": si("", "40"), - "40-80": si("40", "80"), - "80-c0": si("80", "c0"), - "c0-": si("c0", ""), - } - os, err = findOverlappingShards(shardMap) - if len(os) != 2 || err != nil { - t.Errorf("2 to 3 split: %v %v", os, err) - } - compareResultLists(t, os, []expectedOverlappingShard{ - { - left: []string{"", "80"}, - right: []string{"", "40", "80"}, - }, - { - left: []string{"80", ""}, - right: []string{"80", "c0", ""}, - }, - }) - - // find a shard in there - if o := OverlappingShardsForShard(os, "-60"); o != nil { - t.Errorf("Found a shard where I shouldn't have!") - } - if o := OverlappingShardsForShard(os, "-40"); o == nil { - t.Errorf("Found no shard where I should have!") - } else { - compareResultLists(t, []*OverlappingShards{o}, - []expectedOverlappingShard{ - { - left: []string{"", "80"}, - right: []string{"", "40", "80"}, - }, - }) - } -} - -func TestFindOverlappingShardsErrors(t *testing.T) { - var shardMap map[string]*topo.ShardInfo - var err error - - // 3 overlapping shards - shardMap = map[string]*topo.ShardInfo{ - "-20": si("", "20"), - "-40": si("", "40"), - "-80": si("", "80"), - } - _, err = findOverlappingShards(shardMap) - if err == nil { - t.Errorf("3 overlapping shards with no error") - } -} diff --git a/go/vt/topotools/utils.go b/go/vt/topotools/utils.go index 6d1522e04e7..ae70b299bdd 100644 --- a/go/vt/topotools/utils.go +++ b/go/vt/topotools/utils.go @@ -17,10 +17,8 @@ limitations under the License. package topotools import ( - "reflect" - "sync" - "context" + "sync" "vitess.io/vitess/go/vt/topo" @@ -101,37 +99,3 @@ func SortedTabletMap(tabletMap map[string]*topo.TabletInfo) (map[string]*topo.Ta } return replicaMap, primaryMap } - -// CopyMapKeys copies keys from map m into a new slice with the -// type specified by typeHint. Reflection can't make a new slice type -// just based on the key type AFAICT. -func CopyMapKeys(m any, typeHint any) any { - mapVal := reflect.ValueOf(m) - keys := reflect.MakeSlice(reflect.TypeOf(typeHint), 0, mapVal.Len()) - for _, k := range mapVal.MapKeys() { - keys = reflect.Append(keys, k) - } - return keys.Interface() -} - -// CopyMapValues copies values from map m into a new slice with the -// type specified by typeHint. Reflection can't make a new slice type -// just based on the key type AFAICT. -func CopyMapValues(m any, typeHint any) any { - mapVal := reflect.ValueOf(m) - vals := reflect.MakeSlice(reflect.TypeOf(typeHint), 0, mapVal.Len()) - for _, k := range mapVal.MapKeys() { - vals = reflect.Append(vals, mapVal.MapIndex(k)) - } - return vals.Interface() -} - -// MapKeys returns an array with th provided map keys. -func MapKeys(m any) []any { - keys := make([]any, 0, 16) - mapVal := reflect.ValueOf(m) - for _, kv := range mapVal.MapKeys() { - keys = append(keys, kv.Interface()) - } - return keys -} diff --git a/go/vt/vtctl/reparentutil/util.go b/go/vt/vtctl/reparentutil/util.go index db260cc36b6..d049ab9b05a 100644 --- a/go/vt/vtctl/reparentutil/util.go +++ b/go/vt/vtctl/reparentutil/util.go @@ -22,6 +22,7 @@ import ( "sync" "time" + "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" "vitess.io/vitess/go/mysql/replication" @@ -32,7 +33,6 @@ import ( "vitess.io/vitess/go/vt/logutil" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/topoproto" - "vitess.io/vitess/go/vt/topotools" "vitess.io/vitess/go/vt/vtctl/reparentutil/promotionrule" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tmclient" @@ -219,7 +219,7 @@ func ShardReplicationStatuses(ctx context.Context, ts *topo.Server, tmc tmclient if err != nil { return nil, nil, err } - tablets := topotools.CopyMapValues(tabletMap, []*topo.TabletInfo{}).([]*topo.TabletInfo) + tablets := maps.Values(tabletMap) log.Infof("Gathering tablet replication status for: %v", tablets) wg := sync.WaitGroup{}