Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YUNIKORN-2967] Cleanup REST response headers #994

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 34 additions & 31 deletions pkg/webservice/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func redirectDebug(w http.ResponseWriter, r *http.Request) {
}

func getStackInfo(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
var stack = func() []byte {
buf := make([]byte, 1024)
for {
Expand All @@ -145,7 +145,7 @@ func getStackInfo(w http.ResponseWriter, r *http.Request) {
}

func getClusterInfo(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)

lists := schedulerContext.Load().GetPartitionMapClone()
clustersInfo := getClusterDAO(lists)
Expand All @@ -167,7 +167,7 @@ func validateQueue(queuePath string) error {
}

func validateConf(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
requestBytes, err := io.ReadAll(r.Body)
if err == nil {
_, err = configs.LoadSchedulerConfigFromByteArray(requestBytes)
Expand All @@ -184,11 +184,14 @@ func validateConf(w http.ResponseWriter, r *http.Request) {
}
}

func writeHeaders(w http.ResponseWriter) {
func writeHeaders(w http.ResponseWriter, method string) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,HEAD,OPTIONS")
methods := "GET, OPTIONS"
if method == http.MethodPost {
methods = "OPTIONS, POST"
}
w.Header().Set("Access-Control-Allow-Methods", methods)
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Accept,Origin")
}

Expand Down Expand Up @@ -233,7 +236,7 @@ func getClusterUtilJSON(partition *scheduler.PartitionContext) []*dao.ClusterUti
}
utils = append(utils, utilization)
}
} else if !getResource {
} else {
utilization := &dao.ClusterUtilDAOInfo{
ResourceType: "N/A",
Total: int64(-1),
Expand Down Expand Up @@ -446,7 +449,7 @@ func getNodesDAO(entries []*objects.Node) []*dao.NodeDAOInfo {
// Only check the default partition
// Deprecated - To be removed in next major release. Replaced with getNodesUtilisations
func getNodeUtilisation(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
partitionContext := schedulerContext.Load().GetPartitionWithoutClusterID(configs.DefaultPartition)
if partitionContext == nil {
buildJSONErrorResponse(w, PartitionDoesNotExists, http.StatusInternalServerError)
Expand Down Expand Up @@ -510,7 +513,7 @@ func getNodesUtilJSON(partition *scheduler.PartitionContext, name string) *dao.N
}

func getNodeUtilisations(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
var result []*dao.PartitionNodesUtilDAOInfo
for _, part := range schedulerContext.Load().GetPartitionMapClone() {
result = append(result, getPartitionNodesUtilJSON(part))
Expand Down Expand Up @@ -583,7 +586,7 @@ func getPartitionNodesUtilJSON(partition *scheduler.PartitionContext) *dao.Parti
}

func getApplicationHistory(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)

// There is nothing to return but we did not really encounter a problem
if imHistory == nil {
Expand All @@ -600,7 +603,7 @@ func getApplicationHistory(w http.ResponseWriter, r *http.Request) {
}

func getContainerHistory(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)

// There is nothing to return but we did not really encounter a problem
if imHistory == nil {
Expand All @@ -617,7 +620,7 @@ func getContainerHistory(w http.ResponseWriter, r *http.Request) {
}

func getClusterConfig(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)

var marshalledConf []byte
var err error
Expand Down Expand Up @@ -653,7 +656,7 @@ func getClusterConfigDAO() *dao.ConfigDAOInfo {
}

func checkHealthStatus(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)

// Fetch last healthCheck result
result := schedulerContext.Load().GetLastHealthCheckResult()
Expand All @@ -675,8 +678,8 @@ func checkHealthStatus(w http.ResponseWriter, r *http.Request) {
}
}

func getPartitions(w http.ResponseWriter, _ *http.Request) {
writeHeaders(w)
func getPartitions(w http.ResponseWriter, r *http.Request) {
writeHeaders(w, r.Method)

lists := schedulerContext.Load().GetPartitionMapClone()
partitionsInfo := getPartitionInfoDAO(lists)
Expand All @@ -686,7 +689,7 @@ func getPartitions(w http.ResponseWriter, _ *http.Request) {
}

func getPartitionQueues(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand All @@ -707,7 +710,7 @@ func getPartitionQueues(w http.ResponseWriter, r *http.Request) {
}

func getPartitionQueue(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -742,7 +745,7 @@ func getPartitionQueue(w http.ResponseWriter, r *http.Request) {
}

func getPartitionNodes(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand All @@ -761,7 +764,7 @@ func getPartitionNodes(w http.ResponseWriter, r *http.Request) {
}

func getPartitionNode(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand All @@ -786,7 +789,7 @@ func getPartitionNode(w http.ResponseWriter, r *http.Request) {
}

func getQueueApplications(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -827,7 +830,7 @@ func getQueueApplications(w http.ResponseWriter, r *http.Request) {
}

func getPartitionApplicationsByState(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -876,7 +879,7 @@ func getPartitionApplicationsByState(w http.ResponseWriter, r *http.Request) {
}

func getApplication(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -924,7 +927,7 @@ func getApplication(w http.ResponseWriter, r *http.Request) {
}

func getPartitionRules(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand All @@ -943,7 +946,7 @@ func getPartitionRules(w http.ResponseWriter, r *http.Request) {
}

func getQueueApplicationsByState(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -1180,8 +1183,8 @@ func getMetrics(w http.ResponseWriter, r *http.Request) {
promhttp.Handler().ServeHTTP(w, r)
}

func getUsersResourceUsage(w http.ResponseWriter, _ *http.Request) {
writeHeaders(w)
func getUsersResourceUsage(w http.ResponseWriter, r *http.Request) {
writeHeaders(w, r.Method)
userManager := ugm.GetUserManager()
trackers := userManager.GetUserTrackers()
result := make([]*dao.UserResourceUsageDAOInfo, len(trackers))
Expand All @@ -1194,7 +1197,7 @@ func getUsersResourceUsage(w http.ResponseWriter, _ *http.Request) {
}

func getUserResourceUsage(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -1222,7 +1225,7 @@ func getUserResourceUsage(w http.ResponseWriter, r *http.Request) {
}

func getGroupsResourceUsage(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
userManager := ugm.GetUserManager()
trackers := userManager.GetGroupTrackers()
result := make([]*dao.GroupResourceUsageDAOInfo, len(trackers))
Expand All @@ -1235,7 +1238,7 @@ func getGroupsResourceUsage(w http.ResponseWriter, r *http.Request) {
}

func getGroupResourceUsage(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -1263,7 +1266,7 @@ func getGroupResourceUsage(w http.ResponseWriter, r *http.Request) {
}

func getEvents(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
eventSystem := events.GetEventSystem()
if !eventSystem.IsEventTrackingEnabled() {
buildJSONErrorResponse(w, "Event tracking is disabled", http.StatusInternalServerError)
Expand Down Expand Up @@ -1311,7 +1314,7 @@ func getEvents(w http.ResponseWriter, r *http.Request) {
}

func getStream(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
eventSystem := events.GetEventSystem()
if !eventSystem.IsEventTrackingEnabled() {
buildJSONErrorResponse(w, "Event tracking is disabled", http.StatusInternalServerError)
Expand Down
20 changes: 12 additions & 8 deletions pkg/webservice/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1285,16 +1285,18 @@ func TestGetPartitionQueueHandler(t *testing.T) {
func TestGetClusterInfo(t *testing.T) {
schedulerContext.Store(&scheduler.ClusterContext{})
resp := &MockResponseWriter{}
getClusterInfo(resp, nil)
req, err := http.NewRequest("GET", "/ws/v1/clusters", strings.NewReader(""))
assert.NilError(t, err, "error while creating http request")
getClusterInfo(resp, req)
var data []*dao.ClusterDAOInfo
err := json.Unmarshal(resp.outputBytes, &data)
err = json.Unmarshal(resp.outputBytes, &data)
assert.NilError(t, err)
assert.Equal(t, 0, len(data))

setup(t, configTwoLevelQueues, 2)

resp = &MockResponseWriter{}
getClusterInfo(resp, nil)
getClusterInfo(resp, req)
err = json.Unmarshal(resp.outputBytes, &data)
assert.NilError(t, err)
assert.Equal(t, 2, len(data))
Expand Down Expand Up @@ -1412,11 +1414,11 @@ func TestGetPartitionNode(t *testing.T) {
_, allocCreated, err := partition.UpdateAllocation(alloc1)
assert.NilError(t, err, "add alloc-1 should not have failed")
assert.Check(t, allocCreated)
falloc1 := newForeignAlloc("foreign-1", "", node1ID, resAlloc1, siCommon.AllocTypeDefault, 0)
falloc1 := newForeignAlloc("foreign-1", node1ID, resAlloc1, siCommon.AllocTypeDefault, 0)
_, allocCreated, err = partition.UpdateAllocation(falloc1)
assert.NilError(t, err, "add falloc-1 should not have failed")
assert.Check(t, allocCreated)
falloc2 := newForeignAlloc("foreign-2", "", node1ID, resAlloc2, siCommon.AllocTypeStatic, 123)
falloc2 := newForeignAlloc("foreign-2", node1ID, resAlloc2, siCommon.AllocTypeStatic, 123)
_, allocCreated, err = partition.UpdateAllocation(falloc2)
assert.NilError(t, err, "add falloc-2 should not have failed")
assert.Check(t, allocCreated)
Expand Down Expand Up @@ -1746,6 +1748,7 @@ func checkGetQueueAppByState(t *testing.T, partition, queue, state, status strin
url = fmt.Sprintf("/ws/v1/partition/%s/queue/%s/applications/%s?status=%s", partition, queue, state, status)
}
req, err := http.NewRequest("GET", url, strings.NewReader(""))
assert.NilError(t, err, "unexpected error creating request")
req = req.WithContext(context.WithValue(req.Context(), httprouter.ParamsKey, httprouter.Params{
httprouter.Param{Key: "partition", Value: partition},
httprouter.Param{Key: "queue", Value: queue},
Expand Down Expand Up @@ -1780,6 +1783,7 @@ func checkGetQueueAppByIllegalStateOrStatus(t *testing.T, partition, queue, stat
url = fmt.Sprintf("/ws/v1/partition/%s/queue/%s/applications/%s?status=%s", partition, queue, state, status)
}
req, err := http.NewRequest("GET", url, strings.NewReader(""))
assert.NilError(t, err, "unexpected error creating request")
req = req.WithContext(context.WithValue(req.Context(), httprouter.ParamsKey, httprouter.Params{
httprouter.Param{Key: "partition", Value: partition},
httprouter.Param{Key: "queue", Value: queue},
Expand Down Expand Up @@ -2115,9 +2119,9 @@ func TestFullStateDumpPath(t *testing.T) {
prepareSchedulerContext(t)

partitionContext := schedulerContext.Load().GetPartitionMapClone()
context := partitionContext[normalizedPartitionName]
ctx := partitionContext[normalizedPartitionName]
app := newApplication("appID", normalizedPartitionName, "root.default", rmID, security.UserGroup{})
err := context.AddApplication(app)
err := ctx.AddApplication(app)
assert.NilError(t, err, "failed to add Application to partition")

imHistory = history.NewInternalMetricsHistory(5)
Expand Down Expand Up @@ -3053,7 +3057,7 @@ func newAlloc(allocationKey string, appID string, nodeID string, resAlloc *resou
})
}

func newForeignAlloc(allocationKey string, appID string, nodeID string, resAlloc *resources.Resource, fType string, priority int32) *objects.Allocation {
func newForeignAlloc(allocationKey string, nodeID string, resAlloc *resources.Resource, fType string, priority int32) *objects.Allocation {
return objects.NewAllocationFromSI(&si.Allocation{
AllocationKey: allocationKey,
NodeID: nodeID,
Expand Down
2 changes: 1 addition & 1 deletion pkg/webservice/state_dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type AggregatedStateInfo struct {
}

func getFullStateDump(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
if err := doStateDump(w); err != nil {
buildJSONErrorResponse(w, err.Error(), http.StatusInternalServerError)
}
Expand Down
Loading
Loading