From 589bfe58c995f262786c7c020acce2094a345dea Mon Sep 17 00:00:00 2001 From: Sergey Kamardin Date: Tue, 9 Jun 2020 11:47:19 +0300 Subject: [PATCH] ability to print usage per parser (#1) --- README.md | 12 +- flagutil.go | 317 ++++++++++++++++++++++++++++++++--- flagutil_test.go | 148 ++++++++++++++++ options.go | 119 +++++++++++++ parse/args/args.go | 6 + parse/env/env.go | 16 +- parse/env/env_test.go | 3 +- parse/file/file.go | 38 +++-- parse/file/file_test.go | 17 +- parse/file/toml/toml_test.go | 1 + parse/file/yaml/yaml_test.go | 3 +- parse/flagset.go | 41 ++++- parse/pargs/posix.go | 83 +++++++++ parse/pargs/posix_test.go | 63 ++++++- parse/testutil/testutil.go | 3 +- 15 files changed, 807 insertions(+), 63 deletions(-) create mode 100644 flagutil_test.go create mode 100644 options.go diff --git a/README.md b/README.md index 11ba286..f18b24d 100644 --- a/README.md +++ b/README.md @@ -68,10 +68,14 @@ func main() { // Then lookup for "config" flag value and try to parse its value as a // json configuration file. - flagutil.WithParser(&file.Parser{ - Lookup: file.FlagLookup("config"), - Syntax: &json.Syntax{}, - }), + flagutil.WithParser( + &file.Parser{ + Lookup: file.LookupFlag(flags, "config"), + Syntax: new(json.Syntax), + }, + // Don't allow to setup "config" flag from file. + flagutil.WithIgnoreByName("config"), + ), ) // Work with received values. diff --git a/flagutil.go b/flagutil.go index c176e66..8f040f7 100644 --- a/flagutil.go +++ b/flagutil.go @@ -1,9 +1,13 @@ package flagutil import ( + "bytes" "flag" "fmt" "os" + "reflect" + "strings" + "time" "github.com/gobwas/flagutil/parse" ) @@ -14,38 +18,56 @@ type Parser interface { Parse(parse.FlagSet) error } -type ParseOption func(*config) +type ParserFunc func(parse.FlagSet) error -func WithParser(p Parser) ParseOption { - return func(c *config) { - c.parsers = append(c.parsers, p) - } +func (fn ParserFunc) Parse(fs parse.FlagSet) error { + return fn(fs) } -func WithIgnoreUndefined() ParseOption { - return func(c *config) { - c.ignoreUndefined = true - } +type Printer interface { + Name(parse.FlagSet) func(*flag.Flag, func(string)) +} + +type PrinterFunc func(parse.FlagSet) func(*flag.Flag, func(string)) + +func (fn PrinterFunc) Name(fs parse.FlagSet) func(*flag.Flag, func(string)) { + return fn(fs) +} + +type parser struct { + Parser + ignore func(*flag.Flag) bool } type config struct { - parsers []Parser - ignoreUndefined bool + parsers []*parser + ignoreUndefined bool + ignoreUsage bool + unquoteUsageMode UnquoteUsageMode } -func Parse(flags *flag.FlagSet, opts ...ParseOption) (err error) { - var c config +func Parse(flags *flag.FlagSet, opts ...Option) (err error) { + c := config{ + unquoteUsageMode: UnquoteDefault, + } for _, opt := range opts { - opt(&c) + switch x := opt.(type) { + case ParseOption: + x.setupParseConfig(&c) + case PrintOption: + x.setupPrintConfig(&c) + } } fs := parse.NewFlagSet(flags, parse.WithIgnoreUndefined(c.ignoreUndefined), ) for _, p := range c.parsers { parse.NextLevel(fs) + parse.Ignore(fs, p.ignore) + if err = p.Parse(fs); err != nil { if err == flag.ErrHelp { - printUsage(flags) + printUsage(&c, flags) } switch flags.ErrorHandling() { case flag.ContinueOnError: @@ -63,17 +85,261 @@ func Parse(flags *flag.FlagSet, opts ...ParseOption) (err error) { return nil } -func printUsage(f *flag.FlagSet) { - if f.Usage != nil { - f.Usage() +// PrintDefaults prints parsers aware usage message to flags.Output(). +func PrintDefaults(flags *flag.FlagSet, opts ...PrintOption) { + c := config{ + unquoteUsageMode: UnquoteDefault, + } + for _, opt := range opts { + opt.setupPrintConfig(&c) + } + printDefaults(&c, flags) +} + +func printUsage(c *config, flags *flag.FlagSet) { + if !c.ignoreUsage && flags.Usage != nil { + flags.Usage() return } - if name := f.Name(); name == "" { - fmt.Fprintf(f.Output(), "Usage:\n") + if name := flags.Name(); name == "" { + fmt.Fprintf(flags.Output(), "Usage:\n") } else { - fmt.Fprintf(f.Output(), "Usage of %s:\n", name) + fmt.Fprintf(flags.Output(), "Usage of %s:\n", name) } - f.PrintDefaults() + printDefaults(c, flags) +} + +type UnquoteUsageMode uint8 + +const ( + UnquoteNothing UnquoteUsageMode = 1 << iota >> 1 + UnquoteQuoted + UnquoteInferType + UnquoteClean + + UnquoteDefault UnquoteUsageMode = UnquoteQuoted | UnquoteInferType +) + +func (m UnquoteUsageMode) String() string { + switch m { + case UnquoteNothing: + return "UnquoteNothing" + case UnquoteQuoted: + return "UnquoteQuoted" + case UnquoteInferType: + return "UnquoteInferType" + case UnquoteClean: + return "UnquoteClean" + case UnquoteDefault: + return "UnquoteDefault" + default: + return "" + } +} + +func (m UnquoteUsageMode) has(x UnquoteUsageMode) bool { + return m&x != 0 +} + +func printDefaults(c *config, flags *flag.FlagSet) { + fs := parse.NewFlagSet(flags) + + var hasNameFunc bool + nameFunc := make([]func(*flag.Flag, func(string)), len(c.parsers)) + for i := len(c.parsers) - 1; i >= 0; i-- { + if p, ok := c.parsers[i].Parser.(Printer); ok { + hasNameFunc = true + nameFunc[i] = p.Name(fs) + } + } + + var buf bytes.Buffer + flags.VisitAll(func(f *flag.Flag) { + n, _ := buf.WriteString(" ") + for i := len(c.parsers) - 1; i >= 0; i-- { + fn := nameFunc[i] + if fn == nil { + continue + } + if ignore := c.parsers[i].ignore; ignore != nil && ignore(f) { + continue + } + fn(f, func(name string) { + if buf.Len() > n { + buf.WriteString(", ") + } + buf.WriteString(name) + }) + } + if buf.Len() == n { + // No name has been given. + // Two cases are possible: no Printer implementation among parsers; + // or some parser intentionally filtered out this flag. + if hasNameFunc { + buf.Reset() + return + } + buf.WriteString(f.Name) + } + name, usage := unquoteUsage(c.unquoteUsageMode, f) + if len(name) > 0 { + buf.WriteString("\n \t") + buf.WriteString(name) + } + buf.WriteString("\n \t") + if len(usage) > 0 { + buf.WriteString(strings.ReplaceAll(usage, "\n", "\n \t")) + buf.WriteString(" (") + } + buf.WriteString("default ") + buf.WriteString(defValue(f)) + if len(usage) > 0 { + buf.WriteString(")") + } + + buf.WriteByte('\n') + buf.WriteByte('\n') + buf.WriteTo(flags.Output()) + }) +} + +func defValue(f *flag.Flag) string { + v := reflect.ValueOf(f.Value) +repeat: + if v.Kind() == reflect.Ptr { + v = v.Elem() + goto repeat + } + d := f.DefValue + if d == "" { + switch v.Kind() { + case reflect.String: + return `""` + case + reflect.Slice, + reflect.Array: + return `[]` + case + reflect.Struct, + reflect.Map: + return `{}` + } + return "?" + } + if v.Kind() == reflect.String { + return `"` + d + `"` + } + return d +} + +// unquoteUsage is the same as flag.UnquoteUsage() with exception that it does +// not infer type of the flag value. +func unquoteUsage(m UnquoteUsageMode, f *flag.Flag) (name, usage string) { + if m == UnquoteNothing { + return "", f.Usage + } + u := f.Usage + i := strings.IndexByte(u, '`') + if i == -1 { + if m.has(UnquoteInferType) { + return inferType(f), f.Usage + } + return "", u + } + j := strings.IndexByte(u[i+1:], '`') + if j == -1 { + if m.has(UnquoteInferType) { + return inferType(f), f.Usage + } + return "", u + } + j += i + 1 + + switch { + case m.has(UnquoteQuoted): + name = u[i+1 : j] + case m.has(UnquoteInferType): + name = inferType(f) + } + + prefix := u[:i] + suffix := u[j+1:] + switch { + case m.has(UnquoteClean): + usage = "" + + strings.TrimRight(prefix, " ") + + " " + + strings.TrimLeft(suffix, " ") + + case m.has(UnquoteQuoted): + usage = prefix + name + suffix + + default: + usage = f.Usage + } + + return +} + +func inferType(f *flag.Flag) string { + if f.Value == nil { + return "?" + } + if isBoolFlag(f) { + return "bool" + } + + var x interface{} + if g, ok := f.Value.(flag.Getter); ok { + x = g.Get() + } else { + x = f.Value + } + v := reflect.ValueOf(x) + +repeat: + switch v.Type() { + case reflect.TypeOf(time.Duration(0)): + return "duration" + } + switch v.Kind() { + case + reflect.Interface, + reflect.Ptr: + + v = v.Elem() + goto repeat + + case + reflect.String: + return "string" + case + reflect.Float32, + reflect.Float64: + return "float" + case + reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64: + return "int" + case + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64: + return "uint" + case + reflect.Slice, + reflect.Array: + return "list" + case + reflect.Map: + return "object" + } + return "" } // Subset registers new flag subset with given prefix within given flag @@ -95,3 +361,10 @@ func Subset(super *flag.FlagSet, prefix string, setup func(sub *flag.FlagSet)) ( }) return } + +func isBoolFlag(f *flag.Flag) bool { + x, ok := f.Value.(interface { + IsBoolFlag() bool + }) + return ok && x.IsBoolFlag() +} diff --git a/flagutil_test.go b/flagutil_test.go new file mode 100644 index 0000000..b98f61a --- /dev/null +++ b/flagutil_test.go @@ -0,0 +1,148 @@ +package flagutil + +import ( + "bytes" + "flag" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/gobwas/flagutil/parse" +) + +type fullParser struct { + Parser + Printer +} + +func TestPrintUsage(t *testing.T) { + var buf bytes.Buffer + fs := flag.NewFlagSet("test", flag.PanicOnError) + fs.SetOutput(&buf) + var ( + s string + n int + ) + fs.StringVar(&s, "foo", "", "`custom` description here") + fs.IntVar(&n, "bar", n, "description here") + + PrintDefaults(fs, + WithParser( + &fullParser{ + Parser: nil, + Printer: PrinterFunc(func(fs parse.FlagSet) func(*flag.Flag, func(string)) { + return func(f *flag.Flag, it func(string)) { + it("MUST-IGNORE-" + strings.ToUpper(f.Name)) + } + }), + }, + WithIgnoreByPrefix("f"), + WithIgnoreByName("bar"), + ), + WithParser( + &fullParser{ + Parser: nil, + Printer: PrinterFunc(func(fs parse.FlagSet) func(*flag.Flag, func(string)) { + return func(f *flag.Flag, it func(string)) { + it("-" + string(f.Name[0])) + it("-" + f.Name) + } + }), + }, + WithIgnoreByName("foo"), + ), + WithParser(&fullParser{ + Parser: nil, + Printer: PrinterFunc(func(fs parse.FlagSet) func(*flag.Flag, func(string)) { + return func(f *flag.Flag, it func(string)) { + it("$" + strings.ToUpper(f.Name)) + } + }), + }), + ) + exp := "" + + " $BAR, -b, -bar\n" + + " \tint\n" + + " \tdescription here (default 0)\n" + + "\n" + + " $FOO\n" + // -foo is ignored. + " \tcustom\n" + + " \tcustom description here (default \"\")\n" + + "\n" + if act := buf.String(); act != exp { + t.Error(cmp.Diff(exp, act)) + } +} + +func TestUnquoteUsage(t *testing.T) { + type expMode map[UnquoteUsageMode][2]string + for _, test := range []struct { + name string + flag flag.Flag + modes expMode + }{ + { + flag: flag.Flag{ + Usage: "foo `bar` baz", + }, + modes: expMode{ + UnquoteNothing: [2]string{ + "", "foo `bar` baz", + }, + UnquoteQuoted: [2]string{ + "bar", "foo bar baz", + }, + UnquoteClean: [2]string{ + "", "foo baz", + }, + }, + }, + { + flag: stringFlag("", "", "some kind of `hello` message"), + modes: expMode{ + UnquoteDefault: [2]string{ + "hello", "some kind of hello message", + }, + UnquoteInferType: [2]string{ + "string", "some kind of `hello` message", + }, + UnquoteInferType | UnquoteClean: [2]string{ + "string", "some kind of message", + }, + }, + }, + { + flag: stringFlag("", "", "no quoted info"), + modes: expMode{ + UnquoteQuoted: [2]string{ + "", "no quoted info", + }, + UnquoteInferType: [2]string{ + "string", "no quoted info", + }, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + for mode, exp := range test.modes { + t.Run(mode.String(), func(t *testing.T) { + actName, actUsage := unquoteUsage(mode, &test.flag) + if expName := exp[0]; actName != expName { + t.Errorf("unexpected name:\n%s", cmp.Diff(expName, actName)) + } + if expUsage := exp[1]; actUsage != expUsage { + t.Errorf("unexpected usage:\n%s", cmp.Diff(expUsage, actUsage)) + } + }) + } + }) + } +} + +func stringFlag(name, def, desc string) flag.Flag { + fs := flag.NewFlagSet("", flag.PanicOnError) + fs.String(name, def, desc) + f := fs.Lookup(name) + return *f +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..04468f4 --- /dev/null +++ b/options.go @@ -0,0 +1,119 @@ +package flagutil + +import ( + "flag" + "strings" +) + +type Option interface { + isOption() +} + +// ParseOption is a generic option that can be passed to Parse(). +type ParseOption interface { + isOption() + setupParseConfig(*config) +} + +// PrintOption is a generic option that can be passed to PrintDefaults(). +type PrintOption interface { + isOption() + setupPrintConfig(*config) +} + +type ParsePrintOption interface { + isOption() + setupParseConfig(*config) + setupPrintConfig(*config) +} + +type parseOptionFunc func(*config) + +func (fn parseOptionFunc) isOption() {} +func (fn parseOptionFunc) setupParseConfig(c *config) { fn(c) } + +type printOptionFunc func(*config) + +func (fn printOptionFunc) isOption() {} +func (fn printOptionFunc) setupPrintConfig(c *config) { fn(c) } + +type parserOption struct { + p *parser +} + +func (p parserOption) isOption() {} +func (p parserOption) setupParseConfig(c *config) { c.parsers = append(c.parsers, p.p) } +func (p parserOption) setupPrintConfig(c *config) { c.parsers = append(c.parsers, p.p) } + +type ParserOption interface { + setupParserConfig(*parser) +} + +type parserOptionFunc func(*parser) + +func (fn parserOptionFunc) setupParserConfig(p *parser) { + fn(p) +} + +func WithIgnoreByName(names ...string) ParserOption { + m := make(map[string]bool, len(names)) + for _, name := range names { + m[name] = true + } + return parserOptionFunc(func(p *parser) { + prev := p.ignore + p.ignore = func(f *flag.Flag) bool { + if prev != nil && prev(f) { + return true + } + return m[f.Name] + } + }) +} + +func WithIgnoreByPrefix(prefix string) ParserOption { + return parserOptionFunc(func(p *parser) { + prev := p.ignore + p.ignore = func(f *flag.Flag) bool { + if prev != nil && prev(f) { + return true + } + return strings.HasPrefix(f.Name, prefix) + } + }) +} + +// WithParser returns a parse option and makes p to be used during Parse(). +func WithParser(p Parser, opts ...ParserOption) ParsePrintOption { + x := &parser{ + Parser: p, + } + for _, opt := range opts { + opt.setupParserConfig(x) + } + return parserOption{ + p: x, + } +} + +// WithIgnoreUndefined makes Parse() to not fail on setting undefined flag. +func WithIgnoreUndefined() ParseOption { + return parseOptionFunc(func(c *config) { + c.ignoreUndefined = true + }) +} + +// WithIgnoreUsage makes Parse() to ignore flag.FlagSet.Usage field when +// receiving flag.ErrHelp error from some parser and print results of +// flagutil.PrintDefaults() instead. +func WithIgnoreUsage() ParseOption { + return parseOptionFunc(func(c *config) { + c.ignoreUsage = true + }) +} + +func WithUnquoteUsageMode(m UnquoteUsageMode) PrintOption { + return printOptionFunc(func(c *config) { + c.unquoteUsageMode = m + }) +} diff --git a/parse/args/args.go b/parse/args/args.go index 74ce5db..677106c 100644 --- a/parse/args/args.go +++ b/parse/args/args.go @@ -22,3 +22,9 @@ func (p *Parser) Parse(fs parse.FlagSet) error { return dup.Parse(p.Args) } + +func (p *Parser) Name(fs parse.FlagSet) func(*flag.Flag, func(string)) { + return func(f *flag.Flag, it func(string)) { + it("-" + f.Name) + } +} diff --git a/parse/env/env.go b/parse/env/env.go index ff0e786..a1663ee 100644 --- a/parse/env/env.go +++ b/parse/env/env.go @@ -65,8 +65,7 @@ func (p *Parser) Parse(fs parse.FlagSet) (err error) { } } fs.VisitAll(func(f *flag.Flag) { - name := p.Prefix + strings.ToUpper(f.Name) - name = p.replacer.Replace(name) + name := p.name(f) value, has := p.lookupEnv(name) if !has { return @@ -83,6 +82,19 @@ func (p *Parser) Parse(fs parse.FlagSet) (err error) { return err } +func (p *Parser) Name(fs parse.FlagSet) func(*flag.Flag, func(string)) { + p.init() + return func(f *flag.Flag, it func(string)) { + it("$" + p.name(f)) + } +} + +func (p *Parser) name(f *flag.Flag) string { + name := p.Prefix + strings.ToUpper(f.Name) + name = p.replacer.Replace(name) + return name +} + func (p *Parser) lookupEnv(name string) (value string, has bool) { if f := p.LookupEnvFunc; f != nil { return f(name) diff --git a/parse/env/env_test.go b/parse/env/env_test.go index d7bfe73..0a751ee 100644 --- a/parse/env/env_test.go +++ b/parse/env/env_test.go @@ -5,9 +5,10 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" + "github.com/gobwas/flagutil/parse" "github.com/gobwas/flagutil/parse/testutil" - "github.com/google/go-cmp/cmp" ) func TestEnvParser(t *testing.T) { diff --git a/parse/file/file.go b/parse/file/file.go index 7c0857e..9073794 100644 --- a/parse/file/file.go +++ b/parse/file/file.go @@ -2,6 +2,7 @@ package file import ( "bytes" + "flag" "fmt" "io" "io/ioutil" @@ -17,7 +18,7 @@ type Syntax interface { // Lookup is an interface to search for syntax source. type Lookup interface { - Lookup(parse.FlagGetter) (io.ReadCloser, error) + Lookup() (io.ReadCloser, error) } // ErrNoFile is an returned by Lookup implementation to report that lookup @@ -36,9 +37,9 @@ func (f LookupFunc) Lookup() (io.ReadCloser, error) { type MultiLookup []Lookup // Lookup implements Lookup interface. -func (ls MultiLookup) Lookup(fs parse.FlagGetter) (io.ReadCloser, error) { +func (ls MultiLookup) Lookup() (io.ReadCloser, error) { for _, l := range ls { - rc, err := l.Lookup(fs) + rc, err := l.Lookup() if err == ErrNoFile { continue } @@ -52,15 +53,26 @@ func (ls MultiLookup) Lookup(fs parse.FlagGetter) (io.ReadCloser, error) { // FlagLookup search for flag with equal name and interprets it as filename to // open. -type FlagLookup string +type FlagLookup struct { + FlagSet *flag.FlagSet + Name string +} + +// LookupFlag is a shortcut to build up a FlagLookup structure. +func LookupFlag(fs *flag.FlagSet, name string) *FlagLookup { + return &FlagLookup{ + FlagSet: fs, + Name: name, + } +} // Lookup implements Lookup interface. -func (s FlagLookup) Lookup(fs parse.FlagGetter) (io.ReadCloser, error) { - f := fs.Lookup(string(s)) - if f == nil { +func (f *FlagLookup) Lookup() (io.ReadCloser, error) { + flag := f.FlagSet.Lookup(f.Name) + if flag == nil { return nil, ErrNoFile } - path := f.Value.String() + path := flag.Value.String() if path == "" { return nil, ErrNoFile } @@ -72,7 +84,7 @@ func (s FlagLookup) Lookup(fs parse.FlagGetter) (io.ReadCloser, error) { type PathLookup string // Lookup implements Lookup interface. -func (p PathLookup) Lookup(fs parse.FlagGetter) (io.ReadCloser, error) { +func (p PathLookup) Lookup() (io.ReadCloser, error) { info, err := os.Stat(string(p)) if os.IsNotExist(err) { return nil, ErrNoFile @@ -93,7 +105,7 @@ func (p PathLookup) Lookup(fs parse.FlagGetter) (io.ReadCloser, error) { type BytesLookup []byte // Lookup implements Lookup interface. -func (b BytesLookup) Lookup(fs parse.FlagGetter) (io.ReadCloser, error) { +func (b BytesLookup) Lookup() (io.ReadCloser, error) { return ioutil.NopCloser(bytes.NewReader(b)), nil } @@ -112,7 +124,7 @@ type Parser struct { // Parse implements flagutil.Parser interface. func (p *Parser) Parse(fs parse.FlagSet) error { - bts, err := p.readSource(fs) + bts, err := p.readSource() if err == ErrNoFile { if p.Required { err = fmt.Errorf("flagutil/file: source not found") @@ -140,8 +152,8 @@ func (p *Parser) Parse(fs parse.FlagSet) error { }) } -func (p *Parser) readSource(fs parse.FlagSet) ([]byte, error) { - src, err := p.Lookup.Lookup(fs) +func (p *Parser) readSource() ([]byte, error) { + src, err := p.Lookup.Lookup() if err != nil { return nil, err } diff --git a/parse/file/file_test.go b/parse/file/file_test.go index 6675a1f..0f8bc43 100644 --- a/parse/file/file_test.go +++ b/parse/file/file_test.go @@ -3,16 +3,15 @@ package file import ( "bytes" "crypto/rand" + "flag" "io/ioutil" "os" "testing" - - "github.com/gobwas/flagutil/parse/testutil" ) var ( _ Lookup = MultiLookup{} - _ Lookup = FlagLookup("") + _ Lookup = &FlagLookup{} _ Lookup = PathLookup("") _ Lookup = BytesLookup{} ) @@ -24,7 +23,7 @@ func TestPathLookupDir(t *testing.T) { } defer os.Remove(dir) lookup := PathLookup(dir) - if _, err := lookup.Lookup(nil); err == nil { + if _, err := lookup.Lookup(); err == nil { t.Fatal("want error; got nil") } } @@ -37,7 +36,7 @@ func TestPathLookup(t *testing.T) { defer os.Remove(file.Name()) lookup := PathLookup(file.Name()) - rc, err := lookup.Lookup(nil) + rc, err := lookup.Lookup() if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -59,12 +58,12 @@ func TestFlagLookup(t *testing.T) { } defer os.Remove(file.Name()) - lookup := FlagLookup("config") + fs := flag.NewFlagSet(t.Name(), flag.PanicOnError) + fs.String("config", file.Name(), "") - fs := new(testutil.StubFlagSet) - fs.AddFlag("config", file.Name()) + lookup := LookupFlag(fs, "config") - rc, err := lookup.Lookup(fs) + rc, err := lookup.Lookup() if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/parse/file/toml/toml_test.go b/parse/file/toml/toml_test.go index 91697ba..c935ae2 100644 --- a/parse/file/toml/toml_test.go +++ b/parse/file/toml/toml_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/BurntSushi/toml" + "github.com/gobwas/flagutil/parse" "github.com/gobwas/flagutil/parse/file" "github.com/gobwas/flagutil/parse/testutil" diff --git a/parse/file/yaml/yaml_test.go b/parse/file/yaml/yaml_test.go index c1c4092..d786b83 100644 --- a/parse/file/yaml/yaml_test.go +++ b/parse/file/yaml/yaml_test.go @@ -3,10 +3,11 @@ package yaml import ( "testing" + yaml "gopkg.in/yaml.v2" + "github.com/gobwas/flagutil/parse" "github.com/gobwas/flagutil/parse/file" "github.com/gobwas/flagutil/parse/testutil" - yaml "gopkg.in/yaml.v2" ) func TestYAML(t *testing.T) { diff --git a/parse/flagset.go b/parse/flagset.go index 72c15ea..cc3e6d2 100644 --- a/parse/flagset.go +++ b/parse/flagset.go @@ -1,6 +1,9 @@ package parse -import "flag" +import ( + "flag" + "fmt" +) type FlagGetter interface { Lookup(name string) *flag.Flag @@ -25,13 +28,21 @@ func WithIgnoreUndefined(v bool) FlagSetOption { } func NextLevel(fs FlagSet) { - fs.(*flagSet).update() + fset := fs.(*flagSet) + fset.ignore = nil + fset.update() +} + +func Ignore(fs FlagSet, fn func(*flag.Flag) bool) { + fset := fs.(*flagSet) + fset.ignore = fn } type flagSet struct { dest *flag.FlagSet ignoreUndefined bool provided map[string]bool + ignore func(*flag.Flag) bool } func NewFlagSet(flags *flag.FlagSet, opts ...FlagSetOption) FlagSet { @@ -50,13 +61,25 @@ func (fs *flagSet) Set(name, value string) error { if fs.provided[name] { return nil } - defined := fs.dest.Lookup(name) != nil - if !defined && name != "help" && name != "h" && fs.ignoreUndefined { + f := fs.dest.Lookup(name) + if f != nil && fs.ignored(f) { + f = nil + } + defined := f != nil + if !defined && fs.ignoreUndefined { return nil } + if !defined { + return fmt.Errorf("no such flag %q", name) + } return fs.dest.Set(name, value) } +func (fs *flagSet) ignored(f *flag.Flag) bool { + ignore := fs.ignore + return ignore != nil && ignore(f) +} + func (fs *flagSet) update() { fs.dest.Visit(func(f *flag.Flag) { fs.provided[f.Name] = true @@ -65,16 +88,18 @@ func (fs *flagSet) update() { func (fs *flagSet) VisitAll(fn func(*flag.Flag)) { fs.dest.VisitAll(func(f *flag.Flag) { - fn(fs.clone(f)) + if !fs.ignored(f) { + fn(fs.clone(f)) + } }) } func (fs *flagSet) Lookup(name string) *flag.Flag { f := fs.dest.Lookup(name) - if f != nil { - f = fs.clone(f) + if f == nil || fs.ignored(f) { + return nil } - return f + return fs.clone(f) } func (fs *flagSet) clone(f *flag.Flag) *flag.Flag { diff --git a/parse/pargs/posix.go b/parse/pargs/posix.go index 70686a4..cf34e3d 100644 --- a/parse/pargs/posix.go +++ b/parse/pargs/posix.go @@ -5,18 +5,29 @@ import ( "fmt" "strings" + "github.com/gobwas/flagutil" "github.com/gobwas/flagutil/parse" ) type Parser struct { Args []string + // Shorthand specifies whether parser should try to provide shorthand + // version (e.g. just first letter of name) of each top level flag. + Shorthand bool + + // ShorthandFunc allows user to define custom way of picking shorthand + // version of flag with given name. + // Shorthand field must be true when setting ShorthandFunc. + ShorthandFunc func(string) string + pos int err error mult bool name string value string fs parse.FlagSet + alias map[string]string } func (p *Parser) Parse(fs parse.FlagSet) (err error) { @@ -24,6 +35,8 @@ func (p *Parser) Parse(fs parse.FlagSet) (err error) { for p.next() { p.pairs(func(name, value string) bool { + name = p.resolve(name) + // Special case for help request. if fs.Lookup(name) == nil && (name == "help" || name == "h") { err = flag.ErrHelp return false @@ -39,6 +52,43 @@ func (p *Parser) Parse(fs parse.FlagSet) (err error) { return p.err } +func (p *Parser) resolve(name string) string { + if s, has := p.alias[name]; has { + name = s + } + return name +} + +func (p *Parser) Name(fs parse.FlagSet) func(*flag.Flag, func(string)) { + short := p.shorthands(fs) + return func(f *flag.Flag, it func(string)) { + if p.Shorthand { + s := p.shorthand(f) + if _, has := short[s]; has { + it("-" + s) + } + } + var prefix string + if len(f.Name) == 1 { + prefix = "-" + } else { + prefix = "--" + } + it(prefix + f.Name) + } +} + +func (p *Parser) shorthand(f *flag.Flag) string { + if fn := p.ShorthandFunc; fn != nil { + return fn(f.Name) + } + if !isTopSet(f) { + // Not a topmost flag set. + return "" + } + return string(f.Name[0]) +} + func (p *Parser) pairs(fn func(name, value string) bool) { if p.mult { for i := range p.name { @@ -59,9 +109,13 @@ func (p *Parser) reset(fs parse.FlagSet) { p.name = "" p.value = "" p.fs = fs + if p.Shorthand { + p.alias = p.shorthands(fs) + } } func (p *Parser) isBoolFlag(name string) bool { + name = p.resolve(name) f := p.fs.Lookup(name) if f == nil && name == "h" { // Special case for help message request. @@ -136,6 +190,31 @@ func (p *Parser) next() bool { return true } +func (p *Parser) shorthands(fs parse.FlagSet) map[string]string { + short := make(map[string]string) + fs.VisitAll(func(f *flag.Flag) { + s := p.shorthand(f) + if s == "" { + return + } + if _, has := short[s]; has { + // Mark this shorthand name as ambiguous. + short[s] = "" + } else { + short[s] = f.Name + } + }) + for s, n := range short { + if n == "" { + delete(short, s) + } + if fs.Lookup(s) != nil { + delete(short, s) + } + } + return short +} + func (p *Parser) fail(f string, args ...interface{}) { p.err = fmt.Errorf(f, args...) } @@ -185,3 +264,7 @@ func isBoolFlag(f *flag.Flag) bool { }) return ok && x.IsBoolFlag() } + +func isTopSet(f *flag.Flag) bool { + return strings.Index(f.Name, flagutil.SetSeparator) == -1 +} diff --git a/parse/pargs/posix_test.go b/parse/pargs/posix_test.go index aa65abb..078e1a0 100644 --- a/parse/pargs/posix_test.go +++ b/parse/pargs/posix_test.go @@ -3,9 +3,10 @@ package pargs import ( "testing" + "github.com/google/go-cmp/cmp" + "github.com/gobwas/flagutil/parse" "github.com/gobwas/flagutil/parse/testutil" - "github.com/google/go-cmp/cmp" ) func TestPosixParse(t *testing.T) { @@ -15,6 +16,8 @@ func TestPosixParse(t *testing.T) { exp [][2]string err bool flags map[string]bool + + shorthand bool }{ { name: "short basic", @@ -95,6 +98,61 @@ func TestPosixParse(t *testing.T) { }, err: true, }, + + { + name: "shorthand basic", + shorthand: true, + flags: map[string]bool{ + "shorthand": false, + }, + args: []string{ + "-s=foo", + }, + exp: [][2]string{ + {"shorthand", "foo"}, + }, + }, + { + name: "shorthand ambiguous", + shorthand: true, + flags: map[string]bool{ + "some-foo": false, + "some-bar": false, + }, + args: []string{ + "-s=foo", + }, + exp: [][2]string{ + {"s", "foo"}, + }, + }, + { + name: "shorthand collision", + shorthand: true, + flags: map[string]bool{ + "some-foo": false, + "s": false, + }, + args: []string{ + "-s=foo", + }, + exp: [][2]string{ + {"s", "foo"}, + }, + }, + { + name: "shorthand only top", + shorthand: true, + flags: map[string]bool{ + "some.foo": false, + }, + args: []string{ + "-s=foo", + }, + exp: [][2]string{ + {"s", "foo"}, + }, + }, } { t.Run(test.name, func(t *testing.T) { var fs testutil.StubFlagSet @@ -106,7 +164,8 @@ func TestPosixParse(t *testing.T) { } } p := Parser{ - Args: test.args, + Args: test.args, + Shorthand: test.shorthand, } err := p.Parse(&fs) if !test.err && err != nil { diff --git a/parse/testutil/testutil.go b/parse/testutil/testutil.go index da1a867..4210d4b 100644 --- a/parse/testutil/testutil.go +++ b/parse/testutil/testutil.go @@ -7,8 +7,9 @@ import ( "testing" "time" - "github.com/gobwas/flagutil/parse" "github.com/google/go-cmp/cmp" + + "github.com/gobwas/flagutil/parse" ) type stubValue struct {