Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix devlxd image export #13730

Merged
merged 12 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lxd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func (d *Daemon) Authenticate(w http.ResponseWriter, r *http.Request) (trusted b
}

// Devlxd unix socket credentials on main API.
if r.RemoteAddr == "@devlxd" {
if r.RemoteAddr == devlxdRemoteAddress {
return false, "", "", nil, fmt.Errorf("Main API query can't come from /dev/lxd socket")
}

Expand Down
111 changes: 83 additions & 28 deletions lxd/devlxd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -30,8 +31,12 @@ import (
"github.com/canonical/lxd/shared/ws"
)

const devlxdRemoteAddress = "@devlxd"

type hoistFunc func(f func(*Daemon, instance.Instance, http.ResponseWriter, *http.Request) response.Response, d *Daemon) func(http.ResponseWriter, *http.Request)

type devlxdHandlerFunc func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response

// DevLxdServer creates an http.Server capable of handling requests against the
// /dev/lxd Unix socket endpoint created inside containers.
func devLxdServer(d *Daemon) *http.Server {
Expand All @@ -51,10 +56,15 @@ type devLxdHandler struct {
* server side right now either, I went the simple route to avoid
* needless noise.
*/
f func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response
handlerFunc devlxdHandlerFunc
}

var devlxdConfigGet = devLxdHandler{
path: "/1.0/config",
handlerFunc: devlxdConfigGetHandler,
}

var devlxdConfigGet = devLxdHandler{"/1.0/config", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdConfigGetHandler(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(c.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), c.Type() == instancetype.VM)
}
Expand All @@ -67,9 +77,14 @@ var devlxdConfigGet = devLxdHandler{"/1.0/config", func(d *Daemon, c instance.In
}

return response.DevLxdResponse(http.StatusOK, filtered, "json", c.Type() == instancetype.VM)
}}
}

var devlxdConfigKeyGet = devLxdHandler{
path: "/1.0/config/{key}",
handlerFunc: devlxdConfigKeyGetHandler,
}

var devlxdConfigKeyGet = devLxdHandler{"/1.0/config/{key}", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdConfigKeyGetHandler(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(c.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), c.Type() == instancetype.VM)
}
Expand All @@ -89,9 +104,14 @@ var devlxdConfigKeyGet = devLxdHandler{"/1.0/config/{key}", func(d *Daemon, c in
}

return response.DevLxdResponse(http.StatusOK, value, "raw", c.Type() == instancetype.VM)
}}
}

var devlxdImageExport = devLxdHandler{
path: "/1.0/images/{fingerprint}/export",
handlerFunc: devlxdImageExportHandler,
}

var devlxdImageExport = devLxdHandler{"/1.0/images/{fingerprint}/export", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdImageExportHandler(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(c.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), c.Type() == instancetype.VM)
}
Expand All @@ -101,7 +121,7 @@ var devlxdImageExport = devLxdHandler{"/1.0/images/{fingerprint}/export", func(d
}

// Use by security checks to distinguish devlxd vs lxd APIs
r.RemoteAddr = "@devlxd"
r.RemoteAddr = devlxdRemoteAddress

resp := imageExport(d, r)

Expand All @@ -111,19 +131,29 @@ var devlxdImageExport = devLxdHandler{"/1.0/images/{fingerprint}/export", func(d
}

return response.DevLxdResponse(http.StatusOK, "", "raw", c.Type() == instancetype.VM)
}}
}

var devlxdMetadataGet = devLxdHandler{
path: "/1.0/meta-data",
handlerFunc: devlxdMetadataGetHandler,
}

var devlxdMetadataGet = devLxdHandler{"/1.0/meta-data", func(d *Daemon, inst instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdMetadataGetHandler(d *Daemon, inst instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(inst.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), inst.Type() == instancetype.VM)
}

value := inst.ExpandedConfig()["user.meta-data"]

return response.DevLxdResponse(http.StatusOK, fmt.Sprintf("#cloud-config\ninstance-id: %s\nlocal-hostname: %s\n%s", inst.CloudInitID(), inst.Name(), value), "raw", inst.Type() == instancetype.VM)
}}
}

var devlxdEventsGet = devLxdHandler{
path: "/1.0/events",
handlerFunc: devlxdEventsGetHandler,
}

var devlxdEventsGet = devLxdHandler{"/1.0/events", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdEventsGetHandler(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(c.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), c.Type() == instancetype.VM)
}
Expand Down Expand Up @@ -178,9 +208,14 @@ var devlxdEventsGet = devLxdHandler{"/1.0/events", func(d *Daemon, c instance.In
listener.Wait(r.Context())

return resp
}}
}

var devlxdAPIHandler = devLxdHandler{
path: "/1.0",
handlerFunc: devlxdAPIHandlerFunc,
}

var devlxdAPIHandler = devLxdHandler{"/1.0", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdAPIHandlerFunc(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
s := d.State()

if r.Method == "GET" {
Expand Down Expand Up @@ -236,10 +271,14 @@ var devlxdAPIHandler = devLxdHandler{"/1.0", func(d *Daemon, c instance.Instance
}

return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusMethodNotAllowed, fmt.Sprintf("method %q not allowed", r.Method)), c.Type() == instancetype.VM)
}

}}
var devlxdDevicesGet = devLxdHandler{
path: "/1.0/devices",
handlerFunc: devlxdDevicesGetHandler,
}

