Skip to content

Commit

Permalink
api/server/middleware: NewVersionMiddleware: add validation
Browse files Browse the repository at this point in the history
Make sure the middleware cannot be initialized with out of range versions.

Signed-off-by: Sebastiaan van Stijn <[email protected]>
  • Loading branch information
thaJeztah committed Feb 6, 2024
1 parent e1897cb commit 14503cc
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 15 deletions.
16 changes: 13 additions & 3 deletions api/server/middleware/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"runtime"

"github.com/docker/docker/api"
"github.com/docker/docker/api/server/httputils"
"github.com/docker/docker/api/types/versions"
)
Expand All @@ -32,12 +33,21 @@ type VersionMiddleware struct {
}

// NewVersionMiddleware creates a VersionMiddleware with the given versions.
func NewVersionMiddleware(serverVersion, defaultAPIVersion, minAPIVersion string) VersionMiddleware {
return VersionMiddleware{
func NewVersionMiddleware(serverVersion, defaultAPIVersion, minAPIVersion string) (*VersionMiddleware, error) {
if versions.LessThan(defaultAPIVersion, api.MinSupportedAPIVersion) || versions.GreaterThan(defaultAPIVersion, api.DefaultVersion) {
return nil, fmt.Errorf("invalid default API version (%s): must be between %s and %s", defaultAPIVersion, api.MinSupportedAPIVersion, api.DefaultVersion)
}
if versions.LessThan(minAPIVersion, api.MinSupportedAPIVersion) || versions.GreaterThan(minAPIVersion, api.DefaultVersion) {
return nil, fmt.Errorf("invalid minimum API version (%s): must be between %s and %s", minAPIVersion, api.MinSupportedAPIVersion, api.DefaultVersion)
}
if versions.GreaterThan(minAPIVersion, defaultAPIVersion) {
return nil, fmt.Errorf("invalid API version: the minimum API version (%s) is higher than the default version (%s)", minAPIVersion, defaultAPIVersion)
}
return &VersionMiddleware{
serverVersion: serverVersion,
defaultAPIVersion: defaultAPIVersion,
minAPIVersion: minAPIVersion,
}
}, nil
}

type versionUnsupportedError struct {
Expand Down
62 changes: 59 additions & 3 deletions api/server/middleware/version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,60 @@ import (
is "gotest.tools/v3/assert/cmp"
)

func TestNewVersionMiddlewareValidation(t *testing.T) {
tests := []struct {
doc, defaultVersion, minVersion, expectedErr string
}{
{
doc: "defaults",
defaultVersion: api.DefaultVersion,
minVersion: api.MinSupportedAPIVersion,
},
{
doc: "invalid default lower than min",
defaultVersion: api.MinSupportedAPIVersion,
minVersion: api.DefaultVersion,
expectedErr: fmt.Sprintf("invalid API version: the minimum API version (%s) is higher than the default version (%s)", api.DefaultVersion, api.MinSupportedAPIVersion),
},
{
doc: "invalid default too low",
defaultVersion: "0.1",
minVersion: api.MinSupportedAPIVersion,
expectedErr: fmt.Sprintf("invalid default API version (0.1): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
},
{
doc: "invalid default too high",
defaultVersion: "9999.9999",
minVersion: api.DefaultVersion,
expectedErr: fmt.Sprintf("invalid default API version (9999.9999): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
},
{
doc: "invalid minimum too low",
defaultVersion: api.MinSupportedAPIVersion,
minVersion: "0.1",
expectedErr: fmt.Sprintf("invalid minimum API version (0.1): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
},
{
doc: "invalid minimum too high",
defaultVersion: api.DefaultVersion,
minVersion: "9999.9999",
expectedErr: fmt.Sprintf("invalid minimum API version (9999.9999): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.doc, func(t *testing.T) {
_, err := NewVersionMiddleware("1.2.3", tc.defaultVersion, tc.minVersion)
if tc.expectedErr == "" {
assert.Check(t, err)
} else {
assert.Check(t, is.Error(err, tc.expectedErr))
}
})
}
}

func TestVersionMiddlewareVersion(t *testing.T) {
expectedVersion := "<not set>"
handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
Expand All @@ -22,7 +76,8 @@ func TestVersionMiddlewareVersion(t *testing.T) {
return nil
}

m := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
m, err := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
assert.NilError(t, err)
h := m.WrapHandler(handler)

req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
Expand Down Expand Up @@ -71,15 +126,16 @@ func TestVersionMiddlewareWithErrorsReturnsHeaders(t *testing.T) {
return nil
}

m := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
m, err := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
assert.NilError(t, err)
h := m.WrapHandler(handler)

req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
resp := httptest.NewRecorder()
ctx := context.Background()

vars := map[string]string{"version": "0.1"}
err := h(ctx, resp, req, vars)
err = h(ctx, resp, req, vars)
assert.Check(t, is.ErrorContains(err, ""))

hdr := resp.Result().Header
Expand Down
7 changes: 5 additions & 2 deletions api/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ import (
func TestMiddlewares(t *testing.T) {
srv := &Server{}

const apiMinVersion = "1.12"
srv.UseMiddleware(middleware.NewVersionMiddleware("0.1omega2", api.DefaultVersion, apiMinVersion))
m, err := middleware.NewVersionMiddleware("0.1omega2", api.DefaultVersion, api.MinSupportedAPIVersion)
if err != nil {
t.Fatal(err)
}
srv.UseMiddleware(*m)

req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
resp := httptest.NewRecorder()
Expand Down
18 changes: 11 additions & 7 deletions cmd/dockerd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
pluginStore := plugin.NewStore()

var apiServer apiserver.Server
cli.authzMiddleware = initMiddlewares(&apiServer, cli.Config, pluginStore)
cli.authzMiddleware, err = initMiddlewares(&apiServer, cli.Config, pluginStore)
if err != nil {
return errors.Wrap(err, "failed to start API server")
}

d, err := daemon.NewDaemon(ctx, cli.Config, pluginStore, cli.authzMiddleware)
if err != nil {
Expand Down Expand Up @@ -708,14 +711,15 @@ func (opts routerOptions) Build() []router.Router {
return routers
}

func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) *authorization.Middleware {
v := dockerversion.Version

func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) (*authorization.Middleware, error) {
exp := middleware.NewExperimentalMiddleware(cfg.Experimental)
s.UseMiddleware(exp)

vm := middleware.NewVersionMiddleware(v, api.DefaultVersion, cfg.MinAPIVersion)
s.UseMiddleware(vm)
vm, err := middleware.NewVersionMiddleware(dockerversion.Version, api.DefaultVersion, cfg.MinAPIVersion)
if err != nil {
return nil, err
}
s.UseMiddleware(*vm)

if cfg.CorsHeaders != "" {
c := middleware.NewCORSMiddleware(cfg.CorsHeaders)
Expand All @@ -724,7 +728,7 @@ func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugin

authzMiddleware := authorization.NewMiddleware(cfg.AuthorizationPlugins, pluginStore)
s.UseMiddleware(authzMiddleware)
return authzMiddleware
return authzMiddleware, nil
}

func (cli *DaemonCli) getContainerdDaemonOpts() ([]supervisor.DaemonOpt, error) {
Expand Down

0 comments on commit 14503cc

Please sign in to comment.