From 1461afea1ad084b5b77b951ea835d06c7b0979fa Mon Sep 17 00:00:00 2001 From: Sergey Kamardin Date: Tue, 21 Sep 2021 13:35:20 +0300 Subject: [PATCH] flagutil: make LinkFlags() to be unidirectional LinkFlag() --- flagutil.go | 89 ++++++++++++++++-------------------------------- flagutil_test.go | 88 +++++++++++++++++++++++++---------------------- 2 files changed, 78 insertions(+), 99 deletions(-) diff --git a/flagutil.go b/flagutil.go index 3cbb6a7..cf02b3c 100644 --- a/flagutil.go +++ b/flagutil.go @@ -553,68 +553,39 @@ func SetActual(fs *flag.FlagSet, name string) { } } -// LinkFlags links flags named as n0 and n1 in existing flag set fs. -// If any of the flags doesn't exist LinkFlags() will create one. -func LinkFlags(fs *flag.FlagSet, n0, n1 string) { - var ( - u0 string - u1 string - v0 flag.Value - v1 flag.Value - ) - f0 := fs.Lookup(n0) - if f0 != nil { - v0 = f0.Value - u0 = f0.Usage +// LinkFlag links dst to be updated when src value is set. +// It panics if any of the given names doesn't exist in fs. +// +// Note that it caches the both src and dst flag.Value pointers internally, so +// it is possible to link src to dst and dst to src without infinite recursion. +// However, if any of the src or dst flag value get overwritten after +// LinkFlag() call, created link will not work properly anymore. +func LinkFlag(fs *flag.FlagSet, src, dst string) { + srcFlag := fs.Lookup(src) + if srcFlag == nil { + panic(fmt.Sprintf( + "flagutil: link flag: source flag %q must exist", + src, + )) } - f1 := fs.Lookup(n1) - if f1 != nil { - v1 = f1.Value - u1 = f1.Usage + dstFlag := fs.Lookup(dst) + if dstFlag == nil { + panic(fmt.Sprintf( + "flagutil: link flag: destination flag %q must exist", + src, + )) } - - usage := mergeUsage(n0+","+n1, u0, u1) - - v := value{ - doSet: func(s string) (err error) { - if err == nil && v0 != nil { - err = v0.Set(s) - } - if err == nil && v1 != nil { - err = v1.Set(s) - } + var ( + srcValue = srcFlag.Value + dstValue = dstFlag.Value + ) + srcFlag.Value = OverrideSet(srcFlag.Value, func(s string) error { + err := srcValue.Set(s) + if err != nil { return err - }, - doIsBoolFlag: func() bool { - if v0 == nil || v1 == nil { - // Can't guess in advance. - return false - } - return isBoolValue(v0) && isBoolValue(v1) - }, - doString: func() string { - if v0 == nil || v1 == nil { - // Can't guess in advance. - return "" - } - s0 := v0.String() - s1 := v1.String() - if s0 == s1 { - return s0 - } - return "" - }, - } - if f0 != nil { - f0.Value = v - } else { - fs.Var(v, n0, usage) - } - if f1 != nil { - f1.Value = v - } else { - fs.Var(v, n1, usage) - } + } + return dstValue.Set(s) + }) } func mergeUsage(name, s0, s1 string) string { diff --git a/flagutil_test.go b/flagutil_test.go index 3f62e6f..db4efd5 100644 --- a/flagutil_test.go +++ b/flagutil_test.go @@ -101,11 +101,11 @@ func TestUnquoteUsage(t *testing.T) { type expMode map[UnquoteUsageMode][2]string for _, test := range []struct { name string - flag flag.Flag + flag *flag.Flag modes expMode }{ { - flag: flag.Flag{ + flag: &flag.Flag{ Usage: "foo `bar` baz", }, modes: expMode{ @@ -149,7 +149,7 @@ func TestUnquoteUsage(t *testing.T) { t.Run(test.name, func(t *testing.T) { for mode, exp := range test.modes { t.Run(mode.String(), func(t *testing.T) { - actName, actUsage := unquoteUsage(mode, &test.flag) + actName, actUsage := unquoteUsage(mode, test.flag) if expName := exp[0]; actName != expName { t.Errorf("unexpected name:\n%s", cmp.Diff(expName, actName)) } @@ -362,13 +362,13 @@ func isActual(fs *flag.FlagSet, name string) (actual bool) { func TestCombineFlags(t *testing.T) { for _, test := range []struct { name string - flags [2]flag.Flag - exp flag.Flag + flags [2]*flag.Flag + exp *flag.Flag panic bool }{ { name: "different names", - flags: [2]flag.Flag{ + flags: [2]*flag.Flag{ stringFlag("foo", "def", "desc#0"), stringFlag("bar", "def", "desc#1"), }, @@ -376,7 +376,7 @@ func TestCombineFlags(t *testing.T) { }, { name: "different default values", - flags: [2]flag.Flag{ + flags: [2]*flag.Flag{ stringFlag("foo", "def#0", "desc#0"), stringFlag("foo", "def#1", "desc#1"), }, @@ -384,7 +384,7 @@ func TestCombineFlags(t *testing.T) { }, { name: "basic", - flags: [2]flag.Flag{ + flags: [2]*flag.Flag{ stringFlag("foo", "def", "desc#0"), stringFlag("foo", "def", "desc#1"), }, @@ -392,7 +392,7 @@ func TestCombineFlags(t *testing.T) { }, { name: "basic", - flags: [2]flag.Flag{ + flags: [2]*flag.Flag{ stringFlag("foo", "def", "desc#0"), stringFlag("foo", "", "desc#1"), }, @@ -414,7 +414,7 @@ func TestCombineFlags(t *testing.T) { } }() done <- flagOrPanic{ - flag: CombineFlags(&test.flags[0], &test.flags[1]), + flag: CombineFlags(test.flags[0], test.flags[1]), } }() x := <-done @@ -432,7 +432,7 @@ func TestCombineFlags(t *testing.T) { return v.String() }), } - if act, exp := x.flag, &test.exp; !cmp.Equal(act, exp, opts...) { + if act, exp := x.flag, test.exp; !cmp.Equal(act, exp, opts...) { t.Errorf("unexpected flag:\n%s", cmp.Diff(exp, act, opts...)) } exp := fmt.Sprintf("%x", rand.Int63()) @@ -440,57 +440,65 @@ func TestCombineFlags(t *testing.T) { t.Fatalf("unexpected Set() error: %v", err) } for _, f := range test.flags { - if act := f.Value.String(); act != exp { - t.Errorf("unexpected flag value: %s; want %s", act, exp) - } + assertEquals(t, f, exp) } }) } } -func TestLinkFlags(t *testing.T) { +func TestLinkFlag(t *testing.T) { for _, test := range []struct { name string - flags [2]flag.Flag + flags [2]*flag.Flag + links [2]string }{ { name: "basic", - flags: [2]flag.Flag{ + flags: [2]*flag.Flag{ stringFlag("foo", "def#0", "desc#0"), stringFlag("bar", "def#1", "desc#1"), }, + links: [2]string{"foo", "bar"}, }, } { - for i := 0; i < 2; i++ { - t.Run(test.name, func(t *testing.T) { - fs := flag.NewFlagSet("", flag.PanicOnError) - for _, f := range test.flags { + t.Run(test.name, func(t *testing.T) { + fs := flag.NewFlagSet("", flag.PanicOnError) + for _, f := range test.flags { + if f != nil { fs.Var(f.Value, f.Name, f.Usage) } - LinkFlags(fs, - test.flags[0].Name, - test.flags[1].Name, - ) - - exp := fmt.Sprintf("%x", rand.Int63()) - fs.Set(test.flags[i].Name, exp) - - for _, f := range test.flags { - if act := f.Value.String(); act != exp { - t.Errorf( - "unexpected flag %q value: %s; want %s", - f.Name, act, exp, - ) - } + } + LinkFlag(fs, test.links[0], test.links[1]) + + // First, test that setting for src flag affects dst flag. + exp := fmt.Sprintf("%x", rand.Int63()) + fs.Set(test.links[0], exp) + for _, n := range test.links { + if f := fs.Lookup(n); f != nil { + assertEquals(t, f, exp) } - }) - } + } + // Second, test that setting dst flag doesn't affect src flag. + nonExp := fmt.Sprintf("%x", rand.Int63()) + fs.Set(test.flags[1].Name, nonExp) + assertEquals(t, test.flags[0], exp) // Still the same. + assertEquals(t, test.flags[1], nonExp) // Updated. + }) + } +} + +func assertEquals(t *testing.T, f *flag.Flag, exp string) { + if act := f.Value.String(); act != exp { + t.Errorf( + "unexpected flag %q value: %s; want %s", + f.Name, act, exp, + ) } } -func stringFlag(name, def, desc string) flag.Flag { +func stringFlag(name, def, desc string) *flag.Flag { fs := flag.NewFlagSet("", flag.PanicOnError) fs.String(name, def, desc) f := fs.Lookup(name) - return *f + return f }