From 66ae38088fbe45a8d146e3c97934c8e2fd0d1fcd Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Mon, 17 Jun 2024 22:45:39 +0200 Subject: [PATCH 01/12] plugin: add plugin framework (recipient) Updates #485 --- plugin/plugin.go | 246 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 plugin/plugin.go diff --git a/plugin/plugin.go b/plugin/plugin.go new file mode 100644 index 0000000..4353b26 --- /dev/null +++ b/plugin/plugin.go @@ -0,0 +1,246 @@ +package plugin + +import ( + "bufio" + "flag" + "fmt" + "os" + + "filippo.io/age" + "filippo.io/age/internal/format" +) + +type Plugin struct { + name string + fs *flag.FlagSet + sm *string + + recipient func([]byte) (age.Recipient, error) + idAsRecipient func([]byte) (age.Recipient, error) + identity func([]byte) (age.Identity, error) +} + +func New(name string) (*Plugin, error) { + return &Plugin{name: name}, nil +} + +func (p *Plugin) Name() string { + return p.name +} + +func (p *Plugin) RegisterFlags(fs *flag.FlagSet) { + if fs == nil { + fs = flag.CommandLine + } + p.fs = fs + p.sm = fs.String("age-plugin", "", "age-plugin state machine") +} + +func (p *Plugin) HandleRecipient(f func(data []byte) (age.Recipient, error)) { + if p.recipient != nil { + panic("HandleRecipient called twice") + } + p.recipient = f +} + +func (p *Plugin) HandleIdentityAsRecipient(f func(data []byte) (age.Recipient, error)) { + if p.idAsRecipient != nil { + panic("HandleIdentityAsRecipient called twice") + } + p.idAsRecipient = f +} + +func (p *Plugin) HandleIdentity(f func(data []byte) (age.Identity, error)) { + if p.identity != nil { + panic("HandleIdentity called twice") + } + p.identity = f +} + +func (p *Plugin) Main() { + if p.fs == nil { + p.RegisterFlags(nil) + } + if !p.fs.Parsed() { + p.fs.Parse(os.Args[1:]) + } + if *p.sm == "recipient-v1" { + p.RecipientV1() + } + if *p.sm == "identity-v1" { + p.IdentityV1() + } +} + +func (p *Plugin) RecipientV1() { + if p.recipient == nil { + fatalf("recipient-v1 not supported") + } + + var recipientStrings, identityStrings []string + var fileKeys [][]byte + var supportsLabels bool + + sr := format.NewStanzaReader(bufio.NewReader(os.Stdin)) +ReadLoop: + for { + s, err := sr.ReadStanza() + if err != nil { + fatalf("failed to read stanza: %v", err) + } + + switch s.Type { + case "add-recipient": + expectStanzaWithNoBody(s, 1) + recipientStrings = append(recipientStrings, s.Args[0]) + case "add-identity": + expectStanzaWithNoBody(s, 1) + identityStrings = append(identityStrings, s.Args[0]) + case "extension-labels": + expectStanzaWithNoBody(s, 0) + supportsLabels = true + case "wrap-file-key": + expectStanzaWithBody(s, 0) + fileKeys = append(fileKeys, s.Body) + case "done": + expectStanzaWithNoBody(s, 0) + break ReadLoop + default: + // Unsupported stanzas in uni-directional phases are ignored. + } + } + + if len(recipientStrings)+len(identityStrings) == 0 { + fatalf("no recipients or identities provided") + } + if len(fileKeys) == 0 { + fatalf("no file keys provided") + } + + var recipients, identities []age.Recipient + for i, s := range recipientStrings { + name, data, err := ParseRecipient(s) + if err != nil { + recipientError(sr, i, err) + } + if name != p.name { + recipientError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + } + r, err := p.recipient(data) + if err != nil { + recipientError(sr, i, err) + } + recipients = append(recipients, r) + } + for i, s := range identityStrings { + name, data, err := ParseIdentity(s) + if err != nil { + identityError(sr, i, err) + } + if name != p.name { + identityError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + } + r, err := p.idAsRecipient(data) + if err != nil { + identityError(sr, i, err) + } + identities = append(identities, r) + } + + stanzas := make([][]*age.Stanza, len(fileKeys)) + for i, fk := range fileKeys { + for j, r := range recipients { + ss, err := r.Wrap(fk) + if err != nil { + recipientError(sr, j, err) + } + stanzas[i] = append(stanzas[i], ss...) + } + for j, r := range identities { + ss, err := r.Wrap(fk) + if err != nil { + identityError(sr, j, err) + } + stanzas[i] = append(stanzas[i], ss...) + } + } + _ = supportsLabels // TODO + + for i, ss := range stanzas { + for _, s := range ss { + if err := (&format.Stanza{Type: "recipient-stanza", + Args: append([]string{fmt.Sprint(i), s.Type}, s.Args...), + Body: s.Body}).Marshal(os.Stdout); err != nil { + fatalf("failed to write recipient-stanza: %v", err) + } + expectOk(sr) + } + } + + if err := writeStanza(os.Stdout, "done"); err != nil { + fatalf("failed to write done stanza: %v", err) + } +} + +func (p *Plugin) IdentityV1() { + if p.identity == nil { + fatalf("identity-v1 not supported") + } + panic("not implemented") +} + +func expectStanzaWithNoBody(s *format.Stanza, wantArgs int) { + if len(s.Args) != wantArgs { + fatalf("%s stanza has %d arguments, want %d", s.Type, len(s.Args), wantArgs) + } + if len(s.Body) != 0 { + fatalf("%s stanza has %d bytes of body, want 0", s.Type, len(s.Body)) + } +} + +func expectStanzaWithBody(s *format.Stanza, wantArgs int) { + if len(s.Args) != wantArgs { + fatalf("%s stanza has %d arguments, want %d", s.Type, len(s.Args), wantArgs) + } + if len(s.Body) == 0 { + fatalf("%s stanza has 0 bytes of body, want >0", s.Type) + } +} + +func recipientError(sr *format.StanzaReader, idx int, err error) { + protocolError(sr, []string{"recipient", fmt.Sprint(idx)}, err) +} + +func identityError(sr *format.StanzaReader, idx int, err error) { + protocolError(sr, []string{"identity", fmt.Sprint(idx)}, err) +} + +func internalError(sr *format.StanzaReader, err error) { + protocolError(sr, []string{"internal"}, err) +} + +func protocolError(sr *format.StanzaReader, args []string, err error) { + s := &format.Stanza{Type: "error", Args: args} + s.Body = []byte(err.Error()) + if err := s.Marshal(os.Stdout); err != nil { + fatalf("failed to write error stanza: %v", err) + } + expectOk(sr) + os.Exit(3) +} + +func expectOk(sr *format.StanzaReader) { + ok, err := sr.ReadStanza() + if err != nil { + fatalf("failed to read OK stanza: %v", err) + } + if ok.Type != "ok" { + fatalf("expected OK stanza, got %q", ok.Type) + } + expectStanzaWithNoBody(ok, 0) +} + +func fatalf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, format, args...) + os.Exit(1) +} From 30feadaaeff255ad7d9fd2bba7362ca4324c7110 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Mon, 17 Jun 2024 22:58:05 +0200 Subject: [PATCH 02/12] plugin: implement labels (recipient framework) Updates #485 --- plugin/plugin.go | 46 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/plugin/plugin.go b/plugin/plugin.go index 4353b26..6dde450 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -147,24 +147,44 @@ ReadLoop: identities = append(identities, r) } + // Technically labels should be per-file key, but the client-side protocol + // extension shipped like this, and it doesn't feel worth making a v2. + var labels []string + stanzas := make([][]*age.Stanza, len(fileKeys)) for i, fk := range fileKeys { for j, r := range recipients { - ss, err := r.Wrap(fk) + ss, ll, err := wrapWithLabels(r, fk) if err != nil { recipientError(sr, j, err) } + if i == 0 && j == 0 { + labels = ll + } else if !slicesEqual(labels, ll) { + recipientError(sr, j, fmt.Errorf("labels %q do not match previous recipients %q", ll, labels)) + } stanzas[i] = append(stanzas[i], ss...) } for j, r := range identities { - ss, err := r.Wrap(fk) + ss, ll, err := wrapWithLabels(r, fk) if err != nil { identityError(sr, j, err) } + if i == 0 && j == 0 && len(recipients) == 0 { + labels = ll + } else if !slicesEqual(labels, ll) { + identityError(sr, j, fmt.Errorf("labels %q do not match previous recipients %q", ll, labels)) + } stanzas[i] = append(stanzas[i], ss...) } } - _ = supportsLabels // TODO + + if supportsLabels { + if err := writeStanza(os.Stdout, "labels", labels...); err != nil { + fatalf("failed to write labels stanza: %v", err) + } + expectOk(sr) + } for i, ss := range stanzas { for _, s := range ss { @@ -182,6 +202,14 @@ ReadLoop: } } +func wrapWithLabels(r age.Recipient, fileKey []byte) ([]*age.Stanza, []string, error) { + if r, ok := r.(age.RecipientWithLabels); ok { + return r.WrapWithLabels(fileKey) + } + s, err := r.Wrap(fileKey) + return s, nil, err +} + func (p *Plugin) IdentityV1() { if p.identity == nil { fatalf("identity-v1 not supported") @@ -244,3 +272,15 @@ func fatalf(format string, args ...interface{}) { fmt.Fprintf(os.Stderr, format, args...) os.Exit(1) } + +func slicesEqual(s1, s2 []string) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if s1[i] != s2[i] { + return false + } + } + return true +} From 17332799cc76caeec8408eeb872cd43ae232f0a3 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Mon, 17 Jun 2024 23:10:57 +0200 Subject: [PATCH 03/12] plugin: use framework in tests --- plugin/client_test.go | 78 ++++++++++++++++--------------------------- 1 file changed, 28 insertions(+), 50 deletions(-) diff --git a/plugin/client_test.go b/plugin/client_test.go index fc28789..03ddab0 100644 --- a/plugin/client_test.go +++ b/plugin/client_test.go @@ -7,7 +7,6 @@ package plugin import ( - "bufio" "io" "os" "path/filepath" @@ -22,61 +21,40 @@ func TestMain(m *testing.M) { switch filepath.Base(os.Args[0]) { // TODO: deduplicate from cmd/age TestMain. case "age-plugin-test": - switch os.Args[1] { - case "--age-plugin=recipient-v1": - scanner := bufio.NewScanner(os.Stdin) - scanner.Scan() // add-recipient - scanner.Scan() // body - scanner.Scan() // grease - scanner.Scan() // body - scanner.Scan() // wrap-file-key - scanner.Scan() // body - fileKey := scanner.Text() - scanner.Scan() // extension-labels - scanner.Scan() // body - scanner.Scan() // done - scanner.Scan() // body - os.Stdout.WriteString("-> recipient-stanza 0 test\n") - os.Stdout.WriteString(fileKey + "\n") - scanner.Scan() // ok - scanner.Scan() // body - os.Stdout.WriteString("-> done\n\n") - os.Exit(0) - default: - panic(os.Args[1]) - } + p, _ := New("test") + p.HandleRecipient(func(data []byte) (age.Recipient, error) { + return testRecipient{}, nil + }) + p.Main() case "age-plugin-testpqc": - switch os.Args[1] { - case "--age-plugin=recipient-v1": - scanner := bufio.NewScanner(os.Stdin) - scanner.Scan() // add-recipient - scanner.Scan() // body - scanner.Scan() // grease - scanner.Scan() // body - scanner.Scan() // wrap-file-key - scanner.Scan() // body - fileKey := scanner.Text() - scanner.Scan() // extension-labels - scanner.Scan() // body - scanner.Scan() // done - scanner.Scan() // body - os.Stdout.WriteString("-> recipient-stanza 0 test\n") - os.Stdout.WriteString(fileKey + "\n") - scanner.Scan() // ok - scanner.Scan() // body - os.Stdout.WriteString("-> labels postquantum\n\n") - scanner.Scan() // ok - scanner.Scan() // body - os.Stdout.WriteString("-> done\n\n") - os.Exit(0) - default: - panic(os.Args[1]) - } + p, _ := New("testpqc") + p.HandleRecipient(func(data []byte) (age.Recipient, error) { + return testPQCRecipient{}, nil + }) + p.Main() default: os.Exit(m.Run()) } } +type testRecipient struct{} + +func (testRecipient) Wrap(fileKey []byte) ([]*age.Stanza, error) { + return []*age.Stanza{{Type: "test", Body: fileKey}}, nil +} + +type testPQCRecipient struct{} + +var _ age.RecipientWithLabels = testPQCRecipient{} + +func (testPQCRecipient) Wrap(fileKey []byte) ([]*age.Stanza, error) { + return []*age.Stanza{{Type: "test", Body: fileKey}}, nil +} + +func (testPQCRecipient) WrapWithLabels(fileKey []byte) ([]*age.Stanza, []string, error) { + return []*age.Stanza{{Type: "test", Body: fileKey}}, []string{"postquantum"}, nil +} + func TestLabels(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Windows support is TODO") From 35b060b0dd18b6ae83559c36868d4f2b43515ee7 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Tue, 18 Jun 2024 11:37:17 +0200 Subject: [PATCH 04/12] plugin: implement identities in framework Updates #485 --- cmd/age/age_test.go | 64 ++++------- plugin/client_test.go | 5 +- plugin/plugin.go | 262 ++++++++++++++++++++++++++++++++---------- 3 files changed, 224 insertions(+), 107 deletions(-) diff --git a/cmd/age/age_test.go b/cmd/age/age_test.go index 9291829..3541232 100644 --- a/cmd/age/age_test.go +++ b/cmd/age/age_test.go @@ -5,11 +5,11 @@ package main import ( - "bufio" "os" "testing" "filippo.io/age" + "filippo.io/age/plugin" "github.com/rogpeppe/go-internal/testscript" ) @@ -30,51 +30,31 @@ func TestMain(m *testing.M) { return 0 }, "age-plugin-test": func() (exitCode int) { - // TODO: use plugin server package once it's available. - switch os.Args[1] { - case "--age-plugin=recipient-v1": - scanner := bufio.NewScanner(os.Stdin) - scanner.Scan() // add-recipient - scanner.Scan() // body - scanner.Scan() // grease - scanner.Scan() // body - scanner.Scan() // wrap-file-key - scanner.Scan() // body - fileKey := scanner.Text() - scanner.Scan() // extension-labels - scanner.Scan() // body - scanner.Scan() // done - scanner.Scan() // body - os.Stdout.WriteString("-> recipient-stanza 0 test\n") - os.Stdout.WriteString(fileKey + "\n") - scanner.Scan() // ok - scanner.Scan() // body - os.Stdout.WriteString("-> done\n\n") - return 0 - case "--age-plugin=identity-v1": - scanner := bufio.NewScanner(os.Stdin) - scanner.Scan() // add-identity - scanner.Scan() // body - scanner.Scan() // grease - scanner.Scan() // body - scanner.Scan() // recipient-stanza - scanner.Scan() // body - fileKey := scanner.Text() - scanner.Scan() // done - scanner.Scan() // body - os.Stdout.WriteString("-> file-key 0\n") - os.Stdout.WriteString(fileKey + "\n") - scanner.Scan() // ok - scanner.Scan() // body - os.Stdout.WriteString("-> done\n\n") - return 0 - default: - return 1 - } + p, _ := plugin.New("test") + p.HandleRecipient(func(data []byte) (age.Recipient, error) { + return testPlugin{}, nil + }) + p.HandleIdentity(func(data []byte) (age.Identity, error) { + return testPlugin{}, nil + }) + return p.Main() }, })) } +type testPlugin struct{} + +func (testPlugin) Wrap(fileKey []byte) ([]*age.Stanza, error) { + return []*age.Stanza{{Type: "test", Body: fileKey}}, nil +} + +func (testPlugin) Unwrap(ss []*age.Stanza) ([]byte, error) { + if len(ss) == 1 && ss[0].Type == "test" { + return ss[0].Body, nil + } + return nil, age.ErrIncorrectIdentity +} + func TestScript(t *testing.T) { testscript.Run(t, testscript.Params{ Dir: "testdata", diff --git a/plugin/client_test.go b/plugin/client_test.go index 03ddab0..ef08782 100644 --- a/plugin/client_test.go +++ b/plugin/client_test.go @@ -19,19 +19,18 @@ import ( func TestMain(m *testing.M) { switch filepath.Base(os.Args[0]) { - // TODO: deduplicate from cmd/age TestMain. case "age-plugin-test": p, _ := New("test") p.HandleRecipient(func(data []byte) (age.Recipient, error) { return testRecipient{}, nil }) - p.Main() + os.Exit(p.Main()) case "age-plugin-testpqc": p, _ := New("testpqc") p.HandleRecipient(func(data []byte) (age.Recipient, error) { return testPQCRecipient{}, nil }) - p.Main() + os.Exit(p.Main()) default: os.Exit(m.Run()) } diff --git a/plugin/plugin.go b/plugin/plugin.go index 6dde450..e246f3f 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -2,9 +2,11 @@ package plugin import ( "bufio" + "errors" "flag" "fmt" "os" + "strconv" "filippo.io/age" "filippo.io/age/internal/format" @@ -57,7 +59,7 @@ func (p *Plugin) HandleIdentity(f func(data []byte) (age.Identity, error)) { p.identity = f } -func (p *Plugin) Main() { +func (p *Plugin) Main() int { if p.fs == nil { p.RegisterFlags(nil) } @@ -65,16 +67,17 @@ func (p *Plugin) Main() { p.fs.Parse(os.Args[1:]) } if *p.sm == "recipient-v1" { - p.RecipientV1() + return p.RecipientV1() } if *p.sm == "identity-v1" { - p.IdentityV1() + return p.IdentityV1() } + return fatalf("unknown state machine %q", *p.sm) } -func (p *Plugin) RecipientV1() { - if p.recipient == nil { - fatalf("recipient-v1 not supported") +func (p *Plugin) RecipientV1() int { + if p.recipient == nil && p.idAsRecipient == nil { + return fatalf("recipient-v1 not supported") } var recipientStrings, identityStrings []string @@ -86,24 +89,34 @@ ReadLoop: for { s, err := sr.ReadStanza() if err != nil { - fatalf("failed to read stanza: %v", err) + return fatalf("failed to read stanza: %v", err) } switch s.Type { case "add-recipient": - expectStanzaWithNoBody(s, 1) + if err := expectStanzaWithNoBody(s, 1); err != nil { + return fatalf("%v", err) + } recipientStrings = append(recipientStrings, s.Args[0]) case "add-identity": - expectStanzaWithNoBody(s, 1) + if err := expectStanzaWithNoBody(s, 1); err != nil { + return fatalf("%v", err) + } identityStrings = append(identityStrings, s.Args[0]) case "extension-labels": - expectStanzaWithNoBody(s, 0) + if err := expectStanzaWithNoBody(s, 0); err != nil { + return fatalf("%v", err) + } supportsLabels = true case "wrap-file-key": - expectStanzaWithBody(s, 0) + if err := expectStanzaWithBody(s, 0); err != nil { + return fatalf("%v", err) + } fileKeys = append(fileKeys, s.Body) case "done": - expectStanzaWithNoBody(s, 0) + if err := expectStanzaWithNoBody(s, 0); err != nil { + return fatalf("%v", err) + } break ReadLoop default: // Unsupported stanzas in uni-directional phases are ignored. @@ -111,38 +124,44 @@ ReadLoop: } if len(recipientStrings)+len(identityStrings) == 0 { - fatalf("no recipients or identities provided") + return fatalf("no recipients or identities provided") } if len(fileKeys) == 0 { - fatalf("no file keys provided") + return fatalf("no file keys provided") } var recipients, identities []age.Recipient for i, s := range recipientStrings { name, data, err := ParseRecipient(s) if err != nil { - recipientError(sr, i, err) + return recipientError(sr, i, err) } if name != p.name { - recipientError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return recipientError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + } + if p.recipient == nil { + return recipientError(sr, i, fmt.Errorf("recipient encodings not supported")) } r, err := p.recipient(data) if err != nil { - recipientError(sr, i, err) + return recipientError(sr, i, err) } recipients = append(recipients, r) } for i, s := range identityStrings { name, data, err := ParseIdentity(s) if err != nil { - identityError(sr, i, err) + return identityError(sr, i, err) } if name != p.name { - identityError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return identityError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + } + if p.idAsRecipient == nil { + return identityError(sr, i, fmt.Errorf("identity encodings not supported")) } r, err := p.idAsRecipient(data) if err != nil { - identityError(sr, i, err) + return identityError(sr, i, err) } identities = append(identities, r) } @@ -156,24 +175,24 @@ ReadLoop: for j, r := range recipients { ss, ll, err := wrapWithLabels(r, fk) if err != nil { - recipientError(sr, j, err) + return recipientError(sr, j, err) } if i == 0 && j == 0 { labels = ll - } else if !slicesEqual(labels, ll) { - recipientError(sr, j, fmt.Errorf("labels %q do not match previous recipients %q", ll, labels)) + } else if err := checkLabels(ll, labels); err != nil { + return recipientError(sr, j, err) } stanzas[i] = append(stanzas[i], ss...) } for j, r := range identities { ss, ll, err := wrapWithLabels(r, fk) if err != nil { - identityError(sr, j, err) + return identityError(sr, j, err) } if i == 0 && j == 0 && len(recipients) == 0 { labels = ll - } else if !slicesEqual(labels, ll) { - identityError(sr, j, fmt.Errorf("labels %q do not match previous recipients %q", ll, labels)) + } else if err := checkLabels(ll, labels); err != nil { + return identityError(sr, j, err) } stanzas[i] = append(stanzas[i], ss...) } @@ -181,9 +200,11 @@ ReadLoop: if supportsLabels { if err := writeStanza(os.Stdout, "labels", labels...); err != nil { - fatalf("failed to write labels stanza: %v", err) + return fatalf("failed to write labels stanza: %v", err) + } + if err := expectOk(sr); err != nil { + return fatalf("%v", err) } - expectOk(sr) } for i, ss := range stanzas { @@ -191,15 +212,18 @@ ReadLoop: if err := (&format.Stanza{Type: "recipient-stanza", Args: append([]string{fmt.Sprint(i), s.Type}, s.Args...), Body: s.Body}).Marshal(os.Stdout); err != nil { - fatalf("failed to write recipient-stanza: %v", err) + return fatalf("failed to write recipient-stanza: %v", err) + } + if err := expectOk(sr); err != nil { + return fatalf("%v", err) } - expectOk(sr) } } if err := writeStanza(os.Stdout, "done"); err != nil { - fatalf("failed to write done stanza: %v", err) + return fatalf("failed to write done stanza: %v", err) } + return 0 } func wrapWithLabels(r age.Recipient, fileKey []byte) ([]*age.Stanza, []string, error) { @@ -210,67 +234,181 @@ func wrapWithLabels(r age.Recipient, fileKey []byte) ([]*age.Stanza, []string, e return s, nil, err } -func (p *Plugin) IdentityV1() { +func checkLabels(ll, labels []string) error { + if !slicesEqual(ll, labels) { + return fmt.Errorf("labels %q do not match previous recipients %q", ll, labels) + } + return nil +} + +func (p *Plugin) IdentityV1() int { if p.identity == nil { - fatalf("identity-v1 not supported") + return fatalf("identity-v1 not supported") + } + + var files [][]*age.Stanza + var identityStrings []string + + sr := format.NewStanzaReader(bufio.NewReader(os.Stdin)) +ReadLoop: + for { + s, err := sr.ReadStanza() + if err != nil { + return fatalf("failed to read stanza: %v", err) + } + + switch s.Type { + case "add-identity": + if err := expectStanzaWithNoBody(s, 1); err != nil { + return fatalf("%v", err) + } + identityStrings = append(identityStrings, s.Args[0]) + case "recipient-stanza": + if len(s.Args) < 2 { + return fatalf("recipient-stanza stanza has %d arguments, want >=2", len(s.Args)) + } + i, err := strconv.Atoi(s.Args[0]) + if err != nil { + return fatalf("failed to parse recipient-stanza stanza argument: %v", err) + } + ss := &age.Stanza{Type: s.Args[1], Args: s.Args[2:], Body: s.Body} + switch i { + case len(files): + files = append(files, []*age.Stanza{ss}) + case len(files) - 1: + files[len(files)-1] = append(files[len(files)-1], ss) + default: + return fatalf("unexpected file index %d, previous was %d", i, len(files)-1) + } + case "done": + if err := expectStanzaWithNoBody(s, 0); err != nil { + return fatalf("%v", err) + } + break ReadLoop + default: + // Unsupported stanzas in uni-directional phases are ignored. + } + } + + if len(identityStrings) == 0 { + return fatalf("no identities provided") + } + if len(files) == 0 { + return fatalf("no stanzas provided") } - panic("not implemented") + + var identities []age.Identity + for i, s := range identityStrings { + name, data, err := ParseIdentity(s) + if err != nil { + return identityError(sr, i, err) + } + if name != p.name { + return identityError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + } + if p.identity == nil { + return identityError(sr, i, fmt.Errorf("identity encodings not supported")) + } + r, err := p.identity(data) + if err != nil { + return identityError(sr, i, err) + } + identities = append(identities, r) + } + + for i, ss := range files { + // TODO: there should be a mechanism to let the plugin decide the order + // in which identities are tried. + for _, id := range identities { + fk, err := id.Unwrap(ss) + if errors.Is(err, age.ErrIncorrectIdentity) { + continue + } else if err != nil { + if err := writeError(sr, []string{"stanza", fmt.Sprint(i), "0"}, err); err != nil { + return fatalf("%v", err) + } + // Note that we don't exit here, as the protocol allows + // continuing with other files. + break + } + + s := &format.Stanza{Type: "file-key", Args: []string{fmt.Sprint(i)}, Body: fk} + if err := s.Marshal(os.Stdout); err != nil { + return fatalf("failed to write file-key: %v", err) + } + if err := expectOk(sr); err != nil { + return fatalf("%v", err) + } + break + } + } + + if err := writeStanza(os.Stdout, "done"); err != nil { + return fatalf("failed to write done stanza: %v", err) + } + return 0 } -func expectStanzaWithNoBody(s *format.Stanza, wantArgs int) { +func expectStanzaWithNoBody(s *format.Stanza, wantArgs int) error { if len(s.Args) != wantArgs { - fatalf("%s stanza has %d arguments, want %d", s.Type, len(s.Args), wantArgs) + return fmt.Errorf("%s stanza has %d arguments, want %d", s.Type, len(s.Args), wantArgs) } if len(s.Body) != 0 { - fatalf("%s stanza has %d bytes of body, want 0", s.Type, len(s.Body)) + return fmt.Errorf("%s stanza has %d bytes of body, want 0", s.Type, len(s.Body)) } + return nil } -func expectStanzaWithBody(s *format.Stanza, wantArgs int) { +func expectStanzaWithBody(s *format.Stanza, wantArgs int) error { if len(s.Args) != wantArgs { - fatalf("%s stanza has %d arguments, want %d", s.Type, len(s.Args), wantArgs) + return fmt.Errorf("%s stanza has %d arguments, want %d", s.Type, len(s.Args), wantArgs) } if len(s.Body) == 0 { - fatalf("%s stanza has 0 bytes of body, want >0", s.Type) + return fmt.Errorf("%s stanza has 0 bytes of body, want >0", s.Type) } + return nil } -func recipientError(sr *format.StanzaReader, idx int, err error) { - protocolError(sr, []string{"recipient", fmt.Sprint(idx)}, err) +func recipientError(sr *format.StanzaReader, idx int, err error) int { + if err := writeError(sr, []string{"recipient", fmt.Sprint(idx)}, err); err != nil { + return fatalf("%v", err) + } + return 3 } -func identityError(sr *format.StanzaReader, idx int, err error) { - protocolError(sr, []string{"identity", fmt.Sprint(idx)}, err) +func identityError(sr *format.StanzaReader, idx int, err error) int { + if err := writeError(sr, []string{"identity", fmt.Sprint(idx)}, err); err != nil { + return fatalf("%v", err) + } + return 3 } -func internalError(sr *format.StanzaReader, err error) { - protocolError(sr, []string{"internal"}, err) +func expectOk(sr *format.StanzaReader) error { + ok, err := sr.ReadStanza() + if err != nil { + return fmt.Errorf("failed to read OK stanza: %v", err) + } + if ok.Type != "ok" { + return fmt.Errorf("expected OK stanza, got %q", ok.Type) + } + return expectStanzaWithNoBody(ok, 0) } -func protocolError(sr *format.StanzaReader, args []string, err error) { +func writeError(sr *format.StanzaReader, args []string, err error) error { s := &format.Stanza{Type: "error", Args: args} s.Body = []byte(err.Error()) if err := s.Marshal(os.Stdout); err != nil { - fatalf("failed to write error stanza: %v", err) + return fmt.Errorf("failed to write error stanza: %v", err) } - expectOk(sr) - os.Exit(3) -} - -func expectOk(sr *format.StanzaReader) { - ok, err := sr.ReadStanza() - if err != nil { - fatalf("failed to read OK stanza: %v", err) - } - if ok.Type != "ok" { - fatalf("expected OK stanza, got %q", ok.Type) + if err := expectOk(sr); err != nil { + return fmt.Errorf("%v", err) } - expectStanzaWithNoBody(ok, 0) + return nil } -func fatalf(format string, args ...interface{}) { +func fatalf(format string, args ...interface{}) int { fmt.Fprintf(os.Stderr, format, args...) - os.Exit(1) + return 1 } func slicesEqual(s1, s2 []string) bool { From f80cf4d627672d95d72dcb9f577821e4335302e6 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Tue, 18 Jun 2024 12:01:11 +0200 Subject: [PATCH 05/12] plugin: document framework --- plugin/plugin.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/plugin/plugin.go b/plugin/plugin.go index e246f3f..5a265d8 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -12,6 +12,18 @@ import ( "filippo.io/age/internal/format" ) +// TODO: implement interaction methods. +// +// // Can only be used during a Wrap or Unwrap invoked by Plugin. +// func (*Plugin) DisplayMessage(message string) error +// func (*Plugin) RequestValue(prompt string, secret bool) (string, error) +// func (*Plugin) Confirm(prompt, yes, no string) (choseYes bool, err error) + +// TODO: add examples. + +// Plugin is a framework for writing age plugins. It allows exposing regular +// [age.Recipient] and [age.Identity] implementations as plugins, and handles +// all the protocol details. type Plugin struct { name string fs *flag.FlagSet @@ -22,14 +34,24 @@ type Plugin struct { identity func([]byte) (age.Identity, error) } +// New creates a new Plugin with the given name. +// +// For example, a plugin named "frood" would be invoked as "age-plugin-frood". func New(name string) (*Plugin, error) { return &Plugin{name: name}, nil } +// Name returns the name of the plugin. func (p *Plugin) Name() string { return p.name } +// RegisterFlags registers the plugin's flags with the given [flag.FlagSet], or +// with the default [flag.CommandLine] if fs is nil. It must be called before +// [flag.Parse] and [Plugin.Main]. +// +// This allows the plugin to expose additional flags when invoked manually, for +// example to implement a keygen mode. func (p *Plugin) RegisterFlags(fs *flag.FlagSet) { if fs == nil { fs = flag.CommandLine @@ -38,6 +60,14 @@ func (p *Plugin) RegisterFlags(fs *flag.FlagSet) { p.sm = fs.String("age-plugin", "", "age-plugin state machine") } +// HandleRecipient registers a function to parse recipients of the form +// age1name1... into [age.Recipient] values. data is the decoded Bech32 payload. +// +// If the returned Recipient implements [age.RecipientWithLabels], Plugin will +// use it and enforce consistency across every returned stanza in an execution. +// If the client supports labels, they will be passed through the protocol. +// +// It must be called before [Plugin.Main], and can be called at most once. func (p *Plugin) HandleRecipient(f func(data []byte) (age.Recipient, error)) { if p.recipient != nil { panic("HandleRecipient called twice") @@ -45,6 +75,15 @@ func (p *Plugin) HandleRecipient(f func(data []byte) (age.Recipient, error)) { p.recipient = f } +// HandleIdentityAsRecipient registers a function to parse identities of the +// form AGE-PLUGIN-NAME-1... into [age.Recipient] values, for when identities +// are used as recipients. data is the decoded Bech32 payload. +// +// If the returned Recipient implements [age.RecipientWithLabels], Plugin will +// use it and enforce consistency across every returned stanza in an execution. +// If the client supports labels, they will be passed through the protocol. +// +// It must be called before [Plugin.Main], and can be called at most once. func (p *Plugin) HandleIdentityAsRecipient(f func(data []byte) (age.Recipient, error)) { if p.idAsRecipient != nil { panic("HandleIdentityAsRecipient called twice") @@ -52,6 +91,11 @@ func (p *Plugin) HandleIdentityAsRecipient(f func(data []byte) (age.Recipient, e p.idAsRecipient = f } +// HandleIdentity registers a function to parse identities of the form +// AGE-PLUGIN-NAME-1... into [age.Identity] values. data is the decoded Bech32 +// payload. +// +// It must be called before [Plugin.Main], and can be called at most once. func (p *Plugin) HandleIdentity(f func(data []byte) (age.Identity, error)) { if p.identity != nil { panic("HandleIdentity called twice") @@ -59,6 +103,11 @@ func (p *Plugin) HandleIdentity(f func(data []byte) (age.Identity, error)) { p.identity = f } +// Main runs the plugin protocol over stdin/stdout, and writes errors to stderr. +// It returns an exit code to pass to os.Exit. +// +// It automatically calls [Plugin.RegisterFlags] and [flag.Parse] if they were +// not called before. func (p *Plugin) Main() int { if p.fs == nil { p.RegisterFlags(nil) @@ -75,6 +124,10 @@ func (p *Plugin) Main() int { return fatalf("unknown state machine %q", *p.sm) } +// RecipientV1 implements the recipient-v1 state machine over stdin/stdout, and +// writes errors to stderr. It returns an exit code to pass to os.Exit. +// +// Most plugins should call [Plugin.Main] instead of this method. func (p *Plugin) RecipientV1() int { if p.recipient == nil && p.idAsRecipient == nil { return fatalf("recipient-v1 not supported") @@ -241,6 +294,10 @@ func checkLabels(ll, labels []string) error { return nil } +// IdentityV1 implements the identity-v1 state machine over stdin/stdout, and +// writes errors to stderr. It returns an exit code to pass to os.Exit. +// +// Most plugins should call [Plugin.Main] instead of this method. func (p *Plugin) IdentityV1() int { if p.identity == nil { return fatalf("identity-v1 not supported") From 0fbe2ac987642140c6a602b167e01c9ce42a0732 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Tue, 18 Jun 2024 12:58:37 +0200 Subject: [PATCH 06/12] plugin: expand grease --- plugin/client.go | 19 +++++++++++++++++-- plugin/plugin.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/plugin/client.go b/plugin/client.go index dca1a52..5e96006 100644 --- a/plugin/client.go +++ b/plugin/client.go @@ -79,7 +79,7 @@ func (r *Recipient) WrapWithLabels(fileKey []byte) (stanzas []*age.Stanza, label if err := writeStanza(conn, addType, r.encoding); err != nil { return nil, nil, err } - if err := writeStanza(conn, fmt.Sprintf("grease-%x", rand.Int())); err != nil { + if _, err := writeGrease(conn); err != nil { return nil, nil, err } if err := writeStanzaWithBody(conn, "wrap-file-key", fileKey); err != nil { @@ -220,7 +220,7 @@ func (i *Identity) Unwrap(stanzas []*age.Stanza) (fileKey []byte, err error) { if err := writeStanza(conn, "add-identity", i.encoding); err != nil { return nil, err } - if err := writeStanza(conn, fmt.Sprintf("grease-%x", rand.Int())); err != nil { + if _, err := writeGrease(conn); err != nil { return nil, err } for _, rs := range stanzas { @@ -449,3 +449,18 @@ func writeStanzaWithBody(conn io.Writer, t string, body []byte) error { s := &format.Stanza{Type: t, Body: body} return s.Marshal(conn) } + +func writeGrease(conn io.Writer) (sent bool, err error) { + if rand.Intn(3) == 0 { + return false, nil + } + s := &format.Stanza{Type: fmt.Sprintf("grease-%x", rand.Int())} + for i := 0; i < rand.Intn(3); i++ { + s.Args = append(s.Args, fmt.Sprintf("%d", rand.Intn(100))) + } + if rand.Intn(2) == 0 { + s.Body = make([]byte, rand.Intn(100)) + rand.Read(s.Body) + } + return true, s.Marshal(conn) +} diff --git a/plugin/plugin.go b/plugin/plugin.go index 5a265d8..55c697b 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -251,6 +251,14 @@ ReadLoop: } } + if sent, err := writeGrease(os.Stdout); err != nil { + return fatalf("failed to write grease: %v", err) + } else if sent { + if err := expectUnsupported(sr); err != nil { + return fatalf("%v", err) + } + } + if supportsLabels { if err := writeStanza(os.Stdout, "labels", labels...); err != nil { return fatalf("failed to write labels stanza: %v", err) @@ -271,6 +279,13 @@ ReadLoop: return fatalf("%v", err) } } + if sent, err := writeGrease(os.Stdout); err != nil { + return fatalf("failed to write grease: %v", err) + } else if sent { + if err := expectUnsupported(sr); err != nil { + return fatalf("%v", err) + } + } } if err := writeStanza(os.Stdout, "done"); err != nil { @@ -374,6 +389,14 @@ ReadLoop: } for i, ss := range files { + if sent, err := writeGrease(os.Stdout); err != nil { + return fatalf("failed to write grease: %v", err) + } else if sent { + if err := expectUnsupported(sr); err != nil { + return fatalf("%v", err) + } + } + // TODO: there should be a mechanism to let the plugin decide the order // in which identities are tried. for _, id := range identities { @@ -451,6 +474,17 @@ func expectOk(sr *format.StanzaReader) error { return expectStanzaWithNoBody(ok, 0) } +func expectUnsupported(sr *format.StanzaReader) error { + unsupported, err := sr.ReadStanza() + if err != nil { + return fmt.Errorf("failed to read unsupported stanza: %v", err) + } + if unsupported.Type != "unsupported" { + return fmt.Errorf("expected unsupported stanza, got %q", unsupported.Type) + } + return expectStanzaWithNoBody(unsupported, 0) +} + func writeError(sr *format.StanzaReader, args []string, err error) error { s := &format.Stanza{Type: "error", Args: args} s.Body = []byte(err.Error()) From 7eedd929a6cfdcbfec1c9e9cfb2a539c9ea59aba Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Tue, 18 Jun 2024 15:18:52 +0200 Subject: [PATCH 07/12] plugin: add interactivity methods to framework --- plugin/plugin.go | 180 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 143 insertions(+), 37 deletions(-) diff --git a/plugin/plugin.go b/plugin/plugin.go index 55c697b..e042003 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -2,6 +2,7 @@ package plugin import ( "bufio" + "encoding/base64" "errors" "flag" "fmt" @@ -12,13 +13,6 @@ import ( "filippo.io/age/internal/format" ) -// TODO: implement interaction methods. -// -// // Can only be used during a Wrap or Unwrap invoked by Plugin. -// func (*Plugin) DisplayMessage(message string) error -// func (*Plugin) RequestValue(prompt string, secret bool) (string, error) -// func (*Plugin) Confirm(prompt, yes, no string) (choseYes bool, err error) - // TODO: add examples. // Plugin is a framework for writing age plugins. It allows exposing regular @@ -32,6 +26,11 @@ type Plugin struct { recipient func([]byte) (age.Recipient, error) idAsRecipient func([]byte) (age.Recipient, error) identity func([]byte) (age.Identity, error) + + sr *format.StanzaReader + // broken is set if the protocol broke down during an interaction function + // called by a Recipient or Identity. + broken bool } // New creates a new Plugin with the given name. @@ -137,10 +136,10 @@ func (p *Plugin) RecipientV1() int { var fileKeys [][]byte var supportsLabels bool - sr := format.NewStanzaReader(bufio.NewReader(os.Stdin)) + p.sr = format.NewStanzaReader(bufio.NewReader(os.Stdin)) ReadLoop: for { - s, err := sr.ReadStanza() + s, err := p.sr.ReadStanza() if err != nil { return fatalf("failed to read stanza: %v", err) } @@ -187,34 +186,34 @@ ReadLoop: for i, s := range recipientStrings { name, data, err := ParseRecipient(s) if err != nil { - return recipientError(sr, i, err) + return recipientError(p.sr, i, err) } if name != p.name { - return recipientError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return recipientError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name)) } if p.recipient == nil { - return recipientError(sr, i, fmt.Errorf("recipient encodings not supported")) + return recipientError(p.sr, i, fmt.Errorf("recipient encodings not supported")) } r, err := p.recipient(data) if err != nil { - return recipientError(sr, i, err) + return recipientError(p.sr, i, err) } recipients = append(recipients, r) } for i, s := range identityStrings { name, data, err := ParseIdentity(s) if err != nil { - return identityError(sr, i, err) + return identityError(p.sr, i, err) } if name != p.name { - return identityError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return identityError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name)) } if p.idAsRecipient == nil { - return identityError(sr, i, fmt.Errorf("identity encodings not supported")) + return identityError(p.sr, i, fmt.Errorf("identity encodings not supported")) } r, err := p.idAsRecipient(data) if err != nil { - return identityError(sr, i, err) + return identityError(p.sr, i, err) } identities = append(identities, r) } @@ -227,25 +226,29 @@ ReadLoop: for i, fk := range fileKeys { for j, r := range recipients { ss, ll, err := wrapWithLabels(r, fk) - if err != nil { - return recipientError(sr, j, err) + if p.broken { + return 2 + } else if err != nil { + return recipientError(p.sr, j, err) } if i == 0 && j == 0 { labels = ll } else if err := checkLabels(ll, labels); err != nil { - return recipientError(sr, j, err) + return recipientError(p.sr, j, err) } stanzas[i] = append(stanzas[i], ss...) } for j, r := range identities { ss, ll, err := wrapWithLabels(r, fk) - if err != nil { - return identityError(sr, j, err) + if p.broken { + return 2 + } else if err != nil { + return identityError(p.sr, j, err) } if i == 0 && j == 0 && len(recipients) == 0 { labels = ll } else if err := checkLabels(ll, labels); err != nil { - return identityError(sr, j, err) + return identityError(p.sr, j, err) } stanzas[i] = append(stanzas[i], ss...) } @@ -254,7 +257,7 @@ ReadLoop: if sent, err := writeGrease(os.Stdout); err != nil { return fatalf("failed to write grease: %v", err) } else if sent { - if err := expectUnsupported(sr); err != nil { + if err := expectUnsupported(p.sr); err != nil { return fatalf("%v", err) } } @@ -263,7 +266,7 @@ ReadLoop: if err := writeStanza(os.Stdout, "labels", labels...); err != nil { return fatalf("failed to write labels stanza: %v", err) } - if err := expectOk(sr); err != nil { + if err := expectOk(p.sr); err != nil { return fatalf("%v", err) } } @@ -275,14 +278,14 @@ ReadLoop: Body: s.Body}).Marshal(os.Stdout); err != nil { return fatalf("failed to write recipient-stanza: %v", err) } - if err := expectOk(sr); err != nil { + if err := expectOk(p.sr); err != nil { return fatalf("%v", err) } } if sent, err := writeGrease(os.Stdout); err != nil { return fatalf("failed to write grease: %v", err) } else if sent { - if err := expectUnsupported(sr); err != nil { + if err := expectUnsupported(p.sr); err != nil { return fatalf("%v", err) } } @@ -321,10 +324,10 @@ func (p *Plugin) IdentityV1() int { var files [][]*age.Stanza var identityStrings []string - sr := format.NewStanzaReader(bufio.NewReader(os.Stdin)) + p.sr = format.NewStanzaReader(bufio.NewReader(os.Stdin)) ReadLoop: for { - s, err := sr.ReadStanza() + s, err := p.sr.ReadStanza() if err != nil { return fatalf("failed to read stanza: %v", err) } @@ -373,17 +376,17 @@ ReadLoop: for i, s := range identityStrings { name, data, err := ParseIdentity(s) if err != nil { - return identityError(sr, i, err) + return identityError(p.sr, i, err) } if name != p.name { - return identityError(sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return identityError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name)) } if p.identity == nil { - return identityError(sr, i, fmt.Errorf("identity encodings not supported")) + return identityError(p.sr, i, fmt.Errorf("identity encodings not supported")) } r, err := p.identity(data) if err != nil { - return identityError(sr, i, err) + return identityError(p.sr, i, err) } identities = append(identities, r) } @@ -392,7 +395,7 @@ ReadLoop: if sent, err := writeGrease(os.Stdout); err != nil { return fatalf("failed to write grease: %v", err) } else if sent { - if err := expectUnsupported(sr); err != nil { + if err := expectUnsupported(p.sr); err != nil { return fatalf("%v", err) } } @@ -401,10 +404,12 @@ ReadLoop: // in which identities are tried. for _, id := range identities { fk, err := id.Unwrap(ss) - if errors.Is(err, age.ErrIncorrectIdentity) { + if p.broken { + return 2 + } else if errors.Is(err, age.ErrIncorrectIdentity) { continue } else if err != nil { - if err := writeError(sr, []string{"stanza", fmt.Sprint(i), "0"}, err); err != nil { + if err := writeError(p.sr, []string{"stanza", fmt.Sprint(i), "0"}, err); err != nil { return fatalf("%v", err) } // Note that we don't exit here, as the protocol allows @@ -416,7 +421,7 @@ ReadLoop: if err := s.Marshal(os.Stdout); err != nil { return fatalf("failed to write file-key: %v", err) } - if err := expectOk(sr); err != nil { + if err := expectOk(p.sr); err != nil { return fatalf("%v", err) } break @@ -429,6 +434,86 @@ ReadLoop: return 0 } +// DisplayMessage requests that the client display a message to the user. The +// message should start with a lowercase letter and have no final period. +// DisplayMessage returns an error if the client can't display the message, and +// may return before the message has been displayed to the user. +// +// It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main]. +func (p *Plugin) DisplayMessage(message string) error { + if err := writeStanzaWithBody(os.Stdout, "msg", []byte(message)); err != nil { + return p.fatalInteractf("failed to write msg stanza: %v", err) + } + s, err := readOkOrFail(p.sr) + if err != nil { + return p.fatalInteractf("%v", err) + } + if s.Type == "fail" { + return fmt.Errorf("client failed to display message") + } + return nil +} + +// RequestValue requests a secret or public input from the user through the +// client, with the provided prompt. It returns an error if the client can't +// request the input or if the user dismisses the prompt. +// +// It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main]. +func (p *Plugin) RequestValue(prompt string, secret bool) (string, error) { + t := "request-public" + if secret { + t = "request-secret" + } + if err := writeStanzaWithBody(os.Stdout, t, []byte(prompt)); err != nil { + return "", p.fatalInteractf("failed to write stanza: %v", err) + } + s, err := readOkOrFail(p.sr) + if err != nil { + return "", p.fatalInteractf("%v", err) + } + if s.Type == "fail" { + return "", fmt.Errorf("client failed to request value") + } + return string(s.Body), nil +} + +// Confirm requests a confirmation from the user through the client, with the +// provided prompt. The yes and no value are the choices provided to the user. +// no may be empty. The return value choseYes indicates whether the user +// selected the yes or no option. Confirm returns an error if the client can't +// request the confirmation. +// +// It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main]. +func (p *Plugin) Confirm(prompt, yes, no string) (choseYes bool, err error) { + args := []string{base64.StdEncoding.EncodeToString([]byte(yes))} + if no != "" { + args = append(args, base64.StdEncoding.EncodeToString([]byte(no))) + } + s := &format.Stanza{Type: "confirm", Args: args, Body: []byte(prompt)} + if err := s.Marshal(os.Stdout); err != nil { + return false, p.fatalInteractf("failed to write confirm stanza: %v", err) + } + s, err = readOkOrFail(p.sr) + if err != nil { + return false, p.fatalInteractf("%v", err) + } + if s.Type == "fail" { + return false, fmt.Errorf("client failed to request confirmation") + } + if err := expectStanzaWithNoBody(s, 1); err != nil { + return false, p.fatalInteractf("%v", err) + } + return s.Args[0] == "yes", nil +} + +// fatalInteractf prints the error to stderr and sets the broken flag, so the +// Wrap/Unwrap caller can exit with an error. +func (p *Plugin) fatalInteractf(format string, args ...interface{}) error { + p.broken = true + fmt.Fprintf(os.Stderr, format, args...) + return fmt.Errorf(format, args...) +} + func expectStanzaWithNoBody(s *format.Stanza, wantArgs int) error { if len(s.Args) != wantArgs { return fmt.Errorf("%s stanza has %d arguments, want %d", s.Type, len(s.Args), wantArgs) @@ -474,6 +559,27 @@ func expectOk(sr *format.StanzaReader) error { return expectStanzaWithNoBody(ok, 0) } +func readOkOrFail(sr *format.StanzaReader) (*format.Stanza, error) { + s, err := sr.ReadStanza() + if err != nil { + return nil, fmt.Errorf("failed to read response stanza: %v", err) + } + switch s.Type { + case "fail": + if err := expectStanzaWithNoBody(s, 0); err != nil { + return nil, fmt.Errorf("%v", err) + } + return s, nil + case "ok": + if s.Body != nil { + return nil, fmt.Errorf("ok stanza has %d bytes of body, want 0", len(s.Body)) + } + return s, nil + default: + return nil, fmt.Errorf("expected ok or fail stanza, got %q", s.Type) + } +} + func expectUnsupported(sr *format.StanzaReader) error { unsupported, err := sr.ReadStanza() if err != nil { From 6fa7022d8bcf028feb7a8ef41fc6dc35d735abfb Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Fri, 21 Jun 2024 14:45:50 +0200 Subject: [PATCH 08/12] plugin: fix Plugin.Confirm --- plugin/plugin.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugin/plugin.go b/plugin/plugin.go index e042003..5a83c09 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -2,7 +2,6 @@ package plugin import ( "bufio" - "encoding/base64" "errors" "flag" "fmt" @@ -14,6 +13,7 @@ import ( ) // TODO: add examples. +// TODO: add plugin test framework. // Plugin is a framework for writing age plugins. It allows exposing regular // [age.Recipient] and [age.Identity] implementations as plugins, and handles @@ -485,9 +485,9 @@ func (p *Plugin) RequestValue(prompt string, secret bool) (string, error) { // // It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main]. func (p *Plugin) Confirm(prompt, yes, no string) (choseYes bool, err error) { - args := []string{base64.StdEncoding.EncodeToString([]byte(yes))} + args := []string{format.EncodeToString([]byte(yes))} if no != "" { - args = append(args, base64.StdEncoding.EncodeToString([]byte(no))) + args = append(args, format.EncodeToString([]byte(no))) } s := &format.Stanza{Type: "confirm", Args: args, Body: []byte(prompt)} if err := s.Marshal(os.Stdout); err != nil { From 8734a853bcdf694a9e9444515d45dbca26b3c374 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Fri, 21 Jun 2024 14:52:16 +0200 Subject: [PATCH 09/12] plugin: fix Plugin.Confirm again --- plugin/plugin.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/plugin/plugin.go b/plugin/plugin.go index 5a83c09..dafd4c7 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -451,6 +451,9 @@ func (p *Plugin) DisplayMessage(message string) error { if s.Type == "fail" { return fmt.Errorf("client failed to display message") } + if err := expectStanzaWithNoBody(s, 0); err != nil { + return p.fatalInteractf("%v", err) + } return nil } @@ -474,6 +477,9 @@ func (p *Plugin) RequestValue(prompt string, secret bool) (string, error) { if s.Type == "fail" { return "", fmt.Errorf("client failed to request value") } + if err := expectStanzaWithBody(s, 0); err != nil { + return "", p.fatalInteractf("%v", err) + } return string(s.Body), nil } @@ -571,9 +577,6 @@ func readOkOrFail(sr *format.StanzaReader) (*format.Stanza, error) { } return s, nil case "ok": - if s.Body != nil { - return nil, fmt.Errorf("ok stanza has %d bytes of body, want 0", len(s.Body)) - } return s, nil default: return nil, fmt.Errorf("expected ok or fail stanza, got %q", s.Type) From 0b2a3d0800538a34e7bc8dc5abc4899ca5b0669c Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Fri, 21 Jun 2024 18:40:58 +0200 Subject: [PATCH 10/12] plugin: add Plugin example --- plugin/example_test.go | 43 ++++++++++++++++++++++++++++++++++++++++++ plugin/plugin.go | 1 - 2 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 plugin/example_test.go diff --git a/plugin/example_test.go b/plugin/example_test.go new file mode 100644 index 0000000..80b4758 --- /dev/null +++ b/plugin/example_test.go @@ -0,0 +1,43 @@ +package plugin_test + +import ( + "log" + "os" + + "filippo.io/age" + "filippo.io/age/plugin" +) + +type Recipient struct{} + +func (r *Recipient) Wrap(fileKey []byte) ([]*age.Stanza, error) { + panic("unimplemented") +} + +func NewRecipient(data []byte) (*Recipient, error) { + return &Recipient{}, nil +} + +type Identity struct{} + +func (i *Identity) Unwrap(s []*age.Stanza) ([]byte, error) { + panic("unimplemented") +} + +func NewIdentity(data []byte) (*Identity, error) { + return &Identity{}, nil +} + +func ExamplePlugin_main() { + p, err := plugin.New("example") + if err != nil { + log.Fatal(err) + } + p.HandleRecipient(func(data []byte) (age.Recipient, error) { + return NewRecipient(data) + }) + p.HandleIdentity(func(data []byte) (age.Identity, error) { + return NewIdentity(data) + }) + os.Exit(p.Main()) +} diff --git a/plugin/plugin.go b/plugin/plugin.go index dafd4c7..5259e10 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -12,7 +12,6 @@ import ( "filippo.io/age/internal/format" ) -// TODO: add examples. // TODO: add plugin test framework. // Plugin is a framework for writing age plugins. It allows exposing regular From 8587526a13b85704b5cfa9d08dba719cd64daedc Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Thu, 26 Sep 2024 12:40:18 +0200 Subject: [PATCH 11/12] plugin: add Plugin.MainWithIO --- plugin/plugin.go | 172 +++++++++++++++++++++++++---------------------- 1 file changed, 93 insertions(+), 79 deletions(-) diff --git a/plugin/plugin.go b/plugin/plugin.go index 5259e10..689b24b 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -5,6 +5,7 @@ import ( "errors" "flag" "fmt" + "io" "os" "strconv" @@ -26,6 +27,9 @@ type Plugin struct { idAsRecipient func([]byte) (age.Recipient, error) identity func([]byte) (age.Identity, error) + stdin io.Reader + stdout, stderr io.Writer + sr *format.StanzaReader // broken is set if the protocol broke down during an interaction function // called by a Recipient or Identity. @@ -107,6 +111,15 @@ func (p *Plugin) HandleIdentity(f func(data []byte) (age.Identity, error)) { // It automatically calls [Plugin.RegisterFlags] and [flag.Parse] if they were // not called before. func (p *Plugin) Main() int { + return p.MainWithIO(os.Stdin, os.Stdout, os.Stderr) +} + +// MainWithIO works like [Plugin.Main] but runs the plugin protocol over the +// given io.Reader and io.Writers. +func (p *Plugin) MainWithIO(stdin io.Reader, stdout, stderr io.Writer) int { + p.stdin = stdin + p.stdout = stdout + p.stderr = stderr if p.fs == nil { p.RegisterFlags(nil) } @@ -119,7 +132,8 @@ func (p *Plugin) Main() int { if *p.sm == "identity-v1" { return p.IdentityV1() } - return fatalf("unknown state machine %q", *p.sm) + fmt.Fprintf(p.stderr, "unknown state machine %q", *p.sm) + return 4 } // RecipientV1 implements the recipient-v1 state machine over stdin/stdout, and @@ -128,45 +142,45 @@ func (p *Plugin) Main() int { // Most plugins should call [Plugin.Main] instead of this method. func (p *Plugin) RecipientV1() int { if p.recipient == nil && p.idAsRecipient == nil { - return fatalf("recipient-v1 not supported") + return p.fatalf("recipient-v1 not supported") } var recipientStrings, identityStrings []string var fileKeys [][]byte var supportsLabels bool - p.sr = format.NewStanzaReader(bufio.NewReader(os.Stdin)) + p.sr = format.NewStanzaReader(bufio.NewReader(p.stdin)) ReadLoop: for { s, err := p.sr.ReadStanza() if err != nil { - return fatalf("failed to read stanza: %v", err) + return p.fatalf("failed to read stanza: %v", err) } switch s.Type { case "add-recipient": if err := expectStanzaWithNoBody(s, 1); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } recipientStrings = append(recipientStrings, s.Args[0]) case "add-identity": if err := expectStanzaWithNoBody(s, 1); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } identityStrings = append(identityStrings, s.Args[0]) case "extension-labels": if err := expectStanzaWithNoBody(s, 0); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } supportsLabels = true case "wrap-file-key": if err := expectStanzaWithBody(s, 0); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } fileKeys = append(fileKeys, s.Body) case "done": if err := expectStanzaWithNoBody(s, 0); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } break ReadLoop default: @@ -175,44 +189,44 @@ ReadLoop: } if len(recipientStrings)+len(identityStrings) == 0 { - return fatalf("no recipients or identities provided") + return p.fatalf("no recipients or identities provided") } if len(fileKeys) == 0 { - return fatalf("no file keys provided") + return p.fatalf("no file keys provided") } var recipients, identities []age.Recipient for i, s := range recipientStrings { name, data, err := ParseRecipient(s) if err != nil { - return recipientError(p.sr, i, err) + return p.recipientError(i, err) } if name != p.name { - return recipientError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return p.recipientError(i, fmt.Errorf("unsupported plugin name: %q", name)) } if p.recipient == nil { - return recipientError(p.sr, i, fmt.Errorf("recipient encodings not supported")) + return p.recipientError(i, fmt.Errorf("recipient encodings not supported")) } r, err := p.recipient(data) if err != nil { - return recipientError(p.sr, i, err) + return p.recipientError(i, err) } recipients = append(recipients, r) } for i, s := range identityStrings { name, data, err := ParseIdentity(s) if err != nil { - return identityError(p.sr, i, err) + return p.identityError(i, err) } if name != p.name { - return identityError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return p.identityError(i, fmt.Errorf("unsupported plugin name: %q", name)) } if p.idAsRecipient == nil { - return identityError(p.sr, i, fmt.Errorf("identity encodings not supported")) + return p.identityError(i, fmt.Errorf("identity encodings not supported")) } r, err := p.idAsRecipient(data) if err != nil { - return identityError(p.sr, i, err) + return p.identityError(i, err) } identities = append(identities, r) } @@ -228,12 +242,12 @@ ReadLoop: if p.broken { return 2 } else if err != nil { - return recipientError(p.sr, j, err) + return p.recipientError(j, err) } if i == 0 && j == 0 { labels = ll } else if err := checkLabels(ll, labels); err != nil { - return recipientError(p.sr, j, err) + return p.recipientError(j, err) } stanzas[i] = append(stanzas[i], ss...) } @@ -242,31 +256,31 @@ ReadLoop: if p.broken { return 2 } else if err != nil { - return identityError(p.sr, j, err) + return p.identityError(j, err) } if i == 0 && j == 0 && len(recipients) == 0 { labels = ll } else if err := checkLabels(ll, labels); err != nil { - return identityError(p.sr, j, err) + return p.identityError(j, err) } stanzas[i] = append(stanzas[i], ss...) } } - if sent, err := writeGrease(os.Stdout); err != nil { - return fatalf("failed to write grease: %v", err) + if sent, err := writeGrease(p.stdout); err != nil { + return p.fatalf("failed to write grease: %v", err) } else if sent { if err := expectUnsupported(p.sr); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } } if supportsLabels { - if err := writeStanza(os.Stdout, "labels", labels...); err != nil { - return fatalf("failed to write labels stanza: %v", err) + if err := writeStanza(p.stdout, "labels", labels...); err != nil { + return p.fatalf("failed to write labels stanza: %v", err) } if err := expectOk(p.sr); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } } @@ -274,24 +288,24 @@ ReadLoop: for _, s := range ss { if err := (&format.Stanza{Type: "recipient-stanza", Args: append([]string{fmt.Sprint(i), s.Type}, s.Args...), - Body: s.Body}).Marshal(os.Stdout); err != nil { - return fatalf("failed to write recipient-stanza: %v", err) + Body: s.Body}).Marshal(p.stdout); err != nil { + return p.fatalf("failed to write recipient-stanza: %v", err) } if err := expectOk(p.sr); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } } - if sent, err := writeGrease(os.Stdout); err != nil { - return fatalf("failed to write grease: %v", err) + if sent, err := writeGrease(p.stdout); err != nil { + return p.fatalf("failed to write grease: %v", err) } else if sent { if err := expectUnsupported(p.sr); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } } } - if err := writeStanza(os.Stdout, "done"); err != nil { - return fatalf("failed to write done stanza: %v", err) + if err := writeStanza(p.stdout, "done"); err != nil { + return p.fatalf("failed to write done stanza: %v", err) } return 0 } @@ -317,33 +331,33 @@ func checkLabels(ll, labels []string) error { // Most plugins should call [Plugin.Main] instead of this method. func (p *Plugin) IdentityV1() int { if p.identity == nil { - return fatalf("identity-v1 not supported") + return p.fatalf("identity-v1 not supported") } var files [][]*age.Stanza var identityStrings []string - p.sr = format.NewStanzaReader(bufio.NewReader(os.Stdin)) + p.sr = format.NewStanzaReader(bufio.NewReader(p.stdin)) ReadLoop: for { s, err := p.sr.ReadStanza() if err != nil { - return fatalf("failed to read stanza: %v", err) + return p.fatalf("failed to read stanza: %v", err) } switch s.Type { case "add-identity": if err := expectStanzaWithNoBody(s, 1); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } identityStrings = append(identityStrings, s.Args[0]) case "recipient-stanza": if len(s.Args) < 2 { - return fatalf("recipient-stanza stanza has %d arguments, want >=2", len(s.Args)) + return p.fatalf("recipient-stanza stanza has %d arguments, want >=2", len(s.Args)) } i, err := strconv.Atoi(s.Args[0]) if err != nil { - return fatalf("failed to parse recipient-stanza stanza argument: %v", err) + return p.fatalf("failed to parse recipient-stanza stanza argument: %v", err) } ss := &age.Stanza{Type: s.Args[1], Args: s.Args[2:], Body: s.Body} switch i { @@ -352,11 +366,11 @@ ReadLoop: case len(files) - 1: files[len(files)-1] = append(files[len(files)-1], ss) default: - return fatalf("unexpected file index %d, previous was %d", i, len(files)-1) + return p.fatalf("unexpected file index %d, previous was %d", i, len(files)-1) } case "done": if err := expectStanzaWithNoBody(s, 0); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } break ReadLoop default: @@ -365,37 +379,37 @@ ReadLoop: } if len(identityStrings) == 0 { - return fatalf("no identities provided") + return p.fatalf("no identities provided") } if len(files) == 0 { - return fatalf("no stanzas provided") + return p.fatalf("no stanzas provided") } var identities []age.Identity for i, s := range identityStrings { name, data, err := ParseIdentity(s) if err != nil { - return identityError(p.sr, i, err) + return p.identityError(i, err) } if name != p.name { - return identityError(p.sr, i, fmt.Errorf("unsupported plugin name: %q", name)) + return p.identityError(i, fmt.Errorf("unsupported plugin name: %q", name)) } if p.identity == nil { - return identityError(p.sr, i, fmt.Errorf("identity encodings not supported")) + return p.identityError(i, fmt.Errorf("identity encodings not supported")) } r, err := p.identity(data) if err != nil { - return identityError(p.sr, i, err) + return p.identityError(i, err) } identities = append(identities, r) } for i, ss := range files { - if sent, err := writeGrease(os.Stdout); err != nil { - return fatalf("failed to write grease: %v", err) + if sent, err := writeGrease(p.stdout); err != nil { + return p.fatalf("failed to write grease: %v", err) } else if sent { if err := expectUnsupported(p.sr); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } } @@ -408,8 +422,8 @@ ReadLoop: } else if errors.Is(err, age.ErrIncorrectIdentity) { continue } else if err != nil { - if err := writeError(p.sr, []string{"stanza", fmt.Sprint(i), "0"}, err); err != nil { - return fatalf("%v", err) + if err := p.writeError([]string{"stanza", fmt.Sprint(i), "0"}, err); err != nil { + return p.fatalf("%v", err) } // Note that we don't exit here, as the protocol allows // continuing with other files. @@ -417,18 +431,18 @@ ReadLoop: } s := &format.Stanza{Type: "file-key", Args: []string{fmt.Sprint(i)}, Body: fk} - if err := s.Marshal(os.Stdout); err != nil { - return fatalf("failed to write file-key: %v", err) + if err := s.Marshal(p.stdout); err != nil { + return p.fatalf("failed to write file-key: %v", err) } if err := expectOk(p.sr); err != nil { - return fatalf("%v", err) + return p.fatalf("%v", err) } break } } - if err := writeStanza(os.Stdout, "done"); err != nil { - return fatalf("failed to write done stanza: %v", err) + if err := writeStanza(p.stdout, "done"); err != nil { + return p.fatalf("failed to write done stanza: %v", err) } return 0 } @@ -440,7 +454,7 @@ ReadLoop: // // It must only be called by a Wrap or Unwrap method invoked by [Plugin.Main]. func (p *Plugin) DisplayMessage(message string) error { - if err := writeStanzaWithBody(os.Stdout, "msg", []byte(message)); err != nil { + if err := writeStanzaWithBody(p.stdout, "msg", []byte(message)); err != nil { return p.fatalInteractf("failed to write msg stanza: %v", err) } s, err := readOkOrFail(p.sr) @@ -466,7 +480,7 @@ func (p *Plugin) RequestValue(prompt string, secret bool) (string, error) { if secret { t = "request-secret" } - if err := writeStanzaWithBody(os.Stdout, t, []byte(prompt)); err != nil { + if err := writeStanzaWithBody(p.stdout, t, []byte(prompt)); err != nil { return "", p.fatalInteractf("failed to write stanza: %v", err) } s, err := readOkOrFail(p.sr) @@ -495,7 +509,7 @@ func (p *Plugin) Confirm(prompt, yes, no string) (choseYes bool, err error) { args = append(args, format.EncodeToString([]byte(no))) } s := &format.Stanza{Type: "confirm", Args: args, Body: []byte(prompt)} - if err := s.Marshal(os.Stdout); err != nil { + if err := s.Marshal(p.stdout); err != nil { return false, p.fatalInteractf("failed to write confirm stanza: %v", err) } s, err = readOkOrFail(p.sr) @@ -515,10 +529,15 @@ func (p *Plugin) Confirm(prompt, yes, no string) (choseYes bool, err error) { // Wrap/Unwrap caller can exit with an error. func (p *Plugin) fatalInteractf(format string, args ...interface{}) error { p.broken = true - fmt.Fprintf(os.Stderr, format, args...) + fmt.Fprintf(p.stderr, format, args...) return fmt.Errorf(format, args...) } +func (p *Plugin) fatalf(format string, args ...interface{}) int { + fmt.Fprintf(p.stderr, format, args...) + return 1 +} + func expectStanzaWithNoBody(s *format.Stanza, wantArgs int) error { if len(s.Args) != wantArgs { return fmt.Errorf("%s stanza has %d arguments, want %d", s.Type, len(s.Args), wantArgs) @@ -539,16 +558,16 @@ func expectStanzaWithBody(s *format.Stanza, wantArgs int) error { return nil } -func recipientError(sr *format.StanzaReader, idx int, err error) int { - if err := writeError(sr, []string{"recipient", fmt.Sprint(idx)}, err); err != nil { - return fatalf("%v", err) +func (p *Plugin) recipientError(idx int, err error) int { + if err := p.writeError([]string{"recipient", fmt.Sprint(idx)}, err); err != nil { + return p.fatalf("%v", err) } return 3 } -func identityError(sr *format.StanzaReader, idx int, err error) int { - if err := writeError(sr, []string{"identity", fmt.Sprint(idx)}, err); err != nil { - return fatalf("%v", err) +func (p *Plugin) identityError(idx int, err error) int { + if err := p.writeError([]string{"identity", fmt.Sprint(idx)}, err); err != nil { + return p.fatalf("%v", err) } return 3 } @@ -593,23 +612,18 @@ func expectUnsupported(sr *format.StanzaReader) error { return expectStanzaWithNoBody(unsupported, 0) } -func writeError(sr *format.StanzaReader, args []string, err error) error { +func (p *Plugin) writeError(args []string, err error) error { s := &format.Stanza{Type: "error", Args: args} s.Body = []byte(err.Error()) - if err := s.Marshal(os.Stdout); err != nil { + if err := s.Marshal(p.stderr); err != nil { return fmt.Errorf("failed to write error stanza: %v", err) } - if err := expectOk(sr); err != nil { + if err := expectOk(p.sr); err != nil { return fmt.Errorf("%v", err) } return nil } -func fatalf(format string, args ...interface{}) int { - fmt.Fprintf(os.Stderr, format, args...) - return 1 -} - func slicesEqual(s1, s2 []string) bool { if len(s1) != len(s2) { return false From 2214a556f60400ad19f2ca43d3cbbb4a5a0fe5ab Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Thu, 26 Sep 2024 13:08:59 +0200 Subject: [PATCH 12/12] plugin: replace Plugin.MainWithIO with Plugin.SetIO --- plugin/plugin.go | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/plugin/plugin.go b/plugin/plugin.go index 689b24b..fa1d86e 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -40,7 +40,8 @@ type Plugin struct { // // For example, a plugin named "frood" would be invoked as "age-plugin-frood". func New(name string) (*Plugin, error) { - return &Plugin{name: name}, nil + return &Plugin{name: name, stdin: os.Stdin, + stdout: os.Stdout, stderr: os.Stderr}, nil } // Name returns the name of the plugin. @@ -105,21 +106,11 @@ func (p *Plugin) HandleIdentity(f func(data []byte) (age.Identity, error)) { p.identity = f } -// Main runs the plugin protocol over stdin/stdout, and writes errors to stderr. -// It returns an exit code to pass to os.Exit. +// Main runs the plugin protocol. It returns an exit code to pass to os.Exit. // // It automatically calls [Plugin.RegisterFlags] and [flag.Parse] if they were // not called before. func (p *Plugin) Main() int { - return p.MainWithIO(os.Stdin, os.Stdout, os.Stderr) -} - -// MainWithIO works like [Plugin.Main] but runs the plugin protocol over the -// given io.Reader and io.Writers. -func (p *Plugin) MainWithIO(stdin io.Reader, stdout, stderr io.Writer) int { - p.stdin = stdin - p.stdout = stdout - p.stderr = stderr if p.fs == nil { p.RegisterFlags(nil) } @@ -136,8 +127,18 @@ func (p *Plugin) MainWithIO(stdin io.Reader, stdout, stderr io.Writer) int { return 4 } -// RecipientV1 implements the recipient-v1 state machine over stdin/stdout, and -// writes errors to stderr. It returns an exit code to pass to os.Exit. +// SetIO sets the plugin's input and output streams, which default to +// stdin/stdout/stderr. +// +// It must be called before [Plugin.Main]. +func (p *Plugin) SetIO(stdin io.Reader, stdout, stderr io.Writer) { + p.stdin = stdin + p.stdout = stdout + p.stderr = stderr +} + +// RecipientV1 implements the recipient-v1 state machine. It returns an exit +// code to pass to os.Exit. // // Most plugins should call [Plugin.Main] instead of this method. func (p *Plugin) RecipientV1() int { @@ -325,8 +326,8 @@ func checkLabels(ll, labels []string) error { return nil } -// IdentityV1 implements the identity-v1 state machine over stdin/stdout, and -// writes errors to stderr. It returns an exit code to pass to os.Exit. +// IdentityV1 implements the identity-v1 state machine. It returns an exit code +// to pass to os.Exit. // // Most plugins should call [Plugin.Main] instead of this method. func (p *Plugin) IdentityV1() int {