diff --git a/flag.go b/flag.go index ac14b816..55594df4 100644 --- a/flag.go +++ b/flag.go @@ -120,9 +120,9 @@ const ( PanicOnError ) -// normalizedName is a flag name that has been normalized according to rules +// NormalizedName is a flag name that has been normalized according to rules // for the FlagSet (e.g. making '-' and '_' equivalent). -type normalizedName string +type NormalizedName string // A FlagSet represents a set of defined flags. type FlagSet struct { @@ -131,17 +131,17 @@ type FlagSet struct { // a custom error handler. Usage func() - name string - parsed bool - actual map[normalizedName]*Flag - formal map[normalizedName]*Flag - shorthands map[byte]*Flag - args []string // arguments after flags - exitOnError bool // does the program exit if there's an error? - errorHandling ErrorHandling - output io.Writer // nil means stderr; use out() accessor - interspersed bool // allow interspersed option/non-option args - wordSeparators []string + name string + parsed bool + actual map[NormalizedName]*Flag + formal map[NormalizedName]*Flag + shorthands map[byte]*Flag + args []string // arguments after flags + exitOnError bool // does the program exit if there's an error? + errorHandling ErrorHandling + output io.Writer // nil means stderr; use out() accessor + interspersed bool // allow interspersed option/non-option args + normalizeNameFunc func(f *FlagSet, name string) NormalizedName } // A Flag represents the state of a flag. @@ -165,7 +165,7 @@ type Value interface { } // sortFlags returns the flags as a slice in lexicographical sorted order. -func sortFlags(flags map[normalizedName]*Flag) []*Flag { +func sortFlags(flags map[NormalizedName]*Flag) []*Flag { list := make(sort.StringSlice, len(flags)) i := 0 for k := range flags { @@ -175,18 +175,29 @@ func sortFlags(flags map[normalizedName]*Flag) []*Flag { list.Sort() result := make([]*Flag, len(list)) for i, name := range list { - result[i] = flags[normalizedName(name)] + result[i] = flags[NormalizedName(name)] } return result } -func (f *FlagSet) normalizeFlagName(name string) normalizedName { - result := name - for _, sep := range f.wordSeparators { - result = strings.Replace(result, sep, "-", -1) +func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) { + f.normalizeNameFunc = n + for k, v := range f.formal { + delete(f.formal, k) + f.formal[f.normalizeFlagName(string(k))] = v } - // Type convert to indicate normalization has been done. - return normalizedName(result) +} + +func (f *FlagSet) GetNormalizeFunc() func(f *FlagSet, name string) NormalizedName { + if f.normalizeNameFunc != nil { + return f.normalizeNameFunc + } + return func(f *FlagSet, name string) NormalizedName { return NormalizedName(name) } +} + +func (f *FlagSet) normalizeFlagName(name string) NormalizedName { + n := f.GetNormalizeFunc() + return n(f, name) } func (f *FlagSet) out() io.Writer { @@ -240,7 +251,7 @@ func (f *FlagSet) Lookup(name string) *Flag { } // lookup returns the Flag structure of the named flag, returning nil if none exists. -func (f *FlagSet) lookup(name normalizedName) *Flag { +func (f *FlagSet) lookup(name NormalizedName) *Flag { return f.formal[name] } @@ -272,7 +283,7 @@ func (f *FlagSet) Set(name, value string) error { return err } if f.actual == nil { - f.actual = make(map[normalizedName]*Flag) + f.actual = make(map[NormalizedName]*Flag) } f.actual[normalName] = flag flag.Changed = true @@ -417,7 +428,7 @@ func (f *FlagSet) AddFlag(flag *Flag) { panic(msg) // Happens only if flags are declared with identical names } if f.formal == nil { - f.formal = make(map[normalizedName]*Flag) + f.formal = make(map[NormalizedName]*Flag) } f.formal[f.normalizeFlagName(flag.Name)] = flag @@ -482,7 +493,7 @@ func (f *FlagSet) setFlag(flag *Flag, value string, origArg string) error { } // mark as visited for Visit() if f.actual == nil { - f.actual = make(map[normalizedName]*Flag) + f.actual = make(map[NormalizedName]*Flag) } f.actual[f.normalizeFlagName(flag.Name)] = flag flag.Changed = true @@ -648,19 +659,6 @@ func SetInterspersed(interspersed bool) { CommandLine.SetInterspersed(interspersed) } -// SetWordSeparators sets a list of strings to be considerered as word -// separators and normalized for the pruposes of lookups. For example, if this -// is set to {"-", "_", "."} then --foo_bar, --foo-bar, and --foo.bar are -// considered equivalent flags. This must be called before flags are parsed, -// and may only be called once. -func (f *FlagSet) SetWordSeparators(separators []string) { - f.wordSeparators = separators - for k, v := range f.formal { - delete(f.formal, k) - f.formal[f.normalizeFlagName(string(k))] = v - } -} - // Parsed returns true if the command-line flags have been parsed. func Parsed() bool { return CommandLine.Parsed() diff --git a/flag_test.go b/flag_test.go index a1478e26..f552a2f5 100644 --- a/flag_test.go +++ b/flag_test.go @@ -239,14 +239,29 @@ func TestFlagSetParse(t *testing.T) { testParse(NewFlagSet("test", ContinueOnError), t) } -func testNormalizedNames(args []string, t *testing.T) { +func replaceSeparators(name string, from []string, to string) string { + result := name + for _, sep := range from { + result = strings.Replace(result, sep, to, -1) + } + // Type convert to indicate normalization has been done. + return result +} + +func wordSepNormalizeFunc(f *FlagSet, name string) NormalizedName { + seps := []string{"-", "_"} + name = replaceSeparators(name, seps, ".") + return NormalizedName(name) +} + +func testWordSepNormalizedNames(args []string, t *testing.T) { f := NewFlagSet("normalized", ContinueOnError) if f.Parsed() { t.Error("f.Parse() = true before Parse") } withDashFlag := f.Bool("with-dash-flag", false, "bool value") // Set this after some flags have been added and before others. - f.SetWordSeparators([]string{"-", "_"}) + f.SetNormalizeFunc(wordSepNormalizeFunc) withUnderFlag := f.Bool("with_under_flag", false, "bool value") withBothFlag := f.Bool("with-both_flag", false, "bool value") if err := f.Parse(args); err != nil { @@ -266,27 +281,66 @@ func testNormalizedNames(args []string, t *testing.T) { } } -func TestNormalizedNames(t *testing.T) { +func TestWordSepNormalizedNames(t *testing.T) { args := []string{ "--with-dash-flag", "--with-under-flag", "--with-both-flag", } - testNormalizedNames(args, t) + testWordSepNormalizedNames(args, t) args = []string{ "--with_dash_flag", "--with_under_flag", "--with_both_flag", } - testNormalizedNames(args, t) + testWordSepNormalizedNames(args, t) args = []string{ "--with-dash_flag", "--with-under_flag", "--with-both_flag", } - testNormalizedNames(args, t) + testWordSepNormalizedNames(args, t) +} + +func aliasAndWordSepFlagNames(f *FlagSet, name string) NormalizedName { + seps := []string{"-", "_"} + + oldName := replaceSeparators("old-valid_flag", seps, ".") + newName := replaceSeparators("valid-flag", seps, ".") + + name = replaceSeparators(name, seps, ".") + switch name { + case oldName: + name = newName + break + } + + return NormalizedName(name) +} + +func TestCustomNormalizedNames(t *testing.T) { + f := NewFlagSet("normalized", ContinueOnError) + if f.Parsed() { + t.Error("f.Parse() = true before Parse") + } + + validFlag := f.Bool("valid-flag", false, "bool value") + f.SetNormalizeFunc(aliasAndWordSepFlagNames) + someOtherFlag := f.Bool("some-other-flag", false, "bool value") + + args := []string{"--old_valid_flag", "--some-other_flag"} + if err := f.Parse(args); err != nil { + t.Fatal(err) + } + + if *validFlag != true { + t.Errorf("validFlag is %v even though we set the alias --old_valid_falg", *validFlag) + } + if *someOtherFlag != true { + t.Error("someOtherFlag should be true, is ", *someOtherFlag) + } } // Declare a user-defined flag type. @@ -503,7 +557,7 @@ func TestDeprecatedFlagUsage(t *testing.T) { func TestDeprecatedFlagUsageNormalized(t *testing.T) { f := NewFlagSet("bob", ContinueOnError) f.Bool("bad-double_flag", true, "always true") - f.SetWordSeparators([]string{"-", "_"}) + f.SetNormalizeFunc(wordSepNormalizeFunc) usageMsg := "use --good-flag instead" f.MarkDeprecated("bad_double-flag", usageMsg)