Skip to content

Commit

Permalink
bytes: add support for base64 encoded flags (#177)
Browse files Browse the repository at this point in the history
Signed-off-by: Gorka Lerchundi Osa <[email protected]>
  • Loading branch information
glerchundi authored and eparis committed Aug 8, 2018
1 parent 3ebe029 commit 9a97c10
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 1 deletion.
104 changes: 104 additions & 0 deletions bytes.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pflag

import (
"encoding/base64"
"encoding/hex"
"fmt"
"strings"
Expand All @@ -9,10 +10,12 @@ import (
// BytesHex adapts []byte for use as a flag. Value of flag is HEX encoded
type bytesHexValue []byte

// String implements pflag.Value.String.
func (bytesHex bytesHexValue) String() string {
return fmt.Sprintf("%X", []byte(bytesHex))
}

// Set implements pflag.Value.Set.
func (bytesHex *bytesHexValue) Set(value string) error {
bin, err := hex.DecodeString(strings.TrimSpace(value))

Expand All @@ -25,6 +28,7 @@ func (bytesHex *bytesHexValue) Set(value string) error {
return nil
}

// Type implements pflag.Value.Type.
func (*bytesHexValue) Type() string {
return "bytesHex"
}
Expand Down Expand Up @@ -103,3 +107,103 @@ func BytesHex(name string, value []byte, usage string) *[]byte {
func BytesHexP(name, shorthand string, value []byte, usage string) *[]byte {
return CommandLine.BytesHexP(name, shorthand, value, usage)
}

// BytesBase64 adapts []byte for use as a flag. Value of flag is Base64 encoded
type bytesBase64Value []byte

// String implements pflag.Value.String.
func (bytesBase64 bytesBase64Value) String() string {
return base64.StdEncoding.EncodeToString([]byte(bytesBase64))
}

// Set implements pflag.Value.Set.
func (bytesBase64 *bytesBase64Value) Set(value string) error {
bin, err := base64.StdEncoding.DecodeString(strings.TrimSpace(value))

if err != nil {
return err
}

*bytesBase64 = bin

return nil
}

// Type implements pflag.Value.Type.
func (*bytesBase64Value) Type() string {
return "bytesBase64"
}

func newBytesBase64Value(val []byte, p *[]byte) *bytesBase64Value {
*p = val
return (*bytesBase64Value)(p)
}

func bytesBase64ValueConv(sval string) (interface{}, error) {

bin, err := base64.StdEncoding.DecodeString(sval)
if err == nil {
return bin, nil
}

return nil, fmt.Errorf("invalid string being converted to Bytes: %s %s", sval, err)
}

// GetBytesBase64 return the []byte value of a flag with the given name
func (f *FlagSet) GetBytesBase64(name string) ([]byte, error) {
val, err := f.getFlagType(name, "bytesBase64", bytesBase64ValueConv)

if err != nil {
return []byte{}, err
}

return val.([]byte), nil
}

// BytesBase64Var defines an []byte flag with specified name, default value, and usage string.
// The argument p points to an []byte variable in which to store the value of the flag.
func (f *FlagSet) BytesBase64Var(p *[]byte, name string, value []byte, usage string) {
f.VarP(newBytesBase64Value(value, p), name, "", usage)
}

// BytesBase64VarP is like BytesBase64Var, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) BytesBase64VarP(p *[]byte, name, shorthand string, value []byte, usage string) {
f.VarP(newBytesBase64Value(value, p), name, shorthand, usage)
}

// BytesBase64Var defines an []byte flag with specified name, default value, and usage string.
// The argument p points to an []byte variable in which to store the value of the flag.
func BytesBase64Var(p *[]byte, name string, value []byte, usage string) {
CommandLine.VarP(newBytesBase64Value(value, p), name, "", usage)
}

// BytesBase64VarP is like BytesBase64Var, but accepts a shorthand letter that can be used after a single dash.
func BytesBase64VarP(p *[]byte, name, shorthand string, value []byte, usage string) {
CommandLine.VarP(newBytesBase64Value(value, p), name, shorthand, usage)
}

// BytesBase64 defines an []byte flag with specified name, default value, and usage string.
// The return value is the address of an []byte variable that stores the value of the flag.
func (f *FlagSet) BytesBase64(name string, value []byte, usage string) *[]byte {
p := new([]byte)
f.BytesBase64VarP(p, name, "", value, usage)
return p
}

// BytesBase64P is like BytesBase64, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) BytesBase64P(name, shorthand string, value []byte, usage string) *[]byte {
p := new([]byte)
f.BytesBase64VarP(p, name, shorthand, value, usage)
return p
}

// BytesBase64 defines an []byte flag with specified name, default value, and usage string.
// The return value is the address of an []byte variable that stores the value of the flag.
func BytesBase64(name string, value []byte, usage string) *[]byte {
return CommandLine.BytesBase64P(name, "", value, usage)
}

// BytesBase64P is like BytesBase64, but accepts a shorthand letter that can be used after a single dash.
func BytesBase64P(name, shorthand string, value []byte, usage string) *[]byte {
return CommandLine.BytesBase64P(name, shorthand, value, usage)
}
64 changes: 63 additions & 1 deletion bytes_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pflag

import (
"encoding/base64"
"fmt"
"os"
"testing"
Expand Down Expand Up @@ -61,7 +62,7 @@ func TestBytesHex(t *testing.T) {
} else if tc.success {
bytesHex, err := f.GetBytesHex("bytes")
if err != nil {
t.Errorf("Got error trying to fetch the IP flag: %v", err)
t.Errorf("Got error trying to fetch the 'bytes' flag: %v", err)
}
if fmt.Sprintf("%X", bytesHex) != tc.expected {
t.Errorf("expected %q, got '%X'", tc.expected, bytesHex)
Expand All @@ -70,3 +71,64 @@ func TestBytesHex(t *testing.T) {
}
}
}

func setUpBytesBase64(bytesBase64 *[]byte) *FlagSet {
f := NewFlagSet("test", ContinueOnError)
f.BytesBase64Var(bytesBase64, "bytes", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in Base64")
f.BytesBase64VarP(bytesBase64, "bytes2", "B", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in Base64")
return f
}

func TestBytesBase64(t *testing.T) {
testCases := []struct {
input string
success bool
expected string
}{
/// Positive cases
{"", true, ""}, // Is empty string OK ?
{"AQ==", true, "AQ=="},

// Negative cases
{"AQ", false, ""}, // Padding removed
{"ï", false, ""}, // non-base64 characters
}

devnull, _ := os.Open(os.DevNull)
os.Stderr = devnull

for i := range testCases {
var bytesBase64 []byte
f := setUpBytesBase64(&bytesBase64)

tc := &testCases[i]

// --bytes
args := []string{
fmt.Sprintf("--bytes=%s", tc.input),
fmt.Sprintf("-B %s", tc.input),
fmt.Sprintf("--bytes2=%s", tc.input),
}

for _, arg := range args {
err := f.Parse([]string{arg})

if err != nil && tc.success == true {
t.Errorf("expected success, got %q", err)
continue
} else if err == nil && tc.success == false {
// bytesBase64, err := f.GetBytesBase64("bytes")
t.Errorf("expected failure while processing %q", tc.input)
continue
} else if tc.success {
bytesBase64, err := f.GetBytesBase64("bytes")
if err != nil {
t.Errorf("Got error trying to fetch the 'bytes' flag: %v", err)
}
if base64.StdEncoding.EncodeToString(bytesBase64) != tc.expected {
t.Errorf("expected %q, got '%X'", tc.expected, bytesBase64)
}
}
}
}
}

0 comments on commit 9a97c10

Please sign in to comment.