Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: byo transport option #176

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ func (c *Channel) Open() (reterr error) {
if err != nil {
return err
}
case transport.InChannelAuthUnsupported:
}

if len(b) > 0 {
Expand Down
16 changes: 16 additions & 0 deletions driver/options/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@ import (
"github.com/scrapli/scrapligo/util"
)

// WithCustomTransport sets a custom, user provided, transport instead of one of the "core"
// transports. This custom transport must satisfy the transport.Transport interface.
func WithCustomTransport(i transport.Implementation) util.Option {
return func(o interface{}) error {
a, ok := o.(*transport.Args)

if !ok {
return util.ErrIgnoredOption
}

a.UserImplementation = i

return nil
}
}

// WithTransportReadSize sets the number of bytes each transport read operation should try to read.
// The default value is 65535.
func WithTransportReadSize(i int) util.Option {
Expand Down
52 changes: 28 additions & 24 deletions transport/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func NewTransport(
host, transportType string,
options ...util.Option,
) (*Transport, error) {
var i transportImpl
var i Implementation

var err error

Expand All @@ -37,36 +37,40 @@ func NewTransport(
return nil, err
}

switch transportType {
case SystemTransport, StandardTransport:
var sshArgs *SSHArgs
if args.UserImplementation != nil {
i = args.UserImplementation
} else {
switch transportType {
case SystemTransport, StandardTransport:
var sshArgs *SSHArgs

sshArgs, err = NewSSHArgs(options...)
if err != nil {
return nil, err
}
sshArgs, err = NewSSHArgs(options...)
if err != nil {
return nil, err
}

switch transportType {
case SystemTransport:
i, err = NewSystemTransport(sshArgs)
case StandardTransport:
i, err = NewStandardTransport(sshArgs)
switch transportType {
case SystemTransport:
i, err = NewSystemTransport(sshArgs)
case StandardTransport:
i, err = NewStandardTransport(sshArgs)
}
case TelnetTransport:
var telnetArgs *TelnetArgs

telnetArgs, err = NewTelnetArgs(options...)
if err != nil {
return nil, err
}

i, err = NewTelnetTransport(telnetArgs)
case FileTransport:
i, err = NewFileTransport()
}
case TelnetTransport:
var telnetArgs *TelnetArgs

telnetArgs, err = NewTelnetArgs(options...)
if err != nil {
return nil, err
}

i, err = NewTelnetTransport(telnetArgs)
case FileTransport:
i, err = NewFileTransport()
}

if err != nil {
return nil, err
}

for _, option := range options {
Expand Down
6 changes: 4 additions & 2 deletions transport/system.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,12 @@ func (t *System) Write(b []byte) error {
return err
}

func (t *System) inChannelAuthType() string {
// GetInChannelAuthType returns the in channel auth flavor for the system transport.
func (t *System) GetInChannelAuthType() InChannelAuthType {
return InChannelAuthSSH
}

func (t *System) getSSHArgs() *SSHArgs {
// GetSSHArgs returns the ssh args for the system transport.
func (t *System) GetSSHArgs() *SSHArgs {
return t.SSHArgs
}
3 changes: 2 additions & 1 deletion transport/telnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ func (t *Telnet) Write(b []byte) error {
return err
}

func (t *Telnet) inChannelAuthType() string {
// GetInChannelAuthType returns the in channel auth flavor for the telnet transport.
func (t *Telnet) GetInChannelAuthType() InChannelAuthType {
return InChannelAuthTelnet
}
70 changes: 42 additions & 28 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,22 @@ import (
"github.com/scrapli/scrapligo/util"
)

const (
// DefaultTransport is the default transport constant for scrapligo, this defaults to the
// "system" transport.
DefaultTransport = "system"
// InChannelAuthType is an enum-ish string that represents valid in channel auth flavors.
type InChannelAuthType string

const (
// InChannelAuthUnsupported indicates that the transport does *not* support in channel auth.
InChannelAuthUnsupported = "in-channel-auth-unsupported"
InChannelAuthUnsupported InChannelAuthType = "unsupported"
// InChannelAuthSSH indicates that the transport supports in channel ssh auth.
InChannelAuthSSH = "in-channel-auth-ssh"
InChannelAuthSSH InChannelAuthType = "ssh"
// InChannelAuthTelnet indicates that the transport supports in channel telnet auth.
InChannelAuthTelnet = "in-channel-auth-telnet"
InChannelAuthTelnet InChannelAuthType = "telnet"
)

const (
// DefaultTransport is the default transport constant for scrapligo, this defaults to the
// "system" transport.
DefaultTransport = "system"

defaultPort = 22
defaultTimeoutSocketSeconds = 30
Expand All @@ -33,7 +38,7 @@ const (
// InChannelAuthData is a struct containing all necessary information for the Channel to handle
// "in-channel" auth if necessary.
type InChannelAuthData struct {
Type string
Type InChannelAuthType
User string
Password string
PrivateKeyPassPhrase string
Expand Down Expand Up @@ -67,15 +72,16 @@ func NewArgs(l *logging.Instance, host string, options ...util.Option) (*Args, e

// Args is a struct representing common transport arguments.
type Args struct {
l *logging.Instance
Host string
Port int
User string
Password string
TimeoutSocket time.Duration
ReadSize int
TermHeight int
TermWidth int
l *logging.Instance
UserImplementation Implementation
Host string
Port int
User string
Password string
TimeoutSocket time.Duration
ReadSize int
TermHeight int
TermWidth int
}

// NewSSHArgs returns an instance of SSH arguments with provided options set. Just like NewArgs,
Expand Down Expand Up @@ -127,29 +133,37 @@ func NewTelnetArgs(options ...util.Option) (*TelnetArgs, error) {
// TelnetArgs is a struct representing common transport Telnet specific arguments.
type TelnetArgs struct{}

type transportImpl interface {
// Implementation defines a valid base scrapligo transport -- for SSH-specific transports users
// should satisfy SSHImplementation -- and for transports that require authentication "in channel"
// users should satisfy InChannelAuthImplementation (this could be in addition to the SSH one!).
type Implementation interface {
Open(a *Args) error
Close() error
IsAlive() bool
Read(n int) ([]byte, error)
Write(b []byte) error
}

// transportImplSSH is an interface that SSH transports *may* implement, this is currently only
// SSHImplementation is an interface that SSH transports *may* implement, this is currently only
// required if the SSH transport also requires (or just supports) "in-channel" ssh authentication.
type transportImplSSH interface {
getSSHArgs() *SSHArgs
type SSHImplementation interface {
Implementation
GetSSHArgs() *SSHArgs
}

type transportImplInChannelAuth interface {
inChannelAuthType() string
// InChannelAuthImplementation is an interface that when satisfied tells us that the transport
// wants to do "in channel" authentication -- meaning actually look for user/password prompt and
// send those values in the connection rather than in the protocol/out of band.
type InChannelAuthImplementation interface {
Implementation
GetInChannelAuthType() InChannelAuthType
}

// Transport is a struct which wraps a transportImpl object and provides a unified interface to any
// type of transport selected by the user.
type Transport struct {
Args *Args
Impl transportImpl
Impl Implementation
implLock *sync.Mutex
timeoutLock *sync.Mutex
}
Expand Down Expand Up @@ -211,15 +225,15 @@ func (t *Transport) GetPort() int {
// InChannelAuthData returns an instance of InChannelAuthData indicating if in-channel auth is
// supported, and if so, the necessary fields to accomplish that.
func (t *Transport) InChannelAuthData() *InChannelAuthData {
ti, ok := t.Impl.(transportImplInChannelAuth)
ti, ok := t.Impl.(InChannelAuthImplementation)
if !ok {
return &InChannelAuthData{
Type: InChannelAuthUnsupported,
}
}

d := &InChannelAuthData{
Type: ti.inChannelAuthType(),
Type: ti.GetInChannelAuthType(),
User: t.Args.User,
Password: t.Args.Password,
PrivateKeyPassPhrase: "",
Expand All @@ -229,15 +243,15 @@ func (t *Transport) InChannelAuthData() *InChannelAuthData {
return d
}

sshTransport, ok := ti.(transportImplSSH)
sshTransport, ok := ti.(SSHImplementation)
if !ok {
panic(
"in channel auth requested on a non telnet transport," +
" and transport does not implement transportImplSSH",
)
}

d.PrivateKeyPassPhrase = sshTransport.getSSHArgs().PrivateKeyPassPhrase
d.PrivateKeyPassPhrase = sshTransport.GetSSHArgs().PrivateKeyPassPhrase

return d
}
Loading