Skip to content

Commit

Permalink
part: Add YAML support to Set and Map
Browse files Browse the repository at this point in the history
To allow using part.Set[] and part.Map[] types with the YAML marshalling/unmarshalling
in the StateDB script commands, add the custom YAML marshalling and unmarshalling methods
for these types.

Signed-off-by: Jussi Maki <[email protected]>
  • Loading branch information
joamaki committed Oct 10, 2024
1 parent 0f24f74 commit 70dced9
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 2 deletions.
32 changes: 30 additions & 2 deletions part/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"fmt"
"iter"
"reflect"

"gopkg.in/yaml.v3"
)

// Map of key-value pairs. The zero value is ready for use, provided
Expand All @@ -22,8 +24,8 @@ type Map[K, V any] struct {
}

type mapKVPair[K, V any] struct {
Key K `json:"k"`
Value V `json:"v"`
Key K `json:"k" yaml:"k"`
Value V `json:"v" yaml:"v"`
}

// FromMap copies values from the hash map into the given Map.
Expand Down Expand Up @@ -238,3 +240,29 @@ func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
m.tree = txn.CommitOnly()
return nil
}

func (m Map[K, V]) MarshalYAML() (any, error) {
kvs := make([]mapKVPair[K, V], 0, m.Len())
iter := m.tree.Iterator()
for _, kv, ok := iter.Next(); ok; _, kv, ok = iter.Next() {
kvs = append(kvs, kv)
}
return kvs, nil
}

func (m *Map[K, V]) UnmarshalYAML(value *yaml.Node) error {
if value.Kind != yaml.SequenceNode {
return fmt.Errorf("%T.UnmarshalYAML: expected sequence", m)
}
m.ensureTree()
txn := m.tree.Txn()
for _, e := range value.Content {
var kv mapKVPair[K, V]
if err := e.Decode(&kv); err != nil {
return err
}
txn.Insert(m.bytesFromKey(kv.Key), mapKVPair[K, V]{kv.Key, kv.Value})
}
m.tree = txn.CommitOnly()
return nil
}
35 changes: 35 additions & 0 deletions part/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ package part_test

import (
"encoding/json"
"fmt"
"iter"
"math/rand/v2"
"testing"

"github.com/cilium/statedb/part"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)

func TestStringMap(t *testing.T) {
Expand Down Expand Up @@ -172,6 +174,39 @@ func TestMapJSON(t *testing.T) {
require.True(t, m.SlowEqual(m2), "SlowEqual")
}

func TestMapYAMLStringKey(t *testing.T) {
var m part.Map[string, int]
m = m.Set("foo", 1).Set("bar", 2).Set("baz", 3)

bs, err := yaml.Marshal(m)
require.NoError(t, err, "Marshal")

var m2 part.Map[string, int]
err = yaml.Unmarshal(bs, &m2)
require.NoError(t, err, "Unmarshal")
require.True(t, m.SlowEqual(m2), "SlowEqual")
}

func TestMapYAMLStructKey(t *testing.T) {
type key struct {
A int `yaml:"a"`
B string `yaml:"b"`
}
part.RegisterKeyType[key](func(k key) []byte {
return []byte(fmt.Sprintf("%d-%s", k.A, k.B))
})
var m part.Map[key, int]
m = m.Set(key{1, "one"}, 1).Set(key{2, "two"}, 2).Set(key{3, "three"}, 3)

bs, err := yaml.Marshal(m)
require.NoError(t, err, "Marshal")

var m2 part.Map[key, int]
err = yaml.Unmarshal(bs, &m2)
require.NoError(t, err, "Unmarshal")
require.True(t, m.SlowEqual(m2), "SlowEqual")
}

func Benchmark_Uint64Map_Random(b *testing.B) {
numItems := 1000
keys := map[uint64]int{}
Expand Down
29 changes: 29 additions & 0 deletions part/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"encoding/json"
"fmt"
"iter"
"slices"

"gopkg.in/yaml.v3"
)

// Set is a persistent (immutable) set of values. A Set can be
Expand Down Expand Up @@ -208,6 +211,32 @@ func (s *Set[T]) UnmarshalJSON(data []byte) error {
return nil
}

func (s Set[T]) MarshalYAML() (any, error) {
// TODO: Once yaml.v3 supports iter.Seq, drop the Collect().
return slices.Collect(s.All()), nil
}

func (s *Set[T]) UnmarshalYAML(value *yaml.Node) error {
if value.Kind != yaml.SequenceNode {
return fmt.Errorf("%T.UnmarshalYAML: expected sequence", s)
}

if s.tree == nil {
*s = NewSet[T]()
}
txn := s.tree.Txn()

for _, e := range value.Content {
var v T
if err := e.Decode(&v); err != nil {
return err
}
txn.Insert(s.toBytes(v), v)
}
s.tree = txn.CommitOnly()
return nil
}

func toSeq[T any](iter *Iterator[T]) iter.Seq[T] {
return func(yield func(T) bool) {
if iter == nil {
Expand Down
16 changes: 16 additions & 0 deletions part/set_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium

package part_test

import (
Expand All @@ -8,6 +11,7 @@ import (
"github.com/cilium/statedb/part"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)

func TestStringSet(t *testing.T) {
Expand Down Expand Up @@ -65,3 +69,15 @@ func TestSetJSON(t *testing.T) {
require.NoError(t, err, "Unmarshal")
require.True(t, s.Equal(s2), "Equal")
}

func TestSetYAML(t *testing.T) {
s := part.NewSet("foo", "bar", "baz")

bs, err := yaml.Marshal(s)
require.NoError(t, err, "Marshal")

var s2 part.Set[string]
err = yaml.Unmarshal(bs, &s2)
require.NoError(t, err, "Unmarshal")
require.True(t, s.Equal(s2), "Equal")
}

0 comments on commit 70dced9

Please sign in to comment.