diff --git a/pkg/webservice/handlers.go b/pkg/webservice/handlers.go index 934685165..73afdafc9 100644 --- a/pkg/webservice/handlers.go +++ b/pkg/webservice/handlers.go @@ -127,7 +127,7 @@ func redirectDebug(w http.ResponseWriter, r *http.Request) { } func getStackInfo(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) var stack = func() []byte { buf := make([]byte, 1024) for { @@ -145,7 +145,7 @@ func getStackInfo(w http.ResponseWriter, r *http.Request) { } func getClusterInfo(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) lists := schedulerContext.Load().GetPartitionMapClone() clustersInfo := getClusterDAO(lists) @@ -167,7 +167,7 @@ func validateQueue(queuePath string) error { } func validateConf(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) requestBytes, err := io.ReadAll(r.Body) if err == nil { _, err = configs.LoadSchedulerConfigFromByteArray(requestBytes) @@ -184,11 +184,14 @@ func validateConf(w http.ResponseWriter, r *http.Request) { } } -func writeHeaders(w http.ResponseWriter) { +func writeHeaders(w http.ResponseWriter, method string) { w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Methods", "GET,POST,HEAD,OPTIONS") + methods := "GET, OPTIONS" + if method == http.MethodPost { + methods = "OPTIONS, POST" + } + w.Header().Set("Access-Control-Allow-Methods", methods) w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Accept,Origin") } @@ -233,7 +236,7 @@ func getClusterUtilJSON(partition *scheduler.PartitionContext) []*dao.ClusterUti } utils = append(utils, utilization) } - } else if !getResource { + } else { utilization := &dao.ClusterUtilDAOInfo{ ResourceType: "N/A", Total: int64(-1), @@ -446,7 +449,7 @@ func getNodesDAO(entries []*objects.Node) []*dao.NodeDAOInfo { // Only check the default partition // Deprecated - To be removed in next major release. Replaced with getNodesUtilisations func getNodeUtilisation(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) partitionContext := schedulerContext.Load().GetPartitionWithoutClusterID(configs.DefaultPartition) if partitionContext == nil { buildJSONErrorResponse(w, PartitionDoesNotExists, http.StatusInternalServerError) @@ -510,7 +513,7 @@ func getNodesUtilJSON(partition *scheduler.PartitionContext, name string) *dao.N } func getNodeUtilisations(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) var result []*dao.PartitionNodesUtilDAOInfo for _, part := range schedulerContext.Load().GetPartitionMapClone() { result = append(result, getPartitionNodesUtilJSON(part)) @@ -583,7 +586,7 @@ func getPartitionNodesUtilJSON(partition *scheduler.PartitionContext) *dao.Parti } func getApplicationHistory(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) // There is nothing to return but we did not really encounter a problem if imHistory == nil { @@ -600,7 +603,7 @@ func getApplicationHistory(w http.ResponseWriter, r *http.Request) { } func getContainerHistory(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) // There is nothing to return but we did not really encounter a problem if imHistory == nil { @@ -617,7 +620,7 @@ func getContainerHistory(w http.ResponseWriter, r *http.Request) { } func getClusterConfig(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) var marshalledConf []byte var err error @@ -653,7 +656,7 @@ func getClusterConfigDAO() *dao.ConfigDAOInfo { } func checkHealthStatus(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) // Fetch last healthCheck result result := schedulerContext.Load().GetLastHealthCheckResult() @@ -675,8 +678,8 @@ func checkHealthStatus(w http.ResponseWriter, r *http.Request) { } } -func getPartitions(w http.ResponseWriter, _ *http.Request) { - writeHeaders(w) +func getPartitions(w http.ResponseWriter, r *http.Request) { + writeHeaders(w, r.Method) lists := schedulerContext.Load().GetPartitionMapClone() partitionsInfo := getPartitionInfoDAO(lists) @@ -686,7 +689,7 @@ func getPartitions(w http.ResponseWriter, _ *http.Request) { } func getPartitionQueues(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -707,7 +710,7 @@ func getPartitionQueues(w http.ResponseWriter, r *http.Request) { } func getPartitionQueue(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -742,7 +745,7 @@ func getPartitionQueue(w http.ResponseWriter, r *http.Request) { } func getPartitionNodes(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -761,7 +764,7 @@ func getPartitionNodes(w http.ResponseWriter, r *http.Request) { } func getPartitionNode(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -786,7 +789,7 @@ func getPartitionNode(w http.ResponseWriter, r *http.Request) { } func getQueueApplications(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -827,7 +830,7 @@ func getQueueApplications(w http.ResponseWriter, r *http.Request) { } func getPartitionApplicationsByState(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -876,7 +879,7 @@ func getPartitionApplicationsByState(w http.ResponseWriter, r *http.Request) { } func getApplication(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -924,7 +927,7 @@ func getApplication(w http.ResponseWriter, r *http.Request) { } func getPartitionRules(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -943,7 +946,7 @@ func getPartitionRules(w http.ResponseWriter, r *http.Request) { } func getQueueApplicationsByState(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -1180,8 +1183,8 @@ func getMetrics(w http.ResponseWriter, r *http.Request) { promhttp.Handler().ServeHTTP(w, r) } -func getUsersResourceUsage(w http.ResponseWriter, _ *http.Request) { - writeHeaders(w) +func getUsersResourceUsage(w http.ResponseWriter, r *http.Request) { + writeHeaders(w, r.Method) userManager := ugm.GetUserManager() trackers := userManager.GetUserTrackers() result := make([]*dao.UserResourceUsageDAOInfo, len(trackers)) @@ -1194,7 +1197,7 @@ func getUsersResourceUsage(w http.ResponseWriter, _ *http.Request) { } func getUserResourceUsage(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -1222,7 +1225,7 @@ func getUserResourceUsage(w http.ResponseWriter, r *http.Request) { } func getGroupsResourceUsage(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) userManager := ugm.GetUserManager() trackers := userManager.GetGroupTrackers() result := make([]*dao.GroupResourceUsageDAOInfo, len(trackers)) @@ -1235,7 +1238,7 @@ func getGroupsResourceUsage(w http.ResponseWriter, r *http.Request) { } func getGroupResourceUsage(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) vars := httprouter.ParamsFromContext(r.Context()) if vars == nil { buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest) @@ -1263,7 +1266,7 @@ func getGroupResourceUsage(w http.ResponseWriter, r *http.Request) { } func getEvents(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) eventSystem := events.GetEventSystem() if !eventSystem.IsEventTrackingEnabled() { buildJSONErrorResponse(w, "Event tracking is disabled", http.StatusInternalServerError) @@ -1311,7 +1314,7 @@ func getEvents(w http.ResponseWriter, r *http.Request) { } func getStream(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) eventSystem := events.GetEventSystem() if !eventSystem.IsEventTrackingEnabled() { buildJSONErrorResponse(w, "Event tracking is disabled", http.StatusInternalServerError) diff --git a/pkg/webservice/handlers_test.go b/pkg/webservice/handlers_test.go index f262140b1..19cb1328c 100644 --- a/pkg/webservice/handlers_test.go +++ b/pkg/webservice/handlers_test.go @@ -1285,16 +1285,18 @@ func TestGetPartitionQueueHandler(t *testing.T) { func TestGetClusterInfo(t *testing.T) { schedulerContext.Store(&scheduler.ClusterContext{}) resp := &MockResponseWriter{} - getClusterInfo(resp, nil) + req, err := http.NewRequest("GET", "/ws/v1/clusters", strings.NewReader("")) + assert.NilError(t, err, "error while creating http request") + getClusterInfo(resp, req) var data []*dao.ClusterDAOInfo - err := json.Unmarshal(resp.outputBytes, &data) + err = json.Unmarshal(resp.outputBytes, &data) assert.NilError(t, err) assert.Equal(t, 0, len(data)) setup(t, configTwoLevelQueues, 2) resp = &MockResponseWriter{} - getClusterInfo(resp, nil) + getClusterInfo(resp, req) err = json.Unmarshal(resp.outputBytes, &data) assert.NilError(t, err) assert.Equal(t, 2, len(data)) @@ -1412,11 +1414,11 @@ func TestGetPartitionNode(t *testing.T) { _, allocCreated, err := partition.UpdateAllocation(alloc1) assert.NilError(t, err, "add alloc-1 should not have failed") assert.Check(t, allocCreated) - falloc1 := newForeignAlloc("foreign-1", "", node1ID, resAlloc1, siCommon.AllocTypeDefault, 0) + falloc1 := newForeignAlloc("foreign-1", node1ID, resAlloc1, siCommon.AllocTypeDefault, 0) _, allocCreated, err = partition.UpdateAllocation(falloc1) assert.NilError(t, err, "add falloc-1 should not have failed") assert.Check(t, allocCreated) - falloc2 := newForeignAlloc("foreign-2", "", node1ID, resAlloc2, siCommon.AllocTypeStatic, 123) + falloc2 := newForeignAlloc("foreign-2", node1ID, resAlloc2, siCommon.AllocTypeStatic, 123) _, allocCreated, err = partition.UpdateAllocation(falloc2) assert.NilError(t, err, "add falloc-2 should not have failed") assert.Check(t, allocCreated) @@ -1746,6 +1748,7 @@ func checkGetQueueAppByState(t *testing.T, partition, queue, state, status strin url = fmt.Sprintf("/ws/v1/partition/%s/queue/%s/applications/%s?status=%s", partition, queue, state, status) } req, err := http.NewRequest("GET", url, strings.NewReader("")) + assert.NilError(t, err, "unexpected error creating request") req = req.WithContext(context.WithValue(req.Context(), httprouter.ParamsKey, httprouter.Params{ httprouter.Param{Key: "partition", Value: partition}, httprouter.Param{Key: "queue", Value: queue}, @@ -1780,6 +1783,7 @@ func checkGetQueueAppByIllegalStateOrStatus(t *testing.T, partition, queue, stat url = fmt.Sprintf("/ws/v1/partition/%s/queue/%s/applications/%s?status=%s", partition, queue, state, status) } req, err := http.NewRequest("GET", url, strings.NewReader("")) + assert.NilError(t, err, "unexpected error creating request") req = req.WithContext(context.WithValue(req.Context(), httprouter.ParamsKey, httprouter.Params{ httprouter.Param{Key: "partition", Value: partition}, httprouter.Param{Key: "queue", Value: queue}, @@ -2115,9 +2119,9 @@ func TestFullStateDumpPath(t *testing.T) { prepareSchedulerContext(t) partitionContext := schedulerContext.Load().GetPartitionMapClone() - context := partitionContext[normalizedPartitionName] + ctx := partitionContext[normalizedPartitionName] app := newApplication("appID", normalizedPartitionName, "root.default", rmID, security.UserGroup{}) - err := context.AddApplication(app) + err := ctx.AddApplication(app) assert.NilError(t, err, "failed to add Application to partition") imHistory = history.NewInternalMetricsHistory(5) @@ -3053,7 +3057,7 @@ func newAlloc(allocationKey string, appID string, nodeID string, resAlloc *resou }) } -func newForeignAlloc(allocationKey string, appID string, nodeID string, resAlloc *resources.Resource, fType string, priority int32) *objects.Allocation { +func newForeignAlloc(allocationKey string, nodeID string, resAlloc *resources.Resource, fType string, priority int32) *objects.Allocation { return objects.NewAllocationFromSI(&si.Allocation{ AllocationKey: allocationKey, NodeID: nodeID, diff --git a/pkg/webservice/state_dump.go b/pkg/webservice/state_dump.go index 5cb7efc0e..a37ef2cb4 100644 --- a/pkg/webservice/state_dump.go +++ b/pkg/webservice/state_dump.go @@ -54,7 +54,7 @@ type AggregatedStateInfo struct { } func getFullStateDump(w http.ResponseWriter, r *http.Request) { - writeHeaders(w) + writeHeaders(w, r.Method) if err := doStateDump(w); err != nil { buildJSONErrorResponse(w, err.Error(), http.StatusInternalServerError) } diff --git a/pkg/webservice/webservice_test.go b/pkg/webservice/webservice_test.go index 1b723f4a0..5c8cb66bc 100644 --- a/pkg/webservice/webservice_test.go +++ b/pkg/webservice/webservice_test.go @@ -30,6 +30,8 @@ import ( "github.com/apache/yunikorn-core/pkg/scheduler" ) +const base = "http://localhost:9080" + func Test_RedirectDebugHandler(t *testing.T) { defer ResetIMHistory() s := NewWebApp(&scheduler.ClusterContext{}, history.NewInternalMetricsHistory(5)) @@ -40,7 +42,6 @@ func Test_RedirectDebugHandler(t *testing.T) { t.Fatal("failed to stop webapp") } }(s) - base := "http://localhost:9080" tests := []struct { name string reqURL string @@ -76,7 +77,6 @@ func Test_RouterHandling(t *testing.T) { t.Fatal("failed to stop webapp") } }(s) - base := "http://localhost:9080" client := &http.Client{} // unsupported POST resp, err := client.Post(base+"/ws/v1/clusters", "application/json; charset=UTF-8", nil) @@ -105,3 +105,48 @@ func Test_RouterHandling(t *testing.T) { _ = resp.Body.Close() assert.Equal(t, resp.StatusCode, http.StatusOK, "expected OK") } + +func Test_HeaderChecks(t *testing.T) { + s := NewWebApp(&scheduler.ClusterContext{}, nil) + s.StartWebApp() + defer func(s *WebService) { + err := s.StopWebApp() + if err != nil { + t.Fatal("failed to stop webapp") + } + }(s) + client := http.DefaultClient + tests := []struct { + name string + reqURL string + method string + expected string + }{ + {"get options", "/ws/v1/clusters", http.MethodOptions, "GET, OPTIONS"}, + {"get", "/ws/v1/clusters", http.MethodGet, "GET, OPTIONS"}, + {"post options", "/ws/v1/validate-conf", http.MethodOptions, "OPTIONS, POST"}, + {"post", "/ws/v1/validate-conf", http.MethodPost, "OPTIONS, POST"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest(tt.method, base+tt.reqURL, nil) + assert.NilError(t, err, "unexpected error creating request") + var resp *http.Response + resp, err = client.Do(req) + assert.NilError(t, err, "unexpected error executing request") + assert.Equal(t, resp.StatusCode, http.StatusOK, "expected OK") + switch tt.method { + case http.MethodGet, http.MethodPost: + assert.Equal(t, resp.Header.Get("Access-Control-Allow-Methods"), tt.expected, "wrong methods returned") + case http.MethodOptions: + // OPTIONS requests are handled by default via httpdrouter, not defined in the routes + assert.Equal(t, resp.Header.Get("Allow"), tt.expected, "expected only get and options to be returned") + } + var body []byte + body, err = io.ReadAll(resp.Body) + _ = resp.Body.Close() + assert.NilError(t, err, "unexpected error reading body") + assert.Assert(t, body != nil, "expected body with status text") + }) + } +}