From 2ede47b67dbf3924be1586348cebeca082c3e40e Mon Sep 17 00:00:00 2001 From: Carl Montanari Date: Wed, 3 Apr 2024 06:29:42 -0700 Subject: [PATCH] feat: byo transport option --- channel/channel.go | 1 + driver/options/transport.go | 16 +++++++++ transport/factory.go | 52 ++++++++++++++------------- transport/system.go | 6 ++-- transport/telnet.go | 3 +- transport/transport.go | 70 ++++++++++++++++++++++--------------- 6 files changed, 93 insertions(+), 55 deletions(-) diff --git a/channel/channel.go b/channel/channel.go index 4563890..6e0fa71 100644 --- a/channel/channel.go +++ b/channel/channel.go @@ -165,6 +165,7 @@ func (c *Channel) Open() (reterr error) { if err != nil { return err } + case transport.InChannelAuthUnsupported: } if len(b) > 0 { diff --git a/driver/options/transport.go b/driver/options/transport.go index 09ac23c..a1fb768 100644 --- a/driver/options/transport.go +++ b/driver/options/transport.go @@ -7,6 +7,22 @@ import ( "github.com/scrapli/scrapligo/util" ) +// WithCustomTransport sets a custom, user provided, transport instead of one of the "core" +// transports. This custom transport must satisfy the transport.Transport interface. +func WithCustomTransport(i transport.Implementation) util.Option { + return func(o interface{}) error { + a, ok := o.(*transport.Args) + + if !ok { + return util.ErrIgnoredOption + } + + a.UserImplementation = i + + return nil + } +} + // WithTransportReadSize sets the number of bytes each transport read operation should try to read. // The default value is 65535. func WithTransportReadSize(i int) util.Option { diff --git a/transport/factory.go b/transport/factory.go index 4403052..6f130ac 100644 --- a/transport/factory.go +++ b/transport/factory.go @@ -26,7 +26,7 @@ func NewTransport( host, transportType string, options ...util.Option, ) (*Transport, error) { - var i transportImpl + var i Implementation var err error @@ -37,36 +37,40 @@ func NewTransport( return nil, err } - switch transportType { - case SystemTransport, StandardTransport: - var sshArgs *SSHArgs + if args.UserImplementation != nil { + i = args.UserImplementation + } else { + switch transportType { + case SystemTransport, StandardTransport: + var sshArgs *SSHArgs - sshArgs, err = NewSSHArgs(options...) - if err != nil { - return nil, err - } + sshArgs, err = NewSSHArgs(options...) + if err != nil { + return nil, err + } - switch transportType { - case SystemTransport: - i, err = NewSystemTransport(sshArgs) - case StandardTransport: - i, err = NewStandardTransport(sshArgs) + switch transportType { + case SystemTransport: + i, err = NewSystemTransport(sshArgs) + case StandardTransport: + i, err = NewStandardTransport(sshArgs) + } + case TelnetTransport: + var telnetArgs *TelnetArgs + + telnetArgs, err = NewTelnetArgs(options...) + if err != nil { + return nil, err + } + + i, err = NewTelnetTransport(telnetArgs) + case FileTransport: + i, err = NewFileTransport() } - case TelnetTransport: - var telnetArgs *TelnetArgs - telnetArgs, err = NewTelnetArgs(options...) if err != nil { return nil, err } - - i, err = NewTelnetTransport(telnetArgs) - case FileTransport: - i, err = NewFileTransport() - } - - if err != nil { - return nil, err } for _, option := range options { diff --git a/transport/system.go b/transport/system.go index 46e5804..0f17e49 100644 --- a/transport/system.go +++ b/transport/system.go @@ -230,10 +230,12 @@ func (t *System) Write(b []byte) error { return err } -func (t *System) inChannelAuthType() string { +// GetInChannelAuthType returns the in channel auth flavor for the system transport. +func (t *System) GetInChannelAuthType() InChannelAuthType { return InChannelAuthSSH } -func (t *System) getSSHArgs() *SSHArgs { +// GetSSHArgs returns the ssh args for the system transport. +func (t *System) GetSSHArgs() *SSHArgs { return t.SSHArgs } diff --git a/transport/telnet.go b/transport/telnet.go index 3044c2b..5969730 100644 --- a/transport/telnet.go +++ b/transport/telnet.go @@ -166,6 +166,7 @@ func (t *Telnet) Write(b []byte) error { return err } -func (t *Telnet) inChannelAuthType() string { +// GetInChannelAuthType returns the in channel auth flavor for the telnet transport. +func (t *Telnet) GetInChannelAuthType() InChannelAuthType { return InChannelAuthTelnet } diff --git a/transport/transport.go b/transport/transport.go index 5a96df6..140fb28 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -9,17 +9,22 @@ import ( "github.com/scrapli/scrapligo/util" ) -const ( - // DefaultTransport is the default transport constant for scrapligo, this defaults to the - // "system" transport. - DefaultTransport = "system" +// InChannelAuthType is an enum-ish string that represents valid in channel auth flavors. +type InChannelAuthType string +const ( // InChannelAuthUnsupported indicates that the transport does *not* support in channel auth. - InChannelAuthUnsupported = "in-channel-auth-unsupported" + InChannelAuthUnsupported InChannelAuthType = "unsupported" // InChannelAuthSSH indicates that the transport supports in channel ssh auth. - InChannelAuthSSH = "in-channel-auth-ssh" + InChannelAuthSSH InChannelAuthType = "ssh" // InChannelAuthTelnet indicates that the transport supports in channel telnet auth. - InChannelAuthTelnet = "in-channel-auth-telnet" + InChannelAuthTelnet InChannelAuthType = "telnet" +) + +const ( + // DefaultTransport is the default transport constant for scrapligo, this defaults to the + // "system" transport. + DefaultTransport = "system" defaultPort = 22 defaultTimeoutSocketSeconds = 30 @@ -33,7 +38,7 @@ const ( // InChannelAuthData is a struct containing all necessary information for the Channel to handle // "in-channel" auth if necessary. type InChannelAuthData struct { - Type string + Type InChannelAuthType User string Password string PrivateKeyPassPhrase string @@ -67,15 +72,16 @@ func NewArgs(l *logging.Instance, host string, options ...util.Option) (*Args, e // Args is a struct representing common transport arguments. type Args struct { - l *logging.Instance - Host string - Port int - User string - Password string - TimeoutSocket time.Duration - ReadSize int - TermHeight int - TermWidth int + l *logging.Instance + UserImplementation Implementation + Host string + Port int + User string + Password string + TimeoutSocket time.Duration + ReadSize int + TermHeight int + TermWidth int } // NewSSHArgs returns an instance of SSH arguments with provided options set. Just like NewArgs, @@ -127,7 +133,10 @@ func NewTelnetArgs(options ...util.Option) (*TelnetArgs, error) { // TelnetArgs is a struct representing common transport Telnet specific arguments. type TelnetArgs struct{} -type transportImpl interface { +// Implementation defines a valid base scrapligo transport -- for SSH-specific transports users +// should satisfy SSHImplementation -- and for transports that require authentication "in channel" +// users should satisfy InChannelAuthImplementation (this could be in addition to the SSH one!). +type Implementation interface { Open(a *Args) error Close() error IsAlive() bool @@ -135,21 +144,26 @@ type transportImpl interface { Write(b []byte) error } -// transportImplSSH is an interface that SSH transports *may* implement, this is currently only +// SSHImplementation is an interface that SSH transports *may* implement, this is currently only // required if the SSH transport also requires (or just supports) "in-channel" ssh authentication. -type transportImplSSH interface { - getSSHArgs() *SSHArgs +type SSHImplementation interface { + Implementation + GetSSHArgs() *SSHArgs } -type transportImplInChannelAuth interface { - inChannelAuthType() string +// InChannelAuthImplementation is an interface that when satisfied tells us that the transport +// wants to do "in channel" authentication -- meaning actually look for user/password prompt and +// send those values in the connection rather than in the protocol/out of band. +type InChannelAuthImplementation interface { + Implementation + GetInChannelAuthType() InChannelAuthType } // Transport is a struct which wraps a transportImpl object and provides a unified interface to any // type of transport selected by the user. type Transport struct { Args *Args - Impl transportImpl + Impl Implementation implLock *sync.Mutex timeoutLock *sync.Mutex } @@ -211,7 +225,7 @@ func (t *Transport) GetPort() int { // InChannelAuthData returns an instance of InChannelAuthData indicating if in-channel auth is // supported, and if so, the necessary fields to accomplish that. func (t *Transport) InChannelAuthData() *InChannelAuthData { - ti, ok := t.Impl.(transportImplInChannelAuth) + ti, ok := t.Impl.(InChannelAuthImplementation) if !ok { return &InChannelAuthData{ Type: InChannelAuthUnsupported, @@ -219,7 +233,7 @@ func (t *Transport) InChannelAuthData() *InChannelAuthData { } d := &InChannelAuthData{ - Type: ti.inChannelAuthType(), + Type: ti.GetInChannelAuthType(), User: t.Args.User, Password: t.Args.Password, PrivateKeyPassPhrase: "", @@ -229,7 +243,7 @@ func (t *Transport) InChannelAuthData() *InChannelAuthData { return d } - sshTransport, ok := ti.(transportImplSSH) + sshTransport, ok := ti.(SSHImplementation) if !ok { panic( "in channel auth requested on a non telnet transport," + @@ -237,7 +251,7 @@ func (t *Transport) InChannelAuthData() *InChannelAuthData { ) } - d.PrivateKeyPassPhrase = sshTransport.getSSHArgs().PrivateKeyPassPhrase + d.PrivateKeyPassPhrase = sshTransport.GetSSHArgs().PrivateKeyPassPhrase return d }