diff --git a/part/map.go b/part/map.go index 9548a14..70d0b07 100644 --- a/part/map.go +++ b/part/map.go @@ -9,6 +9,8 @@ import ( "fmt" "iter" "reflect" + + "gopkg.in/yaml.v3" ) // Map of key-value pairs. The zero value is ready for use, provided @@ -22,8 +24,8 @@ type Map[K, V any] struct { } type mapKVPair[K, V any] struct { - Key K `json:"k"` - Value V `json:"v"` + Key K `json:"k" yaml:"k"` + Value V `json:"v" yaml:"v"` } // FromMap copies values from the hash map into the given Map. @@ -238,3 +240,29 @@ func (m *Map[K, V]) UnmarshalJSON(data []byte) error { m.tree = txn.CommitOnly() return nil } + +func (m Map[K, V]) MarshalYAML() (any, error) { + kvs := make([]mapKVPair[K, V], 0, m.Len()) + iter := m.tree.Iterator() + for _, kv, ok := iter.Next(); ok; _, kv, ok = iter.Next() { + kvs = append(kvs, kv) + } + return kvs, nil +} + +func (m *Map[K, V]) UnmarshalYAML(value *yaml.Node) error { + if value.Kind != yaml.SequenceNode { + return fmt.Errorf("%T.UnmarshalYAML: expected sequence", m) + } + m.ensureTree() + txn := m.tree.Txn() + for _, e := range value.Content { + var kv mapKVPair[K, V] + if err := e.Decode(&kv); err != nil { + return err + } + txn.Insert(m.bytesFromKey(kv.Key), mapKVPair[K, V]{kv.Key, kv.Value}) + } + m.tree = txn.CommitOnly() + return nil +} diff --git a/part/map_test.go b/part/map_test.go index be2b917..4c61c38 100644 --- a/part/map_test.go +++ b/part/map_test.go @@ -5,6 +5,7 @@ package part_test import ( "encoding/json" + "fmt" "iter" "math/rand/v2" "testing" @@ -12,6 +13,7 @@ import ( "github.com/cilium/statedb/part" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) func TestStringMap(t *testing.T) { @@ -172,6 +174,39 @@ func TestMapJSON(t *testing.T) { require.True(t, m.SlowEqual(m2), "SlowEqual") } +func TestMapYAMLStringKey(t *testing.T) { + var m part.Map[string, int] + m = m.Set("foo", 1).Set("bar", 2).Set("baz", 3) + + bs, err := yaml.Marshal(m) + require.NoError(t, err, "Marshal") + + var m2 part.Map[string, int] + err = yaml.Unmarshal(bs, &m2) + require.NoError(t, err, "Unmarshal") + require.True(t, m.SlowEqual(m2), "SlowEqual") +} + +func TestMapYAMLStructKey(t *testing.T) { + type key struct { + A int `yaml:"a"` + B string `yaml:"b"` + } + part.RegisterKeyType[key](func(k key) []byte { + return []byte(fmt.Sprintf("%d-%s", k.A, k.B)) + }) + var m part.Map[key, int] + m = m.Set(key{1, "one"}, 1).Set(key{2, "two"}, 2).Set(key{3, "three"}, 3) + + bs, err := yaml.Marshal(m) + require.NoError(t, err, "Marshal") + + var m2 part.Map[key, int] + err = yaml.Unmarshal(bs, &m2) + require.NoError(t, err, "Unmarshal") + require.True(t, m.SlowEqual(m2), "SlowEqual") +} + func Benchmark_Uint64Map_Random(b *testing.B) { numItems := 1000 keys := map[uint64]int{} diff --git a/part/set.go b/part/set.go index 8c677bc..89a91f0 100644 --- a/part/set.go +++ b/part/set.go @@ -8,6 +8,9 @@ import ( "encoding/json" "fmt" "iter" + "slices" + + "gopkg.in/yaml.v3" ) // Set is a persistent (immutable) set of values. A Set can be @@ -208,6 +211,32 @@ func (s *Set[T]) UnmarshalJSON(data []byte) error { return nil } +func (s Set[T]) MarshalYAML() (any, error) { + // TODO: Once yaml.v3 supports iter.Seq, drop the Collect(). + return slices.Collect(s.All()), nil +} + +func (s *Set[T]) UnmarshalYAML(value *yaml.Node) error { + if value.Kind != yaml.SequenceNode { + return fmt.Errorf("%T.UnmarshalYAML: expected sequence", s) + } + + if s.tree == nil { + *s = NewSet[T]() + } + txn := s.tree.Txn() + + for _, e := range value.Content { + var v T + if err := e.Decode(&v); err != nil { + return err + } + txn.Insert(s.toBytes(v), v) + } + s.tree = txn.CommitOnly() + return nil +} + func toSeq[T any](iter *Iterator[T]) iter.Seq[T] { return func(yield func(T) bool) { if iter == nil { diff --git a/part/set_test.go b/part/set_test.go index 4ca853b..eff33b5 100644 --- a/part/set_test.go +++ b/part/set_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + package part_test import ( @@ -8,6 +11,7 @@ import ( "github.com/cilium/statedb/part" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) func TestStringSet(t *testing.T) { @@ -65,3 +69,15 @@ func TestSetJSON(t *testing.T) { require.NoError(t, err, "Unmarshal") require.True(t, s.Equal(s2), "Equal") } + +func TestSetYAML(t *testing.T) { + s := part.NewSet("foo", "bar", "baz") + + bs, err := yaml.Marshal(s) + require.NoError(t, err, "Marshal") + + var s2 part.Set[string] + err = yaml.Unmarshal(bs, &s2) + require.NoError(t, err, "Unmarshal") + require.True(t, s.Equal(s2), "Equal") +}