diff --git a/any_table.go b/any_table.go index e629507..fb60806 100644 --- a/any_table.go +++ b/any_table.go @@ -1,6 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + package statedb import ( + "fmt" "iter" ) @@ -12,8 +16,13 @@ type AnyTable struct { } func (t AnyTable) All(txn ReadTxn) iter.Seq2[any, Revision] { + all, _ := t.AllWatch(txn) + return all +} + +func (t AnyTable) AllWatch(txn ReadTxn) (iter.Seq2[any, Revision], <-chan struct{}) { indexTxn := txn.getTxn().mustIndexReadTxn(t.Meta, PrimaryIndexPos) - return partSeq[any](indexTxn.Iterator()) + return partSeq[any](indexTxn.Iterator()), indexTxn.RootWatch() } func (t AnyTable) UnmarshalYAML(data []byte) (any, error) { @@ -38,22 +47,85 @@ func (t AnyTable) Delete(txn WriteTxn, obj any) (old any, hadOld bool, err error return } -func (t AnyTable) Prefix(txn ReadTxn, key string) iter.Seq2[any, Revision] { - indexTxn := txn.getTxn().mustIndexReadTxn(t.Meta, PrimaryIndexPos) - iter, _ := indexTxn.Prefix([]byte(key)) - if indexTxn.unique { - return partSeq[any](iter) +func (t AnyTable) Get(txn ReadTxn, index string, key string) (any, Revision, bool, error) { + itxn, rawKey, err := t.queryIndex(txn, index, key) + if err != nil { + return nil, 0, false, err + } + if itxn.unique { + obj, _, ok := itxn.Get(rawKey) + return obj.data, obj.revision, ok, nil } - return nonUniqueSeq[any](iter, true, []byte(key)) + // For non-unique indexes we need to prefix search and make sure to fully + // match the secondary key. + iter, _ := itxn.Prefix(rawKey) + for { + k, obj, ok := iter.Next() + if !ok { + break + } + secondary, _ := decodeNonUniqueKey(k) + if len(secondary) == len(rawKey) { + return obj.data, obj.revision, true, nil + } + } + return nil, 0, false, nil } -func (t AnyTable) LowerBound(txn ReadTxn, key string) iter.Seq2[any, Revision] { - indexTxn := txn.getTxn().mustIndexReadTxn(t.Meta, PrimaryIndexPos) - iter := indexTxn.LowerBound([]byte(key)) - if indexTxn.unique { - return partSeq[any](iter) +func (t AnyTable) Prefix(txn ReadTxn, index string, key string) (iter.Seq2[any, Revision], error) { + itxn, rawKey, err := t.queryIndex(txn, index, key) + if err != nil { + return nil, err + } + iter, _ := itxn.Prefix(rawKey) + if itxn.unique { + return partSeq[any](iter), nil + } + return nonUniqueSeq[any](iter, true, rawKey), nil +} + +func (t AnyTable) LowerBound(txn ReadTxn, index string, key string) (iter.Seq2[any, Revision], error) { + itxn, rawKey, err := t.queryIndex(txn, index, key) + if err != nil { + return nil, err + } + iter := itxn.LowerBound(rawKey) + if itxn.unique { + return partSeq[any](iter), nil + } + return nonUniqueLowerBoundSeq[any](iter, rawKey), nil +} + +func (t AnyTable) List(txn ReadTxn, index string, key string) (iter.Seq2[any, Revision], error) { + itxn, rawKey, err := t.queryIndex(txn, index, key) + if err != nil { + return nil, err + } + iter, _ := itxn.Prefix(rawKey) + if itxn.unique { + // Unique index means that there can be only a single matching object. + // Doing a Get() is more efficient than constructing an iterator. + value, _, ok := itxn.Get(rawKey) + return func(yield func(any, Revision) bool) { + if ok { + yield(value.data, value.revision) + } + }, nil + } + return nonUniqueSeq[any](iter, false, rawKey), nil +} + +func (t AnyTable) queryIndex(txn ReadTxn, index string, key string) (indexReadTxn, []byte, error) { + indexer := t.Meta.getIndexer(index) + if indexer == nil { + return indexReadTxn{}, nil, fmt.Errorf("invalid index %q", index) + } + rawKey, err := indexer.fromString(key) + if err != nil { + return indexReadTxn{}, nil, err } - return nonUniqueLowerBoundSeq[any](iter, []byte(key)) + itxn, err := txn.getTxn().indexReadTxn(t.Meta, indexer.pos) + return itxn, rawKey, err } func (t AnyTable) TableHeader() []string { diff --git a/cell.go b/cell.go index 38b0658..9a23395 100644 --- a/cell.go +++ b/cell.go @@ -16,6 +16,7 @@ var Cell = cell.Module( cell.Provide( newHiveDB, + ScriptCommands, ), ) diff --git a/db.go b/db.go index 9d8b0cb..be6f18e 100644 --- a/db.go +++ b/db.go @@ -208,9 +208,20 @@ func (db *DB) WriteTxn(table TableMeta, tables ...TableMeta) WriteTxn { lockAt := time.Now() smus.Lock() acquiredAt := time.Now() - root := *db.root.Load() tableEntries := make([]*tableEntry, len(root)) + + txn := &txn{ + db: db, + root: root, + handle: db.handleName, + acquiredAt: time.Now(), + writeTxn: writeTxn{ + modifiedTables: tableEntries, + smus: smus, + }, + } + var tableNames []string for _, table := range allTables { tableEntry := root[table.tablePos()] @@ -223,10 +234,12 @@ func (db *DB) WriteTxn(table TableMeta, tables ...TableMeta) WriteTxn { table.Name(), table.sortableMutex().AcquireDuration(), ) + table.acquired(txn) } // Sort the table names so they always appear ordered in metrics. sort.Strings(tableNames) + txn.tableNames = tableNames db.metrics.WriteTxnTotalAcquisition( db.handleName, @@ -234,15 +247,6 @@ func (db *DB) WriteTxn(table TableMeta, tables ...TableMeta) WriteTxn { acquiredAt.Sub(lockAt), ) - txn := &txn{ - db: db, - root: root, - modifiedTables: tableEntries, - smus: smus, - acquiredAt: acquiredAt, - tableNames: tableNames, - handle: db.handleName, - } runtime.SetFinalizer(txn, txnFinalizer) return txn } diff --git a/db_test.go b/db_test.go index 4459c8d..a12d105 100644 --- a/db_test.go +++ b/db_test.go @@ -11,6 +11,8 @@ import ( "log/slog" "runtime" "slices" + "strconv" + "strings" "testing" "time" @@ -47,6 +49,17 @@ func (t testObject) String() string { return fmt.Sprintf("testObject{ID: %d, Tags: %v}", t.ID, t.Tags) } +func (t testObject) TableHeader() []string { + return []string{"ID", "Tags"} +} + +func (t testObject) TableRow() []string { + return []string{ + strconv.FormatUint(uint64(t.ID), 10), + strings.Join(slices.Collect(t.Tags.All()), ", "), + } +} + var ( idIndex = Index[testObject, uint64]{ Name: "id", @@ -54,7 +67,11 @@ var ( return index.NewKeySet(index.Uint64(t.ID)) }, FromKey: index.Uint64, - Unique: true, + FromString: func(key string) (index.Key, error) { + v, err := strconv.ParseUint(key, 10, 64) + return index.Uint64(v), err + }, + Unique: true, } tagsIndex = Index[testObject, string]{ @@ -62,11 +79,22 @@ var ( FromObject: func(t testObject) index.KeySet { return index.Set(t.Tags) }, - FromKey: index.String, - Unique: false, + FromKey: index.String, + FromString: index.FromString, + Unique: false, } ) +func newTestObjectTable(t testing.TB, name string, secondaryIndexers ...Indexer[testObject]) RWTable[testObject] { + table, err := NewTable( + name, + idIndex, + secondaryIndexers..., + ) + require.NoError(t, err, "NewTable[testObject]") + return table +} + const ( INDEX_TAGS = true NO_INDEX_TAGS = false @@ -82,12 +110,7 @@ func newTestDBWithMetrics(t testing.TB, metrics Metrics, secondaryIndexers ...In var ( db *DB ) - table, err := NewTable( - "test", - idIndex, - secondaryIndexers..., - ) - require.NoError(t, err, "NewTable[testObject]") + table := newTestObjectTable(t, "test", secondaryIndexers...) h := hive.New( cell.Provide(func() Metrics { return metrics }), @@ -237,7 +260,7 @@ func TestDB_Prefix(t *testing.T) { txn := db.ReadTxn() iter, watch := table.PrefixWatch(txn, tagsIndex.Query("ab")) - require.Equal(t, Collect(Map(iter, testObject.getID)), []uint64{71, 82}) + require.Equal(t, []uint64{71, 82}, Collect(Map(iter, testObject.getID))) select { case <-watch: diff --git a/go.mod b/go.mod index 0a8fe05..7878538 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.23 require ( github.com/cilium/hive v0.0.0-20241009102328-2ab688845f23 github.com/cilium/stream v0.0.0-20240209152734-a0792b51812d - github.com/rogpeppe/go-internal v1.11.0 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 go.uber.org/goleak v1.3.0 + golang.org/x/term v0.16.0 golang.org/x/time v0.5.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -34,7 +34,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect golang.org/x/sys v0.17.0 // indirect - golang.org/x/term v0.16.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.17.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/index/string.go b/index/string.go index 9a678f0..99430bc 100644 --- a/index/string.go +++ b/index/string.go @@ -12,6 +12,10 @@ func String(s string) Key { return []byte(s) } +func FromString(s string) (Key, error) { + return String(s), nil +} + func Stringer[T fmt.Stringer](s T) Key { return String(s.String()) } diff --git a/internal/time.go b/internal/time.go new file mode 100644 index 0000000..5463ea0 --- /dev/null +++ b/internal/time.go @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package internal + +import ( + "fmt" + "time" +) + +func PrettySince(t time.Time) string { + return PrettyDuration(time.Since(t)) +} + +func PrettyDuration(d time.Duration) string { + ago := float64(d) / float64(time.Microsecond) + + // micros + if ago < 1000.0 { + return fmt.Sprintf("%.1fus", ago) + } + + // millis + ago /= 1000.0 + if ago < 1000.0 { + return fmt.Sprintf("%.1fms", ago) + } + // secs + ago /= 1000.0 + if ago < 60.0 { + return fmt.Sprintf("%.1fs", ago) + } + // mins + ago /= 60.0 + if ago < 60.0 { + return fmt.Sprintf("%.1fm", ago) + } + // hours + ago /= 60.0 + return fmt.Sprintf("%.1fh", ago) +} diff --git a/iterator.go b/iterator.go index d84bb68..301ce33 100644 --- a/iterator.go +++ b/iterator.go @@ -115,13 +115,12 @@ func nonUniqueSeq[Obj any](iter *part.Iterator[object], prefixSearch bool, searc secondary, primary := decodeNonUniqueKey(key) - // The secondary key is shorter than what we're looking for, e.g. - // we match into the primary key. Keep searching for matching secondary - // keys. switch { case !prefixSearch && len(secondary) != len(searchKey): + // This a List(), thus secondary key must match length exactly. continue case prefixSearch && len(secondary) < len(searchKey): + // This is Prefix(), thus key must be equal or longer to search key. continue } diff --git a/quick_test.go b/quick_test.go index 00374fa..d3b7d04 100644 --- a/quick_test.go +++ b/quick_test.go @@ -36,8 +36,9 @@ var ( FromObject: func(t quickObj) index.KeySet { return index.NewKeySet(index.String(t.A)) }, - FromKey: index.String, - Unique: true, + FromKey: index.String, + FromString: index.FromString, + Unique: true, } bIndex = Index[quickObj, string]{ @@ -45,8 +46,9 @@ var ( FromObject: func(t quickObj) index.KeySet { return index.NewKeySet(index.String(t.B)) }, - FromKey: index.String, - Unique: false, + FromKey: index.String, + FromString: index.FromString, + Unique: false, } ) @@ -128,7 +130,9 @@ func TestDB_Quick(t *testing.T) { return false } } - for anyObj := range anyTable.Prefix(rtxn, a) { + anyObjs, err := anyTable.Prefix(rtxn, "a", a) + require.NoError(t, err, "AnyTable.Prefix") + for anyObj := range anyObjs { obj := anyObj.(quickObj) if !strings.HasPrefix(obj.A, a) { t.Logf("AnyTable.Prefix() returned object with wrong prefix via aIndex") @@ -142,7 +146,9 @@ func TestDB_Quick(t *testing.T) { return false } } - for anyObj := range anyTable.LowerBound(rtxn, a) { + anyObjs, err = anyTable.LowerBound(rtxn, "a", a) + require.NoError(t, err, "AnyTable.LowerBound") + for anyObj := range anyObjs { obj := anyObj.(quickObj) if cmp.Compare(obj.A, a) < 0 { t.Logf("AnyTable.LowerBound() order wrong") @@ -213,6 +219,16 @@ func TestDB_Quick(t *testing.T) { visited[obj.A] = struct{}{} } + anyObjs, err = anyTable.Prefix(rtxn, "b", b) + require.NoError(t, err, "AnyTable.Prefix") + for anyObj := range anyObjs { + obj := anyObj.(quickObj) + if !strings.HasPrefix(obj.B, b) { + t.Logf("AnyTable.Prefix() via bIndex has wrong prefix") + return false + } + } + visited = map[string]struct{}{} for obj := range table.LowerBound(rtxn, bIndex.Query(b)) { if cmp.Compare(obj.B, b) < 0 { @@ -226,6 +242,16 @@ func TestDB_Quick(t *testing.T) { visited[obj.A] = struct{}{} } + anyObjs, err = anyTable.LowerBound(rtxn, "b", b) + require.NoError(t, err, "AnyTable.LowerBound") + for anyObj := range anyObjs { + obj := anyObj.(quickObj) + if cmp.Compare(obj.B, b) < 0 { + t.Logf("AnyTable.LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(b)) + return false + } + } + // Iterating over the secondary index returns the objects in order // defined by the "B" key. if !isOrdered(t, Map(table.Prefix(rtxn, bIndex.Query("")), quickObj.getB)) { diff --git a/reconciler/status_test.go b/reconciler/status_test.go index 6365fbf..936f927 100644 --- a/reconciler/status_test.go +++ b/reconciler/status_test.go @@ -21,7 +21,7 @@ func TestStatusString(t *testing.T) { UpdatedAt: now, Error: "", } - assert.Regexp(t, `Pending \([0-9]+\.[0-9]+m?s ago\)`, s.String()) + assert.Regexp(t, `Pending \([0-9]+\.[0-9]+.+s ago\)`, s.String()) s.UpdatedAt = now.Add(-time.Hour) assert.Regexp(t, `Pending \([0-9]+\.[0-9]+h ago\)`, s.String()) @@ -30,14 +30,14 @@ func TestStatusString(t *testing.T) { UpdatedAt: now, Error: "", } - assert.Regexp(t, `Done \([0-9]+\.[0-9]+m?s ago\)`, s.String()) + assert.Regexp(t, `Done \([0-9]+\.[0-9]+.+s ago\)`, s.String()) s = Status{ Kind: StatusKindError, UpdatedAt: now, Error: "hey I'm an error", } - assert.Regexp(t, `Error: hey I'm an error \([0-9]+\.[0-9]+m?s ago\)`, s.String()) + assert.Regexp(t, `Error: hey I'm an error \([0-9]+\.[0-9]+.+s ago\)`, s.String()) } func sanitizeAgo(s string) string { diff --git a/reconciler/types.go b/reconciler/types.go index aa69beb..6d6342d 100644 --- a/reconciler/types.go +++ b/reconciler/types.go @@ -19,6 +19,7 @@ import ( "github.com/cilium/hive/job" "github.com/cilium/statedb" "github.com/cilium/statedb/index" + "github.com/cilium/statedb/internal" ) type Reconciler[Obj any] interface { @@ -148,30 +149,9 @@ func (s Status) IsPendingOrRefreshing() bool { func (s Status) String() string { if s.Kind == StatusKindError { - return fmt.Sprintf("Error: %s (%s ago)", s.Error, prettySince(s.UpdatedAt)) + return fmt.Sprintf("Error: %s (%s ago)", s.Error, internal.PrettySince(s.UpdatedAt)) } - return fmt.Sprintf("%s (%s ago)", s.Kind, prettySince(s.UpdatedAt)) -} - -func prettySince(t time.Time) string { - ago := float64(time.Now().Sub(t)) / float64(time.Millisecond) - // millis - if ago < 1000.0 { - return fmt.Sprintf("%.1fms", ago) - } - // secs - ago /= 1000.0 - if ago < 60.0 { - return fmt.Sprintf("%.1fs", ago) - } - // mins - ago /= 60.0 - if ago < 60.0 { - return fmt.Sprintf("%.1fm", ago) - } - // hours - ago /= 60.0 - return fmt.Sprintf("%.1fh", ago) + return fmt.Sprintf("%s (%s ago)", s.Kind, internal.PrettySince(s.UpdatedAt)) } var idGen atomic.Uint64 @@ -314,7 +294,7 @@ func (s StatusSet) String() string { b.WriteString(strings.Join(done, " ")) } b.WriteString(" (") - b.WriteString(prettySince(updatedAt)) + b.WriteString(internal.PrettySince(updatedAt)) b.WriteString(" ago)") return b.String() } diff --git a/regression_test.go b/regression_test.go index 1f26e70..567f4ac 100644 --- a/regression_test.go +++ b/regression_test.go @@ -237,5 +237,4 @@ func Test_Regression_Prefix_NonUnique(t *testing.T) { assert.EqualValues(t, "z", items[0].ID) assert.EqualValues(t, "b", items[1].ID) } - } diff --git a/script.go b/script.go new file mode 100644 index 0000000..609efc6 --- /dev/null +++ b/script.go @@ -0,0 +1,684 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package statedb + +import ( + "bytes" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "iter" + "maps" + "os" + "regexp" + "slices" + "strings" + "text/tabwriter" + "time" + + "github.com/cilium/hive" + "github.com/cilium/hive/script" + "gopkg.in/yaml.v3" +) + +func ScriptCommands(db *DB) hive.ScriptCmdOut { + subCmds := map[string]script.Cmd{ + "tables": TablesCmd(db), + "show": ShowCmd(db), + "cmp": CompareCmd(db), + "insert": InsertCmd(db), + "delete": DeleteCmd(db), + "get": GetCmd(db), + "prefix": PrefixCmd(db), + "list": ListCmd(db), + "lowerbound": LowerBoundCmd(db), + "initialized": InitializedCmd(db), + } + subCmdsList := strings.Join(slices.Collect(maps.Keys(subCmds)), ", ") + return hive.NewScriptCmd( + "db", + script.Command( + script.CmdUsage{ + Summary: "Inspect and manipulate StateDB", + Args: "cmd args...", + Detail: []string{ + "Supported commands: " + subCmdsList, + }, + }, + func(s *script.State, args ...string) (script.WaitFunc, error) { + if len(args) < 1 { + return nil, fmt.Errorf("expected command (%s)", subCmdsList) + } + cmd, ok := subCmds[args[0]] + if !ok { + return nil, fmt.Errorf("command not found, expected one of %s", subCmdsList) + } + wf, err := cmd.Run(s, args[1:]...) + if errors.Is(err, errUsage) { + s.Logf("usage: db %s %s\n", args[0], cmd.Usage().Args) + } + return wf, err + }, + ), + ) +} + +var errUsage = errors.New("bad arguments") + +func TablesCmd(db *DB) script.Cmd { + return script.Command( + script.CmdUsage{ + Summary: "Show StateDB tables", + Args: "table", + }, + func(s *script.State, args ...string) (script.WaitFunc, error) { + txn := db.ReadTxn() + tbls := db.GetTables(txn) + w := tabwriter.NewWriter(s.LogWriter(), 5, 4, 3, ' ', 0) + fmt.Fprintf(w, "Name\tObject count\tDeleted objects\tIndexes\tInitializers\tGo type\tLast WriteTxn\n") + for _, tbl := range tbls { + idxs := strings.Join(tbl.Indexes(), ", ") + fmt.Fprintf(w, "%s\t%d\t%d\t%s\t%v\t%T\t%s\n", + tbl.Name(), tbl.NumObjects(txn), tbl.numDeletedObjects(txn), idxs, tbl.PendingInitializers(txn), tbl.proto(), tbl.getAcquiredInfo()) + } + w.Flush() + return nil, nil + }, + ) +} + +func newCmdFlagSet() *flag.FlagSet { + return &flag.FlagSet{ + // Disable showing the normal usage. + Usage: func() {}, + } +} + +func InitializedCmd(db *DB) script.Cmd { + return script.Command( + script.CmdUsage{ + Summary: "Wait until all or specific tables have been initialized", + Args: "(-timeout=) table...", + }, + func(s *script.State, args ...string) (script.WaitFunc, error) { + txn := db.ReadTxn() + allTbls := db.GetTables(txn) + tbls := allTbls + + flags := newCmdFlagSet() + timeout := flags.Duration("timeout", 5*time.Second, "Maximum amount of time to wait for the table contents to match") + if err := flags.Parse(args); err != nil { + return nil, fmt.Errorf("%w: %s", errUsage, err) + } + timeoutChan := time.After(*timeout) + args = flags.Args() + + if len(args) > 0 { + // Specific tables requested, look them up. + tbls = make([]TableMeta, 0, len(args)) + for _, tableName := range args { + found := false + for _, tbl := range allTbls { + if tableName == tbl.Name() { + tbls = append(tbls, tbl) + found = true + break + } + } + if !found { + return nil, fmt.Errorf("table %q not found", tableName) + } + } + } + + for _, tbl := range tbls { + init, watch := tbl.Initialized(txn) + if init { + s.Logf("%s initialized\n", tbl.Name()) + continue + } + s.Logf("Waiting for %s to initialize (%v)...\n", tbl.Name(), tbl.PendingInitializers(txn)) + select { + case <-s.Context().Done(): + return nil, s.Context().Err() + case <-timeoutChan: + return nil, fmt.Errorf("timed out") + case <-watch: + s.Logf("%s initialized\n", tbl.Name()) + } + } + return nil, nil + }, + ) +} + +func ShowCmd(db *DB) script.Cmd { + return script.Command( + script.CmdUsage{ + Summary: "Show table", + Args: "(-o=) (-columns=col1,...) (-format={table,yaml,json}) table", + }, + func(s *script.State, args ...string) (script.WaitFunc, error) { + flags := newCmdFlagSet() + file := flags.String("o", "", "File to write to instead of stdout") + columns := flags.String("columns", "", "Comma-separated list of columns to write") + format := flags.String("format", "table", "Format to write in (table, yaml, json)") + if err := flags.Parse(args); err != nil { + return nil, fmt.Errorf("%w: %s", errUsage, err) + } + + var cols []string + if len(*columns) > 0 { + cols = strings.Split(*columns, ",") + } + + args = flags.Args() + if len(args) < 1 { + return nil, fmt.Errorf("%w: missing table name", errUsage) + } + tableName := args[0] + return func(*script.State) (stdout, stderr string, err error) { + var buf strings.Builder + var w io.Writer + if *file == "" { + w = &buf + } else { + f, err := os.OpenFile(s.Path(*file), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return "", "", fmt.Errorf("OpenFile(%s): %w", *file, err) + } + defer f.Close() + w = f + } + tbl, txn, err := getTable(db, tableName) + if err != nil { + return "", "", err + } + err = writeObjects(tbl, tbl.All(txn), w, cols, *format) + return buf.String(), "", err + }, nil + }) +} + +func CompareCmd(db *DB) script.Cmd { + return script.Command( + script.CmdUsage{ + Summary: "Compare table", + Args: "table file (-timeout=) (-grep=)", + }, + func(s *script.State, args ...string) (script.WaitFunc, error) { + flags := newCmdFlagSet() + timeout := flags.Duration("timeout", time.Second, "Maximum amount of time to wait for the table contents to match") + grep := flags.String("grep", "", "Grep the result rows and only compare matching ones") + err := flags.Parse(args) + args = flags.Args() + if err != nil || len(args) != 2 { + return nil, fmt.Errorf("%w: %s", errUsage, err) + } + + var grepRe *regexp.Regexp + if *grep != "" { + grepRe, err = regexp.Compile(*grep) + if err != nil { + return nil, fmt.Errorf("bad grep: %w", err) + } + } + + tableName := args[0] + + txn := db.ReadTxn() + meta := db.GetTable(txn, tableName) + if meta == nil { + return nil, fmt.Errorf("table %q not found", tableName) + } + tbl := AnyTable{Meta: meta} + header := tbl.TableHeader() + + data, err := os.ReadFile(s.Path(args[1])) + if err != nil { + return nil, fmt.Errorf("ReadFile(%s): %w", args[1], err) + } + lines := strings.Split(string(data), "\n") + lines = slices.DeleteFunc(lines, func(line string) bool { + return strings.TrimSpace(line) == "" + }) + if len(lines) < 1 { + return nil, fmt.Errorf("%q missing header line, e.g. %q", args[1], strings.Join(header, " ")) + } + + columnNames, columnPositions := splitHeaderLine(lines[0]) + columnIndexes, err := getColumnIndexes(columnNames, header) + if err != nil { + return nil, err + } + lines = lines[1:] + origLines := lines + timeoutChan := time.After(*timeout) + + for { + lines = origLines + + // Create the diff between 'lines' and the rows in the table. + equal := true + var diff bytes.Buffer + w := tabwriter.NewWriter(&diff, 5, 4, 3, ' ', 0) + fmt.Fprintf(w, " %s\n", joinByPositions(columnNames, columnPositions)) + + objs, watch := tbl.AllWatch(db.ReadTxn()) + for obj := range objs { + rowRaw := takeColumns(obj.(TableWritable).TableRow(), columnIndexes) + row := joinByPositions(rowRaw, columnPositions) + if grepRe != nil && !grepRe.Match([]byte(row)) { + continue + } + + if len(lines) == 0 { + equal = false + fmt.Fprintf(w, "- %s\n", row) + continue + } + line := lines[0] + splitLine := splitByPositions(line, columnPositions) + + if slices.Equal(rowRaw, splitLine) { + fmt.Fprintf(w, " %s\n", row) + } else { + fmt.Fprintf(w, "- %s\n", row) + fmt.Fprintf(w, "+ %s\n", line) + equal = false + } + lines = lines[1:] + } + for _, line := range lines { + fmt.Fprintf(w, "+ %s\n", line) + equal = false + } + if equal { + return nil, nil + } + w.Flush() + + select { + case <-s.Context().Done(): + return nil, s.Context().Err() + + case <-timeoutChan: + return nil, fmt.Errorf("table mismatch:\n%s", diff.String()) + + case <-watch: + } + } + }) +} + +func InsertCmd(db *DB) script.Cmd { + return script.Command( + script.CmdUsage{ + Summary: "Insert object into a table", + Args: "table path...", + }, + func(s *script.State, args ...string) (script.WaitFunc, error) { + return insertOrDelete(true, db, s, args...) + }, + ) +} + +func DeleteCmd(db *DB) script.Cmd { + return script.Command( + script.CmdUsage{ + Summary: "Delete an object from the table", + Args: "table path...", + }, + func(s *script.State, args ...string) (script.WaitFunc, error) { + return insertOrDelete(false, db, s, args...) + }, + ) +} + +func getTable(db *DB, tableName string) (*AnyTable, ReadTxn, error) { + txn := db.ReadTxn() + meta := db.GetTable(txn, tableName) + if meta == nil { + return nil, nil, fmt.Errorf("table %q not found", tableName) + } + return &AnyTable{Meta: meta}, txn, nil +} + +func insertOrDelete(insert bool, db *DB, s *script.State, args ...string) (script.WaitFunc, error) { + if len(args) < 2 { + return nil, fmt.Errorf("%w: expected table and path(s)", errUsage) + } + + tbl, _, err := getTable(db, args[0]) + if err != nil { + return nil, err + } + + wtxn := db.WriteTxn(tbl.Meta) + defer wtxn.Commit() + + for _, arg := range args[1:] { + data, err := os.ReadFile(s.Path(arg)) + if err != nil { + return nil, fmt.Errorf("ReadFile(%s): %w", arg, err) + } + parts := strings.Split(string(data), "---") + for _, part := range parts { + obj, err := tbl.UnmarshalYAML([]byte(part)) + if err != nil { + return nil, fmt.Errorf("Unmarshal(%s): %w", arg, err) + } + if insert { + _, _, err = tbl.Insert(wtxn, obj) + if err != nil { + return nil, fmt.Errorf("Insert(%s): %w", arg, err) + } + } else { + _, _, err = tbl.Delete(wtxn, obj) + if err != nil { + return nil, fmt.Errorf("Delete(%s): %w", arg, err) + } + + } + } + } + return nil, nil +} + +func PrefixCmd(db *DB) script.Cmd { + return queryCmd(db, queryCmdPrefix, "Query table by prefix") +} + +func LowerBoundCmd(db *DB) script.Cmd { + return queryCmd(db, queryCmdLowerBound, "Query table by lower bound search") +} + +func ListCmd(db *DB) script.Cmd { + return queryCmd(db, queryCmdList, "List objects in the table") +} + +func GetCmd(db *DB) script.Cmd { + return queryCmd(db, queryCmdGet, "Get the first matching object") +} + +const ( + queryCmdList = iota + queryCmdPrefix + queryCmdLowerBound + queryCmdGet +) + +func queryCmd(db *DB, query int, summary string) script.Cmd { + return script.Command( + script.CmdUsage{ + Summary: summary, + Args: "(-o=) (-columns=col1,...) (-format={table*,yaml,json}) (-index=) table key", + }, + func(s *script.State, args ...string) (script.WaitFunc, error) { + return runQueryCmd(query, db, s, args) + }, + ) +} + +func runQueryCmd(query int, db *DB, s *script.State, args []string) (script.WaitFunc, error) { + flags := newCmdFlagSet() + file := flags.String("o", "", "File to write results to instead of stdout") + index := flags.String("index", "", "Index to query") + format := flags.String("format", "table", "Format to write in (table, yaml, json)") + columns := flags.String("columns", "", "Comma-separated list of columns to write") + delete := flags.Bool("delete", false, "Delete all matching objects") + if err := flags.Parse(args); err != nil { + return nil, fmt.Errorf("%w: %s", errUsage, err) + } + + var cols []string + if len(*columns) > 0 { + cols = strings.Split(*columns, ",") + } + + args = flags.Args() + if len(args) < 2 { + return nil, fmt.Errorf("%w: expected table and key", errUsage) + } + + return func(*script.State) (stdout, stderr string, err error) { + tbl, txn, err := getTable(db, args[0]) + if err != nil { + return "", "", err + } + + var buf strings.Builder + var w io.Writer + if *file == "" { + w = &buf + } else { + f, err := os.OpenFile(s.Path(*file), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return "", "", fmt.Errorf("OpenFile(%s): %s", *file, err) + } + defer f.Close() + w = f + } + + var it iter.Seq2[any, uint64] + switch query { + case queryCmdList: + it, err = tbl.List(txn, *index, args[1]) + case queryCmdLowerBound: + it, err = tbl.LowerBound(txn, *index, args[1]) + case queryCmdPrefix: + it, err = tbl.Prefix(txn, *index, args[1]) + case queryCmdGet: + it, err = tbl.List(txn, *index, args[1]) + if err == nil { + it = firstOfSeq2(it) + } + default: + panic("unknown query enum") + } + if err != nil { + return "", "", fmt.Errorf("query: %w", err) + } + + err = writeObjects(tbl, it, w, cols, *format) + if err != nil { + return "", "", err + } + + if *delete { + wtxn := db.WriteTxn(tbl.Meta) + count := 0 + for obj := range it { + _, hadOld, err := tbl.Delete(wtxn, obj) + if err != nil { + wtxn.Abort() + return "", "", err + } + if hadOld { + count++ + } + } + s.Logf("Deleted %d objects\n", count) + wtxn.Commit() + } + + return buf.String(), "", err + }, nil +} + +func firstOfSeq2[A, B any](it iter.Seq2[A, B]) iter.Seq2[A, B] { + return func(yield func(a A, b B) bool) { + for a, b := range it { + yield(a, b) + break + } + } +} + +func writeObjects(tbl *AnyTable, it iter.Seq2[any, Revision], w io.Writer, columns []string, format string) error { + if len(columns) > 0 && format != "table" { + return fmt.Errorf("-columns not supported with non-table formats") + } + switch format { + case "yaml": + sep := []byte("---\n") + first := true + for obj := range it { + if !first { + w.Write(sep) + } + first = false + + out, err := yaml.Marshal(obj) + if err != nil { + return fmt.Errorf("yaml.Marshal: %w", err) + } + if _, err := w.Write(out); err != nil { + return err + } + } + return nil + case "json": + sep := []byte("\n") + first := true + for obj := range it { + if !first { + w.Write(sep) + } + first = false + + out, err := json.Marshal(obj) + if err != nil { + return fmt.Errorf("json.Marshal: %w", err) + } + if _, err := w.Write(out); err != nil { + return err + } + } + return nil + case "table": + header := tbl.TableHeader() + if header == nil { + return fmt.Errorf("objects in table %q not TableWritable", tbl.Meta.Name()) + } + + var idxs []int + var err error + if len(columns) > 0 { + idxs, err = getColumnIndexes(columns, header) + header = columns + } else { + idxs, err = getColumnIndexes(header, header) + } + if err != nil { + return err + } + tw := tabwriter.NewWriter(w, 5, 4, 3, ' ', 0) + fmt.Fprintf(tw, "%s\n", strings.Join(header, "\t")) + + for obj := range it { + row := takeColumns(obj.(TableWritable).TableRow(), idxs) + fmt.Fprintf(tw, "%s\n", strings.Join(row, "\t")) + } + return tw.Flush() + } + return fmt.Errorf("unknown format %q, expected table, yaml or json", format) +} + +func takeColumns[T any](xs []T, idxs []int) []T { + // Invariant: idxs is sorted so can set in-place. + for i, idx := range idxs { + xs[i] = xs[idx] + } + return xs[:len(idxs)] +} + +func getColumnIndexes(names []string, header []string) ([]int, error) { + columnIndexes := make([]int, 0, len(header)) +loop: + for _, name := range names { + for i, name2 := range header { + if strings.EqualFold(name, name2) { + columnIndexes = append(columnIndexes, i) + continue loop + } + } + return nil, fmt.Errorf("column %q not part of %v", name, header) + } + return columnIndexes, nil +} + +// splitHeaderLine takes a header of column names separated by any +// number of whitespaces and returns the names and their starting positions. +// e.g. "Foo Bar Baz" would result in ([Foo,Bar,Baz],[0,5,9]). +// With this information we can take a row in the database and format it +// the same way as our test data. +func splitHeaderLine(line string) (names []string, pos []int) { + start := 0 + skip := true + for i, r := range line { + switch r { + case ' ', '\t': + if !skip { + names = append(names, line[start:i]) + pos = append(pos, start) + start = -1 + } + skip = true + default: + skip = false + if start == -1 { + start = i + } + } + } + if start >= 0 && start < len(line) { + names = append(names, line[start:]) + pos = append(pos, start) + } + return +} + +// splitByPositions takes a "row" line and the positions of the header columns +// and extracts the values. +// e.g. if we have the positions [0,5,9] (from header "Foo Bar Baz") and +// line is "1 a b", then we'd extract [1,a,b]. +// The whitespace on the right of the start position (e.g. "1 \t") is trimmed. +// This of course requires that the table is properly formatted in a way that the +// header columns are indented to fit the data exactly. +func splitByPositions(line string, positions []int) []string { + out := make([]string, 0, len(positions)) + start := 0 + for _, pos := range positions[1:] { + if start >= len(line) { + out = append(out, "") + start = len(line) + continue + } + out = append(out, strings.TrimRight(line[start:min(pos, len(line))], " \t")) + start = pos + } + out = append(out, strings.TrimRight(line[min(start, len(line)):], " \t")) + return out +} + +// joinByPositions is the reverse of splitByPositions, it takes the columns of a +// row and the starting positions of each and joins into a single line. +// e.g. [1,a,b] and positions [0,5,9] expands to "1 a b". +// NOTE: This does not deal well with mixing tabs and spaces. The test input +// data should preferably just use spaces. +func joinByPositions(row []string, positions []int) string { + var w strings.Builder + prev := 0 + for i, pos := range positions { + for pad := pos - prev; pad > 0; pad-- { + w.WriteByte(' ') + } + w.WriteString(row[i]) + prev = pos + len(row[i]) + } + return w.String() +} diff --git a/script_test.go b/script_test.go new file mode 100644 index 0000000..12e87e9 --- /dev/null +++ b/script_test.go @@ -0,0 +1,42 @@ +package statedb + +import ( + "context" + "maps" + "testing" + + "github.com/cilium/hive" + "github.com/cilium/hive/cell" + "github.com/cilium/hive/hivetest" + "github.com/cilium/hive/script" + "github.com/cilium/hive/script/scripttest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScript(t *testing.T) { + log := hivetest.Logger(t) + h := hive.New( + Cell, // DB + cell.Invoke(func(db *DB) { + t1 := newTestObjectTable(t, "test1", tagsIndex) + require.NoError(t, db.RegisterTable(t1), "RegisterTable") + t2 := newTestObjectTable(t, "test2", tagsIndex) + require.NoError(t, db.RegisterTable(t2), "RegisterTable") + }), + ) + t.Cleanup(func() { + assert.NoError(t, h.Stop(log, context.TODO())) + }) + cmds, err := h.ScriptCommands(log) + require.NoError(t, err, "ScriptCommands") + maps.Insert(cmds, maps.All(script.DefaultCmds())) + engine := &script.Engine{ + Cmds: cmds, + } + scripttest.Test(t, + context.Background(), func() *script.Engine { + return engine + }, []string{}, "testdata/*.txtar") + +} diff --git a/table.go b/table.go index 3b838aa..1d75f07 100644 --- a/table.go +++ b/table.go @@ -9,8 +9,10 @@ import ( "regexp" "runtime" "slices" + "sort" "strings" "sync" + "sync/atomic" "github.com/cilium/statedb/internal" "github.com/cilium/statedb/part" @@ -46,7 +48,8 @@ func NewTable[Obj any]( fromObject: func(iobj object) index.KeySet { return idx.fromObject(iobj.data.(Obj)) }, - unique: idx.isUnique(), + fromString: idx.fromString, + unique: idx.isUnique(), } } @@ -129,6 +132,15 @@ type genTable[Obj any] struct { primaryAnyIndexer anyIndexer secondaryAnyIndexers map[string]anyIndexer indexPositions map[string]int + lastWriteTxn atomic.Pointer[txn] +} + +func (t *genTable[Obj]) acquired(txn *txn) { + t.lastWriteTxn.Store(txn) +} + +func (t *genTable[Obj]) getAcquiredInfo() string { + return t.lastWriteTxn.Load().acquiredInfo() } func (t *genTable[Obj]) tableEntry() tableEntry { @@ -169,6 +181,16 @@ func (t *genTable[Obj]) indexPos(name string) int { return t.indexPositions[name] } +func (t *genTable[Obj]) getIndexer(name string) *anyIndexer { + if name == "" || t.primaryAnyIndexer.name == name { + return &t.primaryAnyIndexer + } + if indexer, ok := t.secondaryAnyIndexers[name]; ok { + return &indexer + } + return nil +} + func (t *genTable[Obj]) PrimaryIndexer() Indexer[Obj] { return t.primaryIndexer } @@ -191,6 +213,7 @@ func (t *genTable[Obj]) Indexes() []string { for k := range t.secondaryAnyIndexers { idxs = append(idxs, k) } + sort.Strings(idxs) return idxs } @@ -243,6 +266,11 @@ func (t *genTable[Obj]) NumObjects(txn ReadTxn) int { return table.numObjects() } +func (t *genTable[Obj]) numDeletedObjects(txn ReadTxn) int { + table := txn.getTxn().getTableEntry(t) + return table.numDeletedObjects() +} + func (t *genTable[Obj]) Get(txn ReadTxn, q Query[Obj]) (obj Obj, revision uint64, ok bool) { obj, revision, _, ok = t.GetWatch(txn, q) return diff --git a/testdata/db.txtar b/testdata/db.txtar new file mode 100644 index 0000000..675ba02 --- /dev/null +++ b/testdata/db.txtar @@ -0,0 +1,178 @@ +# +# This file is invoked by 'script_test.go' and tests the StateDB script commands +# defined in 'script.go'. +# + +hive start + +# Show the registered tables +db tables + +# Initialized +db initialized +db initialized test1 +db initialized test1 test2 + +# Show (empty) +db show test1 +db show test2 + +# Insert +db insert test1 obj1.yaml +db insert test1 obj2.yaml +db insert test2 obj2.yaml + +# Show (non-empty) +db show test1 +grep ^ID.*Tags +grep 1.*bar +grep 2.*baz +db show test2 + +db show -format=table test1 +grep ^ID.*Tags +grep 1.*bar +grep 2.*baz + +db show -format=table -columns=Tags test1 +grep ^Tags$ +grep '^bar, foo$' +grep '^baz, foo$' + +db show -format=json test1 +grep ID.:1.*bar +grep ID.:2.*baz + +db show -format=yaml test1 +grep 'id: 1' +grep 'id: 2' + +db show -format=yaml -o=test1_export.yaml test1 +cmp test1.yaml test1_export.yaml + +# Get +db get test2 2 +db get -format=table test2 2 +grep '^ID.*Tags$' +grep ^2.*baz +db get -format=table -columns=Tags test2 2 +grep ^Tags$ +grep '^baz, foo$' +db get -format=json test2 2 +db get -format=yaml test2 2 +db get -format=yaml -o=obj2_get.yaml test2 2 +cmp obj2.yaml obj2_get.yaml + +db get -index=tags -format=yaml -o=obj1_get.yaml test1 bar +cmp obj1.yaml obj1_get.yaml + +# List +db list -o=list.table test1 1 +cmp obj1.table list.table +db list -o=list.table test1 2 +cmp obj2.table list.table + +db list -o=list.table -index=tags test1 bar +cmp obj1.table list.table +db list -o=list.table -index=tags test1 baz +cmp obj2.table list.table +db list -o=list.table -index=tags test1 foo +cmp objs.table list.table + +db list -format=table -index=tags -columns=Tags test1 foo +grep ^Tags$ +grep '^bar, foo$' +grep '^baz, foo$' + +# Prefix +# uint64 so can't really prefix search meaningfully, unless +# FromString() accomodates partial keys. +db prefix test1 1 + +db prefix -o=prefix.table -index=tags test1 ba +cmp objs.table prefix.table + +# LowerBound +db lowerbound -o=lb.table test1 0 +cmp objs.table lb.table +db lowerbound -o=lb.table test1 1 +cmp objs.table lb.table +db lowerbound -o=lb.table test1 2 +cmp obj2.table lb.table +db lowerbound -o=lb.table test1 3 +cmp empty.table lb.table + +# Compare +db cmp test1 objs.table +db cmp test1 objs_ids.table +db cmp -grep=bar test1 obj1.table +db cmp -grep=baz test1 obj2.table + +# Delete +db delete test1 obj1.yaml +db cmp test1 obj2.table + +db delete test1 obj2.yaml +db cmp test1 empty.table + +# Delete with get +db insert test1 obj1.yaml +db cmp test1 obj1.table +db get -delete test1 1 +db cmp test1 empty.table + +# Delete with prefix +db insert test1 obj1.yaml +db insert test1 obj2.yaml +db cmp test1 objs.table +db prefix -index=tags -delete test1 fo +db cmp test1 empty.table + +# Delete with lowerbound +db insert test1 obj1.yaml +db insert test1 obj2.yaml +db cmp test1 objs.table +db lowerbound -index=id -delete test1 2 +db cmp test1 obj1.table + +# Tables +db tables + +# --------------------- + +-- obj1.yaml -- +id: 1 +tags: + - bar + - foo +-- obj2.yaml -- +id: 2 +tags: + - baz + - foo +-- test1.yaml -- +id: 1 +tags: + - bar + - foo +--- +id: 2 +tags: + - baz + - foo +-- objs.table -- +ID Tags +1 bar, foo +2 baz, foo +-- objs_ids.table -- +ID +1 +2 +-- obj1.table -- +ID Tags +1 bar, foo +-- obj2.table -- +ID Tags +2 baz, foo +-- empty.table -- +ID Tags diff --git a/testutils/script.go b/testutils/script.go deleted file mode 100644 index 0120979..0000000 --- a/testutils/script.go +++ /dev/null @@ -1,494 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright Authors of Cilium - -package testutils - -import ( - "bytes" - "cmp" - "encoding/json" - "flag" - "fmt" - "iter" - "maps" - "os" - "regexp" - "slices" - "strings" - "text/tabwriter" - "time" - - "github.com/cilium/statedb" - "github.com/rogpeppe/go-internal/testscript" - "gopkg.in/yaml.v3" -) - -type Cmd = func(ts *testscript.TestScript, neg bool, args []string) - -const tsDBKey = "statedb" - -func Setup(e *testscript.Env, db *statedb.DB) { - e.Values[tsDBKey] = db -} - -func getDB(ts *testscript.TestScript) *statedb.DB { - v := ts.Value(tsDBKey) - if v == nil { - ts.Fatalf("%q not set, call testutils.Setup()", tsDBKey) - } - return v.(*statedb.DB) -} - -func getTable(ts *testscript.TestScript, tableName string) (*statedb.DB, statedb.ReadTxn, statedb.AnyTable) { - db := getDB(ts) - txn := db.ReadTxn() - meta := db.GetTable(txn, tableName) - if meta == nil { - ts.Fatalf("table %q not found", tableName) - } - tbl := statedb.AnyTable{Meta: meta} - return db, txn, tbl -} - -var ( - Commands = map[string]Cmd{ - "db": DBCmd, - } - SubCommands = map[string]Cmd{ - "tables": TablesCmd, - "show": ShowTableCmd, - "write": WriteTableCmd, - "cmp": CompareTableCmd, - "insert": InsertCmd, - "delete": DeleteCmd, - "prefix": PrefixCmd, - "lowerbound": LowerBoundCmd, - } -) - -func DBCmd(ts *testscript.TestScript, neg bool, args []string) { - if len(args) < 1 { - ts.Fatalf("usage: db args...\n is one of %v", maps.Keys(SubCommands)) - } - if cmd, ok := SubCommands[args[0]]; ok { - cmd(ts, neg, args[1:]) - } else { - ts.Fatalf("unknown db command %q, should be one of %v", args[0], maps.Keys(SubCommands)) - } -} - -func TablesCmd(ts *testscript.TestScript, neg bool, args []string) { - db := getDB(ts) - txn := db.ReadTxn() - tbls := db.GetTables(txn) - var buf bytes.Buffer - w := tabwriter.NewWriter(&buf, 5, 4, 3, ' ', 0) - fmt.Fprintf(w, "Name\tObject count\tIndexes\n") - for _, tbl := range tbls { - idxs := strings.Join(tbl.Indexes(), ", ") - fmt.Fprintf(w, "%s\t%d\t%s\n", tbl.Name(), tbl.NumObjects(txn), idxs) - } - w.Flush() - ts.Logf("%s", buf.String()) -} - -func ShowTableCmd(ts *testscript.TestScript, neg bool, args []string) { - if len(args) != 1 { - ts.Fatalf("usage: show_table ") - } - ts.Logf("%s", showTable(ts, args[0]).String()) -} - -func WriteTableCmd(ts *testscript.TestScript, neg bool, args []string) { - if len(args) < 1 || len(args) > 5 { - ts.Fatalf("usage: write_table
(-to=) (-columns=) (-format={table*,yaml})") - } - var flags flag.FlagSet - file := flags.String("to", "", "File to write to instead of stdout") - columns := flags.String("columns", "", "Comma-separated list of columns to write") - format := flags.String("format", "table", "Format to write in") - - // Sort the args to allow the table name at any position. - slices.SortFunc(args, func(a, b string) int { - switch { - case a[0] == '-': - return 1 - case b[0] == '-': - return -1 - default: - return cmp.Compare(a, b) - } - }) - - if err := flags.Parse(args[1:]); err != nil { - ts.Fatalf("bad args: %s", err) - } - tableName := args[0] - - switch *format { - case "yaml", "json": - if len(*columns) > 0 { - ts.Fatalf("-columns not supported with -format=yaml/json") - } - - _, txn, tbl := getTable(ts, tableName) - var buf bytes.Buffer - count := tbl.Meta.NumObjects(txn) - for obj := range tbl.All(txn) { - if *format == "yaml" { - out, err := yaml.Marshal(obj) - if err != nil { - ts.Fatalf("yaml.Marshal: %s", err) - } - buf.Write(out) - if count > 1 { - buf.WriteString("---\n") - } - } else { - out, err := json.Marshal(obj) - if err != nil { - ts.Fatalf("json.Marshal: %s", err) - } - buf.Write(out) - buf.WriteByte('\n') - } - count-- - } - if *file == "" { - ts.Logf("%s", buf.String()) - } else if err := os.WriteFile(ts.MkAbs(*file), buf.Bytes(), 0644); err != nil { - ts.Fatalf("WriteFile(%s): %s", *file, err) - } - default: - var cols []string - if len(*columns) > 0 { - cols = strings.Split(*columns, ",") - } - buf := showTable(ts, tableName, cols...) - if *file == "" { - ts.Logf("%s", buf.String()) - } else if err := os.WriteFile(ts.MkAbs(*file), buf.Bytes(), 0644); err != nil { - ts.Fatalf("WriteFile(%s): %s", *file, err) - } - } -} - -func CompareTableCmd(ts *testscript.TestScript, neg bool, args []string) { - var flags flag.FlagSet - timeout := flags.Duration("timeout", time.Second, "Maximum amount of time to wait for the table contents to match") - grep := flags.String("grep", "", "Grep the result rows and only compare matching ones") - - err := flags.Parse(args) - args = args[len(args)-flags.NArg():] - if err != nil || len(args) != 2 { - ts.Fatalf("usage: cmp (-timeout=) (-grep=)
") - } - - var grepRe *regexp.Regexp - if *grep != "" { - grepRe, err = regexp.Compile(*grep) - if err != nil { - ts.Fatalf("bad grep: %s", err) - } - } - - tableName := args[0] - db, _, tbl := getTable(ts, tableName) - header := tbl.TableHeader() - - data := ts.ReadFile(args[1]) - lines := strings.Split(data, "\n") - lines = slices.DeleteFunc(lines, func(line string) bool { - return strings.TrimSpace(line) == "" - }) - if len(lines) < 1 { - ts.Fatalf("%q missing header line, e.g. %q", args[1], strings.Join(header, " ")) - } - - columnNames, columnPositions := splitHeaderLine(lines[0]) - columnIndexes, err := getColumnIndexes(columnNames, header) - if err != nil { - ts.Fatalf("%s", err) - } - lines = lines[1:] - origLines := lines - tryUntil := time.Now().Add(*timeout) - - for { - lines = origLines - - // Create the diff between 'lines' and the rows in the table. - equal := true - var diff bytes.Buffer - w := tabwriter.NewWriter(&diff, 5, 4, 3, ' ', 0) - fmt.Fprintf(w, " %s\n", joinByPositions(columnNames, columnPositions)) - - for obj := range tbl.All(db.ReadTxn()) { - rowRaw := takeColumns(obj.(statedb.TableWritable).TableRow(), columnIndexes) - row := joinByPositions(rowRaw, columnPositions) - if grepRe != nil && !grepRe.Match([]byte(row)) { - continue - } - - if len(lines) == 0 { - equal = false - fmt.Fprintf(w, "- %s\n", row) - continue - } - line := lines[0] - splitLine := splitByPositions(line, columnPositions) - - if slices.Equal(rowRaw, splitLine) { - fmt.Fprintf(w, " %s\n", row) - } else { - fmt.Fprintf(w, "- %s\n", row) - fmt.Fprintf(w, "+ %s\n", line) - equal = false - } - lines = lines[1:] - } - for _, line := range lines { - fmt.Fprintf(w, "+ %s\n", line) - equal = false - } - if equal { - return - } - w.Flush() - - if time.Now().After(tryUntil) { - ts.Fatalf("table mismatch:\n%s", diff.String()) - } - time.Sleep(10 * time.Millisecond) - } -} - -func InsertCmd(ts *testscript.TestScript, neg bool, args []string) { - insertOrDeleteCmd(ts, true, args) -} - -func DeleteCmd(ts *testscript.TestScript, neg bool, args []string) { - insertOrDeleteCmd(ts, false, args) -} - -func insertOrDeleteCmd(ts *testscript.TestScript, insert bool, args []string) { - if len(args) < 2 { - if insert { - ts.Fatalf("usage: insert
path...") - } else { - ts.Fatalf("usage: delete
path...") - } - } - - db, _, tbl := getTable(ts, args[0]) - wtxn := db.WriteTxn(tbl.Meta) - defer wtxn.Commit() - - for _, arg := range args[1:] { - data := ts.ReadFile(arg) - parts := strings.Split(data, "---") - for _, part := range parts { - obj, err := tbl.UnmarshalYAML([]byte(part)) - if err != nil { - ts.Fatalf("Unmarshal(%s): %s", arg, err) - } - if insert { - _, _, err = tbl.Insert(wtxn, obj) - if err != nil { - ts.Fatalf("Insert(%s): %s", arg, err) - } - } else { - _, _, err = tbl.Delete(wtxn, obj) - if err != nil { - ts.Fatalf("Delete(%s): %s", arg, err) - } - - } - } - } -} - -func PrefixCmd(ts *testscript.TestScript, neg bool, args []string) { - prefixOrLowerboundCmd(ts, false, args) -} - -func LowerBoundCmd(ts *testscript.TestScript, neg bool, args []string) { - prefixOrLowerboundCmd(ts, true, args) -} - -func prefixOrLowerboundCmd(ts *testscript.TestScript, lowerbound bool, args []string) { - db := getDB(ts) - if len(args) < 2 { - if lowerbound { - ts.Fatalf("usage: lowerbound
(-to=)") - } else { - ts.Fatalf("usage: prefix
(-to=)") - } - } - - var flags flag.FlagSet - file := flags.String("to", "", "File to write to instead of stdout") - if err := flags.Parse(args[2:]); err != nil { - ts.Fatalf("bad args: %s", err) - } - - txn := db.ReadTxn() - meta := db.GetTable(txn, args[0]) - if meta == nil { - ts.Fatalf("table %q not found", args[0]) - } - tbl := statedb.AnyTable{Meta: meta} - var buf bytes.Buffer - w := tabwriter.NewWriter(&buf, 5, 4, 3, ' ', 0) - header := tbl.TableHeader() - fmt.Fprintf(w, "%s\n", strings.Join(header, "\t")) - - var it iter.Seq2[any, uint64] - if lowerbound { - it = tbl.LowerBound(txn, args[1]) - } else { - it = tbl.Prefix(txn, args[1]) - } - - for obj := range it { - row := obj.(statedb.TableWritable).TableRow() - fmt.Fprintf(w, "%s\n", strings.Join(row, "\t")) - } - w.Flush() - if *file == "" { - ts.Logf("%s", buf.String()) - } else if err := os.WriteFile(ts.MkAbs(*file), buf.Bytes(), 0644); err != nil { - ts.Fatalf("WriteFile(%s): %s", *file, err) - } -} - -// splitHeaderLine takes a header of column names separated by any -// number of whitespaces and returns the names and their starting positions. -// e.g. "Foo Bar Baz" would result in ([Foo,Bar,Baz],[0,5,9]). -// With this information we can take a row in the database and format it -// the same way as our test data. -func splitHeaderLine(line string) (names []string, pos []int) { - start := 0 - skip := true - for i, r := range line { - switch r { - case ' ', '\t': - if !skip { - names = append(names, line[start:i]) - pos = append(pos, start) - start = -1 - } - skip = true - default: - skip = false - if start == -1 { - start = i - } - } - } - if start >= 0 && start < len(line) { - names = append(names, line[start:]) - pos = append(pos, start) - } - return -} - -// splitByPositions takes a "row" line and the positions of the header columns -// and extracts the values. -// e.g. if we have the positions [0,5,9] (from header "Foo Bar Baz") and -// line is "1 a b", then we'd extract [1,a,b]. -// The whitespace on the right of the start position (e.g. "1 \t") is trimmed. -// This of course requires that the table is properly formatted in a way that the -// header columns are indented to fit the data exactly. -func splitByPositions(line string, positions []int) []string { - out := make([]string, 0, len(positions)) - start := 0 - for _, pos := range positions[1:] { - if start >= len(line) { - out = append(out, "") - start = len(line) - continue - } - out = append(out, strings.TrimRight(line[start:min(pos, len(line))], " \t")) - start = pos - } - out = append(out, strings.TrimRight(line[min(start, len(line)):], " \t")) - return out -} - -// joinByPositions is the reverse of splitByPositions, it takes the columns of a -// row and the starting positions of each and joins into a single line. -// e.g. [1,a,b] and positions [0,5,9] expands to "1 a b". -// NOTE: This does not deal well with mixing tabs and spaces. The test input -// data should preferably just use spaces. -func joinByPositions(row []string, positions []int) string { - var w strings.Builder - prev := 0 - for i, pos := range positions { - for pad := pos - prev; pad > 0; pad-- { - w.WriteByte(' ') - } - w.WriteString(row[i]) - prev = pos + len(row[i]) - } - return w.String() -} - -func showTable(ts *testscript.TestScript, tableName string, columns ...string) *bytes.Buffer { - db := getDB(ts) - txn := db.ReadTxn() - meta := db.GetTable(txn, tableName) - if meta == nil { - ts.Fatalf("table %q not found", tableName) - } - tbl := statedb.AnyTable{Meta: meta} - - header := tbl.TableHeader() - if header == nil { - ts.Fatalf("objects in table %q not TableWritable", meta.Name()) - } - var idxs []int - var err error - if len(columns) > 0 { - idxs, err = getColumnIndexes(columns, header) - header = columns - } else { - idxs, err = getColumnIndexes(header, header) - } - if err != nil { - ts.Fatalf("%s", err) - } - - var buf bytes.Buffer - w := tabwriter.NewWriter(&buf, 5, 4, 3, ' ', 0) - fmt.Fprintf(w, "%s\n", strings.Join(header, "\t")) - for obj := range tbl.All(db.ReadTxn()) { - row := takeColumns(obj.(statedb.TableWritable).TableRow(), idxs) - fmt.Fprintf(w, "%s\n", strings.Join(row, "\t")) - } - w.Flush() - return &buf -} - -func takeColumns[T any](xs []T, idxs []int) []T { - // Invariant: idxs is sorted so can set in-place. - for i, idx := range idxs { - xs[i] = xs[idx] - } - return xs[:len(idxs)] -} - -func getColumnIndexes(names []string, header []string) ([]int, error) { - columnIndexes := make([]int, 0, len(header)) -loop: - for _, name := range names { - for i, name2 := range header { - if strings.EqualFold(name, name2) { - columnIndexes = append(columnIndexes, i) - continue loop - } - } - return nil, fmt.Errorf("column %q not part of %v", name, header) - } - return columnIndexes, nil -} diff --git a/testutils/script_test.go b/testutils/script_test.go deleted file mode 100644 index a99106f..0000000 --- a/testutils/script_test.go +++ /dev/null @@ -1,62 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright Authors of Cilium - -package testutils_test - -import ( - "flag" - "strings" - "testing" - - "github.com/cilium/statedb" - "github.com/cilium/statedb/index" - "github.com/cilium/statedb/testutils" - "github.com/rogpeppe/go-internal/testscript" -) - -type object struct { - Name string - Tags []string -} - -func (o object) TableHeader() []string { - return []string{"Name", "Tags"} -} - -func (o object) TableRow() []string { - return []string{ - o.Name, - strings.Join(o.Tags, ", "), - } -} - -var nameIdx = statedb.Index[object, string]{ - Name: "name", - FromObject: func(obj object) index.KeySet { - return index.NewKeySet(index.String(obj.Name)) - }, - FromKey: index.String, - Unique: true, -} - -var update = flag.Bool("update", false, "update the txtar files") - -func TestScriptCommands(t *testing.T) { - testscript.Run(t, testscript.Params{ - Dir: "testdata", - Setup: func(e *testscript.Env) error { - db := statedb.New() - tbl, err := statedb.NewTable("names", nameIdx) - if err != nil { - t.Fatalf("NewTable: %s", err) - } - if err := db.RegisterTable(tbl); err != nil { - t.Fatalf("RegisterTable: %s", err) - } - testutils.Setup(e, db) - return nil - }, - Cmds: testutils.Commands, - UpdateScripts: *update, - }) -} diff --git a/testutils/testdata/test.txtar b/testutils/testdata/test.txtar deleted file mode 100644 index e5e4e9e..0000000 --- a/testutils/testdata/test.txtar +++ /dev/null @@ -1,87 +0,0 @@ -db tables -db show names - -db insert names data.yaml -db show names - -# Compare the contents of a table -db cmp names names.table - -# Compare against subset of the columns -db cmp names names_name.table - -# Compare the table with retries up to 10s (1s is default) -db cmp -timeout=10s names names.table - -# Compare only rows that match the grep pattern -db cmp -grep=^baz names baz.table - -# Write the table to a file with specific columns -db write names -to=out.table -columns=Name,Tags - -# Use the plain 'cmp'. You'll want to use 'UpdateScript' -# to create and update the expected output. -cmp out.table out_expected.table - -# Write the table to a file as yaml -db write names -to=out.yaml -format=yaml -cmp out.yaml out_expected.yaml - -# Prefix search the table with the primary key. Only useful -# for stringy primary keys. -db prefix names q -db prefix names ba -to=out_prefix_ba.table - -# LowerBound searches -db lowerbound names a -to=out_lb_a.table -cmp out_lb_a.table out_expected.table -db lowerbound names z -to=out_lb_z.table -cmp out_lb_z.table empty.table - -# Delete and check that it's empty. -db delete names quux-name.yaml -db cmp names baz.table -db cmp names out_prefix_ba.table - -db delete names data.yaml -db cmp names empty.table - --- data.yaml -- -name: quux -tags: -- foo -- bar ---- -name: baz - --- quux-name.yaml -- -name: quux - --- names.table -- -Name Tags -baz -quux foo, bar - --- names_name.table -- -Name -baz -quux - --- baz.table -- -Name -baz - --- empty.table -- -Name Tags --- out_expected.table -- -Name Tags -baz -quux foo, bar --- out_expected.yaml -- -name: baz -tags: [] ---- -name: quux -tags: - - foo - - bar diff --git a/txn.go b/txn.go index f5115d8..4a7bc1c 100644 --- a/txn.go +++ b/txn.go @@ -12,6 +12,7 @@ import ( "reflect" "runtime" "slices" + "sync/atomic" "time" "github.com/cilium/statedb/index" @@ -20,12 +21,18 @@ import ( ) type txn struct { - db *DB - handle string - root dbRoot + db *DB + root dbRoot + + handle string + acquiredAt time.Time // the time at which the transaction acquired the locks + duration atomic.Uint64 // the transaction duration after it finished + writeTxn +} + +type writeTxn struct { modifiedTables []*tableEntry // table entries being modified smus internal.SortableMutexes // the (sorted) table locks - acquiredAt time.Time // the time at which the transaction acquired the locks tableNames []string } @@ -46,6 +53,23 @@ func (txn *txn) getTxn() *txn { return txn } +// acquiredInfo returns the information for the "Last WriteTxn" column +// in "db tables" command. The correctness of this relies on the following assumptions: +// - txn.handle and txn.acquiredAt are not modified +// - txn.duration is atomically updated on Commit or Abort +func (txn *txn) acquiredInfo() string { + if txn == nil { + return "" + } + since := internal.PrettySince(txn.acquiredAt) + dur := time.Duration(txn.duration.Load()) + if txn.duration.Load() == 0 { + // Still locked + return fmt.Sprintf("%s (locked for %s)", txn.handle, since) + } + return fmt.Sprintf("%s (%s ago, locked for %s)", txn.handle, since, internal.PrettyDuration(dur)) +} + // txnFinalizer is called when the GC frees *txn. It checks that a WriteTxn // has been Aborted or Committed. This is a safeguard against forgetting to // Abort/Commit which would cause the table to be locked forever. @@ -402,7 +426,7 @@ func decodeNonUniqueKey(key []byte) (secondary []byte, encPrimary []byte) { func (txn *txn) Abort() { runtime.SetFinalizer(txn, nil) - // If writeTxns is nil, this transaction has already been committed or aborted, and + // If modifiedTables is nil, this transaction has already been committed or aborted, and // thus there is nothing to do. We allow this without failure to allow for defer // pattern: // @@ -421,13 +445,15 @@ func (txn *txn) Abort() { return } + txn.duration.Store(uint64(time.Since(txn.acquiredAt))) + txn.smus.Unlock() txn.db.metrics.WriteTxnDuration( txn.handle, txn.tableNames, time.Since(txn.acquiredAt)) - *txn = zeroTxn + txn.writeTxn = writeTxn{} } // Commit the transaction. Returns a ReadTxn that is the snapshot of the database at the @@ -459,6 +485,8 @@ func (txn *txn) Commit() ReadTxn { return nil } + txn.duration.Store(uint64(time.Since(txn.acquiredAt))) + db := txn.db // Commit each individual changed index to each table. @@ -514,6 +542,7 @@ func (txn *txn) Commit() ReadTxn { // Commit the transaction to build the new root tree and then // atomically store it. + txn.root = root db.root.Store(&root) db.mu.Unlock() @@ -536,11 +565,8 @@ func (txn *txn) Commit() ReadTxn { txn.tableNames, time.Since(txn.acquiredAt)) - // Zero out the transaction to make it inert and - // convert it into a ReadTxn. - *txn = zeroTxn - txn.db = db - txn.root = root + // Convert into a ReadTxn + txn.writeTxn = writeTxn{} return txn } diff --git a/types.go b/types.go index ac28faa..5492e64 100644 --- a/types.go +++ b/types.go @@ -4,6 +4,7 @@ package statedb import ( + "errors" "io" "iter" @@ -230,13 +231,17 @@ type tableInternal interface { tablePos() int setTablePos(int) indexPos(string) int - tableKey() []byte // The radix key for the table in the root tree + tableKey() []byte // The radix key for the table in the root tree + getIndexer(name string) *anyIndexer primary() anyIndexer // The untyped primary indexer for the table secondary() map[string]anyIndexer // Secondary indexers (if any) sortableMutex() internal.SortableMutex // The sortable mutex for locking the table for writing anyChanges(txn WriteTxn) (anyChangeIterator, error) proto() any // Returns the zero value of 'Obj', e.g. the prototype unmarshalYAML(data []byte) (any, error) // Unmarshal the data into 'Obj' + numDeletedObjects(txn ReadTxn) int // Number of objects in graveyard + acquired(*txn) + getAcquiredInfo() string } type ReadTxn interface { @@ -282,10 +287,26 @@ func ByRevision[Obj any](rev uint64) Query[Obj] { // Index implements the indexing of objects (FromObjects) and querying of objects from the index (FromKey) type Index[Obj any, Key any] struct { - Name string + // Name of the index + Name string + + // FromObject extracts key(s) from the object. The key set + // can contain 0, 1 or more keys. FromObject func(obj Obj) index.KeySet - FromKey func(key Key) index.Key - Unique bool + + // FromKey converts the index key into a raw key. + // With this we can perform Query() against this index with + // the [Key] type. + FromKey func(key Key) index.Key + + // FromString is an optional conversion from string to a raw key. + // If implemented allows script commands to query with this index. + FromString func(key string) (index.Key, error) + + // Unique marks the index as unique. Primary index must always be + // unique. A secondary index may be non-unique in which case a single + // key may map to multiple objects. + Unique bool } var _ Indexer[struct{}] = &Index[struct{}, bool]{} @@ -303,6 +324,17 @@ func (i Index[Obj, Key]) fromObject(obj Obj) index.KeySet { return i.FromObject(obj) } +var errFromStringNil = errors.New("FromString not defined") + +//nolint:unused +func (i Index[Obj, Key]) fromString(s string) (index.Key, error) { + if i.FromString == nil { + return index.Key{}, errFromStringNil + } + k, err := i.FromString(s) + return k, err +} + //nolint:unused func (i Index[Obj, Key]) isUnique() bool { return i.Unique @@ -333,6 +365,7 @@ type Indexer[Obj any] interface { indexName() string isUnique() bool fromObject(Obj) index.KeySet + fromString(string) (index.Key, error) ObjectToKey(Obj) index.Key QueryFromObject(Obj) Query[Obj] @@ -383,6 +416,9 @@ type anyIndexer struct { // object with. fromObject func(object) index.KeySet + // fromString converts string into a key. Optional. + fromString func(string) (index.Key, error) + // unique if true will index the object solely on the // values returned by fromObject. If false the primary // key of the object will be appended to the key.