var devlxdDevicesGet = devLxdHandler{"/1.0/devices", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
func devlxdDevicesGetHandler(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
if shared.IsFalse(c.ExpandedConfig()["security.devlxd"]) {
return response.DevLxdErrorResponse(api.StatusErrorf(http.StatusForbidden, "not authorized"), c.Type() == instancetype.VM)
}
Expand All @@ -256,12 +295,15 @@ var devlxdDevicesGet = devLxdHandler{"/1.0/devices", func(d *Daemon, c instance.
}

return response.DevLxdResponse(http.StatusOK, c.ExpandedDevices(), "json", c.Type() == instancetype.VM)
}}
}

var handlers = []devLxdHandler{
{"/", func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
return response.DevLxdResponse(http.StatusOK, []string{"/1.0"}, "json", c.Type() == instancetype.VM)
}},
{
path: "/",
handlerFunc: func(d *Daemon, c instance.Instance, w http.ResponseWriter, r *http.Request) response.Response {
return response.DevLxdResponse(http.StatusOK, []string{"/1.0"}, "json", c.Type() == instancetype.VM)
},
},
devlxdAPIHandler,
devlxdConfigGet,
devlxdConfigKeyGet,
Expand All @@ -276,7 +318,7 @@ func hoistReq(f func(*Daemon, instance.Instance, http.ResponseWriter, *http.Requ
conn := ucred.GetConnFromContext(r.Context())
cred, ok := pidMapper.m[conn.(*net.UnixConn)]
if !ok {
http.Error(w, pidNotInContainerErr.Error(), http.StatusInternalServerError)
http.Error(w, errPIDNotInContainer.Error(), http.StatusInternalServerError)
return
}

Expand Down Expand Up @@ -312,7 +354,7 @@ func devLxdAPI(d *Daemon, f hoistFunc) http.Handler {
m.UseEncodedPath() // Allow encoded values in path segments.

for _, handler := range handlers {
m.HandleFunc(handler.path, f(handler.f, d))
m.HandleFunc(handler.path, f(handler.handlerFunc, d))
}

return m
Expand Down Expand Up @@ -345,18 +387,27 @@ func devLxdAPI(d *Daemon, f hoistFunc) http.Handler {
*/
var pidMapper = ConnPidMapper{m: map[*net.UnixConn]*unix.Ucred{}}

// ConnPidMapper is threadsafe cache of unix connections to process IDs. We use this in hoistReq to determine
// the instance that the connection has been made from.
type ConnPidMapper struct {
m map[*net.UnixConn]*unix.Ucred
mLock sync.Mutex
}

// ConnStateHandler is used in the `ConnState` field of the devlxd http.Server so that we can cache the process ID of the
// caller when a new connection is made and delete it when the connection is closed.
func (m *ConnPidMapper) ConnStateHandler(conn net.Conn, state http.ConnState) {
unixConn := conn.(*net.UnixConn)
unixConn, _ := conn.(*net.UnixConn)
if unixConn == nil {
logger.Error("Invalid type for devlxd connection", logger.Ctx{"conn_type": fmt.Sprintf("%T", conn)})
return
tomponline marked this conversation as resolved.
Show resolved Hide resolved
}

switch state {
case http.StateNew:
cred, err := ucred.GetCred(unixConn)
if err != nil {
logger.Debugf("Error getting ucred for conn %s", err)
logger.Debug("Error getting ucred for devlxd connection", logger.Ctx{"error": err})
} else {
m.mLock.Lock()
m.m[unixConn] = cred
Expand Down Expand Up @@ -384,11 +435,11 @@ func (m *ConnPidMapper) ConnStateHandler(conn net.Conn, state http.ConnState) {
delete(m.m, unixConn)
m.mLock.Unlock()
default:
logger.Debugf("Unknown state for connection %s", state)
logger.Debug("Unknown state for devlxd connection", logger.Ctx{"state": state.String()})
}
}

var pidNotInContainerErr = fmt.Errorf("pid not in container?")
var errPIDNotInContainer = errors.New("Process ID not found in container")

func findContainerForPid(pid int32, s *state.State) (instance.Container, error) {
/*
Expand Down Expand Up @@ -437,7 +488,9 @@ func findContainerForPid(pid int32, s *state.State) (instance.Container, error)
return nil, fmt.Errorf("Instance is not container type")
}

return inst.(instance.Container), nil
// Explicitly ignore type assertion check. We've just checked that it's a container.
c, _ := inst.(instance.Container)
return c, nil
}

status, err := os.ReadFile(fmt.Sprintf("/proc/%d/status", pid))
Expand Down Expand Up @@ -490,9 +543,11 @@ func findContainerForPid(pid int32, s *state.State) (instance.Container, error)
}

if origPidNs == pidNs {
return inst.(instance.Container), nil
// Explicitly ignore type assertion check. The instance must be a container if we've found it via the process ID.
c, _ := inst.(instance.Container)
return c, nil
}
}

return nil, pidNotInContainerErr
return nil, errPIDNotInContainer
}
2 changes: 1 addition & 1 deletion lxd/devlxd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func TestHttpRequest(t *testing.T) {
t.Fatal(err)
}

if !strings.Contains(string(resp), pidNotInContainerErr.Error()) {
if !strings.Contains(string(resp), errPIDNotInContainer.Error()) {
t.Fatal("resp error not expected: ", string(resp))
}
}
Loading
Loading