Skip to content

Commit

Permalink
Merge pull request #12542 from tomponline/tp-exec
Browse files Browse the repository at this point in the history
Exec cleanup improvements
  • Loading branch information
tomponline authored Nov 23, 2023
2 parents e31e65a + d654ae6 commit a4a8f23
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 54 deletions.
44 changes: 25 additions & 19 deletions client/lxd_containers.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,8 @@ func (r *ProtocolLXD) ExecContainer(containerName string, exec api.ContainerExec
dones[0] = ws.MirrorRead(conn, args.Stdin)
}

waitConns := 0 // Used for keeping track of when stdout and stderr have finished.

// Handle stdout
if fds["1"] != "" {
conn, err := r.GetOperationWebsocket(opAPI.ID, fds["1"])
Expand All @@ -710,6 +712,7 @@ func (r *ProtocolLXD) ExecContainer(containerName string, exec api.ContainerExec

conns = append(conns, conn)
dones[1] = ws.MirrorWrite(conn, args.Stdout)
waitConns++
}

// Handle stderr
Expand All @@ -721,33 +724,36 @@ func (r *ProtocolLXD) ExecContainer(containerName string, exec api.ContainerExec

conns = append(conns, conn)
dones[2] = ws.MirrorWrite(conn, args.Stderr)
waitConns++
}

// Wait for everything to be done
go func() {
for i, chDone := range dones {
// Skip stdin, dealing with it separately below
if i == 0 {
continue
for {
select {
case <-dones[0]:
// Handle stdin finish, but don't wait for it if output channels
// have all finished.
dones[0] = nil
_ = conns[0].Close()
case <-dones[1]:
dones[1] = nil
_ = conns[1].Close()
waitConns--
case <-dones[2]:
dones[2] = nil
_ = conns[2].Close()
waitConns--
}

<-chDone
}
if waitConns <= 0 {
// Close stdin websocket if defined and not already closed.
if dones[0] != nil {
conns[0].Close()
}

if fds["0"] != "" {
if args.Stdin != nil {
_ = args.Stdin.Close()
break
}

// Empty the stdin channel but don't block on it as
// stdin may be stuck in Read()
go func() {
<-dones[0]
}()
}

for _, conn := range conns {
_ = conn.Close()
}

if args.DataDone != nil {
Expand Down
50 changes: 34 additions & 16 deletions client/lxd_instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,10 @@ func (r *ProtocolLXD) ExecInstance(instanceName string, exec api.InstanceExecPos
return nil, err
}

go func() {
_, _, _ = conn.ReadMessage() // Consume pings from server.
}()

go args.Control(conn)
}

Expand Down Expand Up @@ -1249,10 +1253,16 @@ func (r *ProtocolLXD) ExecInstance(instanceName string, exec api.InstanceExecPos
return nil, err
}

go func() {
_, _, _ = conn.ReadMessage() // Consume pings from server.
}()

conns = append(conns, conn)
dones[0] = ws.MirrorRead(conn, args.Stdin)
}

waitConns := 0 // Used for keeping track of when stdout and stderr have finished.

// Handle stdout
if fds["1"] != "" {
conn, err := r.GetOperationWebsocket(opAPI.ID, fds["1"])
Expand All @@ -1262,6 +1272,7 @@ func (r *ProtocolLXD) ExecInstance(instanceName string, exec api.InstanceExecPos

conns = append(conns, conn)
dones[1] = ws.MirrorWrite(conn, args.Stdout)
waitConns++
}

// Handle stderr
Expand All @@ -1273,29 +1284,36 @@ func (r *ProtocolLXD) ExecInstance(instanceName string, exec api.InstanceExecPos

conns = append(conns, conn)
dones[2] = ws.MirrorWrite(conn, args.Stderr)
waitConns++
}

// Wait for everything to be done
go func() {
for i, chDone := range dones {
// Skip stdin, dealing with it separately below
if i == 0 {
continue
for {
select {
case <-dones[0]:
// Handle stdin finish, but don't wait for it if output channels
// have all finished.
dones[0] = nil
_ = conns[0].Close()
case <-dones[1]:
dones[1] = nil
_ = conns[1].Close()
waitConns--
case <-dones[2]:
dones[2] = nil
_ = conns[2].Close()
waitConns--
}

<-chDone
}

if fds["0"] != "" {
// Empty the stdin channel but don't block on it as
// stdin may be stuck in Read()
go func() {
<-dones[0]
}()
}
if waitConns <= 0 {
// Close stdin websocket if defined and not already closed.
if dones[0] != nil {
conns[0].Close()
}

for _, conn := range conns {
_ = conn.Close()
break
}
}

if args.DataDone != nil {
Expand Down
4 changes: 2 additions & 2 deletions lxc/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ func (c *cmdExec) Run(cmd *cobra.Command, args []string) error {
}
}

var stdin io.ReadCloser
var stdin io.Reader
stdin = os.Stdin
if c.flagDisableStdin {
stdin = io.NopCloser(bytes.NewReader(nil))
stdin = bytes.NewReader(nil)
}

stdout := getStdout()
Expand Down
28 changes: 14 additions & 14 deletions lxd/instance_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,23 @@ func (s *execWs) Connect(op *operations.Operation, r *http.Request, w http.Respo
if err != nil {
logger.Warn("Failed setting TCP timeouts on remote connection", logger.Ctx{"err": err})
}
}

// Start channel keep alive to run until channel is closed.
go func() {
pingInterval := time.Second * 10
t := time.NewTicker(pingInterval)
defer t.Stop()

for {
err := conn.WriteControl(websocket.PingMessage, []byte("keepalive"), time.Now().Add(5*time.Second))
if err != nil {
return
}
// Start channel keep alive to run until channel is closed.
go func() {
pingInterval := time.Second * 10
t := time.NewTicker(pingInterval)
defer t.Stop()

<-t.C
for {
err := conn.WriteControl(websocket.PingMessage, []byte("keepalive"), time.Now().Add(5*time.Second))
if err != nil {
return
}
}()
}

<-t.C
}
}()

if fd == execWSControl {
s.waitControlConnected.Cancel() // Control connection connected.
Expand Down
5 changes: 2 additions & 3 deletions shared/ws/mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ func MirrorRead(conn *websocket.Conn, rc io.Reader) chan error {
connRWC := NewWrapper(conn)

go func() {
defer close(chDone)

_, err := io.Copy(connRWC, rc)

logger.Debug("Websocket: Stopped read mirror", logger.Ctx{"address": conn.RemoteAddr().String(), "err": err})
Expand All @@ -40,6 +38,7 @@ func MirrorRead(conn *websocket.Conn, rc io.Reader) chan error {
connRWC.Close()

chDone <- err
close(chDone)
}()

return chDone
Expand All @@ -58,11 +57,11 @@ func MirrorWrite(conn *websocket.Conn, wc io.Writer) chan error {
connRWC := NewWrapper(conn)

go func() {
defer close(chDone)
_, err := io.Copy(wc, connRWC)

logger.Debug("Websocket: Stopped write mirror", logger.Ctx{"address": conn.RemoteAddr().String(), "err": err})
chDone <- err
close(chDone)
}()

return chDone
Expand Down

0 comments on commit a4a8f23

Please sign in to comment.