diff --git a/flagutil.go b/flagutil.go index 5988f40..3cbb6a7 100644 --- a/flagutil.go +++ b/flagutil.go @@ -549,7 +549,71 @@ func SetActual(fs *flag.FlagSet, name string) { } fs.Set(name, "dummy") if !didSet { - panic("flagutil: make specified didn't work well") + panic("flagutil: set actual didn't work well") + } +} + +// LinkFlags links flags named as n0 and n1 in existing flag set fs. +// If any of the flags doesn't exist LinkFlags() will create one. +func LinkFlags(fs *flag.FlagSet, n0, n1 string) { + var ( + u0 string + u1 string + v0 flag.Value + v1 flag.Value + ) + f0 := fs.Lookup(n0) + if f0 != nil { + v0 = f0.Value + u0 = f0.Usage + } + f1 := fs.Lookup(n1) + if f1 != nil { + v1 = f1.Value + u1 = f1.Usage + } + + usage := mergeUsage(n0+","+n1, u0, u1) + + v := value{ + doSet: func(s string) (err error) { + if err == nil && v0 != nil { + err = v0.Set(s) + } + if err == nil && v1 != nil { + err = v1.Set(s) + } + return err + }, + doIsBoolFlag: func() bool { + if v0 == nil || v1 == nil { + // Can't guess in advance. + return false + } + return isBoolValue(v0) && isBoolValue(v1) + }, + doString: func() string { + if v0 == nil || v1 == nil { + // Can't guess in advance. + return "" + } + s0 := v0.String() + s1 := v1.String() + if s0 == s1 { + return s0 + } + return "" + }, + } + if f0 != nil { + f0.Value = v + } else { + fs.Var(v, n0, usage) + } + if f1 != nil { + f1.Value = v + } else { + fs.Var(v, n1, usage) } } diff --git a/flagutil_test.go b/flagutil_test.go index 18082d4..3f62e6f 100644 --- a/flagutil_test.go +++ b/flagutil_test.go @@ -448,6 +448,46 @@ func TestCombineFlags(t *testing.T) { } } +func TestLinkFlags(t *testing.T) { + for _, test := range []struct { + name string + flags [2]flag.Flag + }{ + { + name: "basic", + flags: [2]flag.Flag{ + stringFlag("foo", "def#0", "desc#0"), + stringFlag("bar", "def#1", "desc#1"), + }, + }, + } { + for i := 0; i < 2; i++ { + t.Run(test.name, func(t *testing.T) { + fs := flag.NewFlagSet("", flag.PanicOnError) + for _, f := range test.flags { + fs.Var(f.Value, f.Name, f.Usage) + } + LinkFlags(fs, + test.flags[0].Name, + test.flags[1].Name, + ) + + exp := fmt.Sprintf("%x", rand.Int63()) + fs.Set(test.flags[i].Name, exp) + + for _, f := range test.flags { + if act := f.Value.String(); act != exp { + t.Errorf( + "unexpected flag %q value: %s; want %s", + f.Name, act, exp, + ) + } + } + }) + } + } +} + func stringFlag(name, def, desc string) flag.Flag { fs := flag.NewFlagSet("", flag.PanicOnError) fs.String(name, def, desc)