diff --git a/cmd.go b/cmd.go index 89e5a68..7018ca6 100644 --- a/cmd.go +++ b/cmd.go @@ -24,35 +24,52 @@ func (c *Connect) Command(command string) (err error) { } defer func() { c.Session = nil }() - // setup options - err = c.setOption(c.Session) - if err != nil { - return - } - - // Set Stdin, Stdout, Stderr... - if c.Stdin != nil { + // Set Stdin + switch { + case c.Stdin != nil: w, _ := c.Session.StdinPipe() go io.Copy(w, c.Stdin) - } else { + + case c.PtyRelayTty != nil: + c.Session.Stdin = c.PtyRelayTty + + default: stdin := GetStdin() c.Session.Stdin = stdin } - if c.Stdout != nil { + // Set Stdout + switch { + case c.Stdout != nil: or, _ := c.Session.StdoutPipe() go io.Copy(c.Stdout, or) - } else { + + case c.PtyRelayTty != nil: + c.Session.Stdout = c.PtyRelayTty + + default: c.Session.Stdout = os.Stdout } - if c.Stderr != nil { + // Set Stderr + switch { + case c.Stderr != nil: er, _ := c.Session.StderrPipe() go io.Copy(c.Stderr, er) - } else { + + case c.PtyRelayTty != nil: + c.Session.Stderr = c.PtyRelayTty + + default: c.Session.Stderr = os.Stderr } + // setup options + err = c.setOption(c.Session) + if err != nil { + return + } + // Run Command c.Session.Run(command) diff --git a/connect.go b/connect.go index d4affed..d55b211 100644 --- a/connect.go +++ b/connect.go @@ -6,7 +6,6 @@ package sshlib import ( "context" - "errors" "io" "log" "net" @@ -51,6 +50,9 @@ type Connect struct { // Set it before CraeteClient. ForwardAgent bool + // Set the TTY to be used as the input and output for the Session/Cmd. + PtyRelayTty *os.File + // CheckKnownHosts if true, check knownhosts. // Ignored if HostKeyCallback is set. // Set it before CraeteClient. @@ -191,36 +193,37 @@ func (c *Connect) SendKeepAlive(session *ssh.Session) { interval = c.SendKeepAliveInterval } + max := 3 + if c.SendKeepAliveMax > 0 { + max = c.SendKeepAliveMax + } + t := time.NewTicker(time.Duration(c.ConnectTimeout) * time.Second) defer t.Stop() + count := 0 for { select { case <-t.C: if _, err := session.SendRequest("keepalive@openssh.com", true, nil); err != nil { - if !errors.Is(err, io.EOF) { - log.Println("Failed to send keepalive packet:", err) - session.Close() - c.Client.Close() - break - } else { - // sleep - time.Sleep(time.Duration(interval) * time.Second) - continue - } + log.Println("Failed to send keepalive packet:", err) + count += 1 } else { - // sleep + // err is nil. time.Sleep(time.Duration(interval) * time.Second) - continue } } + + if count > max { + return + } } } // CheckClientAlive check alive ssh.Client. func (c *Connect) CheckClientAlive() error { _, _, err := c.Client.SendRequest("keepalive", true, nil) - if err == nil || err.Error() == "request failed" { + if err == nil { return nil } return err diff --git a/shell.go b/shell.go index e9379a5..a185947 100644 --- a/shell.go +++ b/shell.go @@ -20,7 +20,13 @@ import ( // Shell connect login shell over ssh. func (c *Connect) Shell(session *ssh.Session) (err error) { // Input terminal Make raw - fd := int(os.Stdin.Fd()) + var fd int + if c.PtyRelayTty != nil { + fd = int(c.PtyRelayTty.Fd()) + } else { + fd = int(os.Stdin.Fd()) + } + state, err := terminal.MakeRaw(fd) if err != nil { return @@ -33,6 +39,13 @@ func (c *Connect) Shell(session *ssh.Session) (err error) { return } + // set tty + if c.PtyRelayTty != nil { + session.Stdin = c.PtyRelayTty + session.Stdout = c.PtyRelayTty + session.Stderr = c.PtyRelayTty + } + // Start shell err = session.Shell() if err != nil { @@ -42,6 +55,11 @@ func (c *Connect) Shell(session *ssh.Session) (err error) { // keep alive packet go c.SendKeepAlive(session) + // if tty is set, get signal winch + if c.PtyRelayTty != nil { + go c.ChangeWinSize(session) + } + err = session.Wait() if err != nil { return @@ -54,7 +72,13 @@ func (c *Connect) Shell(session *ssh.Session) (err error) { // Used to start a shell with a specified command. func (c *Connect) CmdShell(session *ssh.Session, command string) (err error) { // Input terminal Make raw - fd := int(os.Stdin.Fd()) + var fd int + if c.PtyRelayTty != nil { + fd = int(c.PtyRelayTty.Fd()) + } else { + fd = int(os.Stdin.Fd()) + } + state, err := terminal.MakeRaw(fd) if err != nil { return @@ -67,6 +91,13 @@ func (c *Connect) CmdShell(session *ssh.Session, command string) (err error) { return } + // set tty + if c.PtyRelayTty != nil { + session.Stdin = c.PtyRelayTty + session.Stdout = c.PtyRelayTty + session.Stderr = c.PtyRelayTty + } + // Start shell err = session.Start(command) if err != nil { @@ -84,6 +115,23 @@ func (c *Connect) CmdShell(session *ssh.Session, command string) (err error) { return } +func (c *Connect) ChangeWinSize(session *ssh.Session) { + // Get terminal window size + var fd int + if c.PtyRelayTty != nil { + fd = int(c.PtyRelayTty.Fd()) + } else { + fd = int(os.Stdout.Fd()) + } + width, height, err := terminal.GetSize(fd) + if err != nil { + return + } + + // Send window size + session.WindowChange(height, width) +} + // SetLog set up terminal log logging. // This only happens in Connect.Shell(). func (c *Connect) SetLog(path string, timestamp bool) {