Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mockssh: expose default command handler for reuse #233

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions pkg/mockssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"io"
"net"
"net/http"
"os"
"os/exec"
"sync"
"testing"
Expand All @@ -31,9 +30,9 @@ type Server struct {
CertAuthorityKeys []ssh.PublicKey
CertChecker ssh.CertChecker

// RemoteEnv, RemoteDir and CommandHandler are optional configuration.
RemoteEnv []string
RemoteDir string
// An optional CommandHandler, which responds to commands sent over SSH.
// NewServer will give this a default using ExecHandler, which can also
// be reused from custom handlers.
CommandHandler CommandHandler

// listener and port are set after Start.
Expand All @@ -47,7 +46,7 @@ type CommandIO struct {
StdErr io.Writer
}

type CommandHandler func(conn ssh.ConnMetadata, command string, io CommandIO) int
type CommandHandler func(conn ssh.ConnMetadata, command string, commandIO CommandIO) int

// NewServer creates and starts a local SSH server for a test.
// It must be stopped with the Server.Stop method.
Expand All @@ -65,9 +64,8 @@ func NewServer(t *testing.T, authorityEndpoint string) (*Server, error) {
}

s := &Server{t: t, hostKey: hk}
s.CommandHandler = s.defaultCommandHandler
s.CommandHandler = ExecHandler("", nil)
s.CertChecker = s.defaultCertChecker()
s.RemoteDir = t.TempDir()
s.CertAuthorityKeys = keys

if err := s.start(); err != nil {
Expand All @@ -89,6 +87,10 @@ func (s *Server) HostKeyConfig() string {
)
}

func (s *Server) HostKey() ssh.PublicKey {
return s.hostKey.PublicKey()
}

func (s *Server) start() error {
t := s.t

Expand Down Expand Up @@ -148,22 +150,25 @@ func (s *Server) Stop() error {
return nil
}

func (s *Server) defaultCommandHandler(_ ssh.ConnMetadata, command string, commandIO CommandIO) int {
c := exec.Command("bash", "-c", command)
c.Stdout = commandIO.StdOut
c.Stderr = commandIO.StdErr
c.Stdin = commandIO.StdIn
c.Dir = s.RemoteDir
c.Env = append(os.Environ(), s.RemoteEnv...)
if err := c.Run(); err != nil {
exitErr := &exec.ExitError{}
if errors.As(err, &exitErr) {
return exitErr.ExitCode()
// ExecHandler returns a CommandHandler to execute a command in the given environment.
func ExecHandler(workingDir string, env []string) CommandHandler {
return func(_ ssh.ConnMetadata, command string, commandIO CommandIO) int {
c := exec.Command("bash", "-c", command)
c.Stdout = commandIO.StdOut
c.Stderr = commandIO.StdErr
c.Stdin = commandIO.StdIn
c.Dir = workingDir
c.Env = env
if err := c.Run(); err != nil {
exitErr := &exec.ExitError{}
if errors.As(err, &exitErr) {
return exitErr.ExitCode()
}
_, _ = fmt.Fprintf(commandIO.StdErr, "Failed to execute command: %v", err)
return 1
}
_, _ = fmt.Fprintf(commandIO.StdErr, "Failed to execute command: %v", err)
return 1
return 0
}
return 0
}

func (s *Server) defaultCertChecker() ssh.CertChecker {
Expand Down Expand Up @@ -253,9 +258,9 @@ func (s *Server) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.NewChann
for {
select {
case s := <-exitWithStatus:
_, err = channel.SendRequest("exit-status", false, ssh.Marshal(struct{ Status int }{s}))
_, err = channel.SendRequest("exit-status", false, ssh.Marshal(struct{ Status uint32 }{uint32(s)})) //nolint: gosec
if err != nil {
t.Errorf("Failed to send exit status: %v", err)
t.Fatalf("Failed to send exit status: %v", err)
}
goto closeChannel
case <-timer.C:
Expand Down
105 changes: 105 additions & 0 deletions pkg/mockssh/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package mockssh_test

import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"fmt"
"net"
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"

"github.com/platformsh/cli/pkg/mockapi"
"github.com/platformsh/cli/pkg/mockssh"
)

func TestServer(t *testing.T) {
authServer := mockapi.NewAuthServer(t)
defer authServer.Close()

sshServer, err := mockssh.NewServer(t, authServer.URL+"/ssh/authority")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = sshServer.Stop()
})

tempDir := t.TempDir()
sshServer.CommandHandler = mockssh.ExecHandler(tempDir, []string{})

cert := getTestSSHAuth(t, authServer.URL)

// Create the SSH client configuration
address := fmt.Sprintf("127.0.0.1:%d", sshServer.Port())
config := &ssh.ClientConfig{
User: "test",
Auth: []ssh.AuthMethod{ssh.PublicKeys(cert)},
HostKeyCallback: func(_ string, remote net.Addr, key ssh.PublicKey) error {
if remote.String() != address {
return fmt.Errorf("unexpected address: %s", remote.String())
}
if bytes.Equal(sshServer.HostKey().Marshal(), key.Marshal()) {
return nil
}
return fmt.Errorf("host key mismatch")
},
}

client, err := ssh.Dial("tcp", address, config)
require.NoError(t, err)
defer client.Close()

session, err := client.NewSession()
require.NoError(t, err)
defer session.Close()

stdOutBuffer := &bytes.Buffer{}
session.Stdout = stdOutBuffer

require.NoError(t, session.Run("pwd"))
assert.Equal(t, tempDir, strings.TrimRight(stdOutBuffer.String(), "\n"))

session2, err := client.NewSession()
require.NoError(t, err)
defer session2.Close()
err = session2.Run("false")
assert.Error(t, err)
var exitErr *ssh.ExitError
assert.ErrorAs(t, err, &exitErr)
assert.Equal(t, 1, exitErr.ExitStatus())
}

func getTestSSHAuth(t *testing.T, authServerURL string) ssh.Signer {
t.Helper()

// Generate a keypair
_, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
s, err := ssh.NewSignerFromKey(priv)
require.NoError(t, err)

b, err := json.Marshal(struct{ Key string }{string(ssh.MarshalAuthorizedKey(s.PublicKey()))})
require.NoError(t, err)
resp, err := http.DefaultClient.Post(authServerURL+"/ssh", "application/json", bytes.NewReader(b))
require.NoError(t, err)
defer resp.Body.Close()

var rs struct{ Certificate string }
require.NoError(t, json.NewDecoder(resp.Body).Decode(&rs))

parsed, _, _, _, err := ssh.ParseAuthorizedKey([]byte(rs.Certificate)) //nolint: dogsled
require.NoError(t, err)

cert, _ := parsed.(*ssh.Certificate)
certSigner, err := ssh.NewCertSigner(cert, s)
require.NoError(t, err)

return certSigner
}
Loading