From 9bfafaca7118719c6ad907a3c56d83c2bb11ec78 Mon Sep 17 00:00:00 2001 From: Tianze Shan Date: Thu, 3 Oct 2024 16:11:48 -0700 Subject: [PATCH 1/2] Remove internal error messages from fault injection HTTP response (#4381) * Remove internal error messages from fault injection HTTP response --------- Co-authored-by: mye956 Co-authored-by: xingzhen Co-authored-by: Tianze Shan --- .../handlers/fault/v1/handlers/handlers.go | 41 ++++++++++--------- .../handlers/fault/v1/handlers/handlers.go | 41 ++++++++++--------- .../fault/v1/handlers/handlers_test.go | 7 ++-- 3 files changed, 47 insertions(+), 42 deletions(-) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 4c5b5a3a520..f3f3bf3acef 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -42,12 +42,15 @@ import ( ) const ( + // Request types startFaultRequestType = "start %s" stopFaultRequestType = "stop %s" checkStatusFaultRequestType = "check status %s" - invalidNetworkModeError = "%s mode is not supported. Please use either host or awsvpc mode." - faultInjectionEnabledError = "enableFaultInjection is not enabled for task: %s" - requestTimedOutError = "%s: request timed out" + // Error messages + internalError = "internal error" + invalidNetworkModeError = "%s mode is not supported. Please use either host or awsvpc mode." + faultInjectionEnabledError = "enableFaultInjection is not enabled for task: %s" + requestTimedOutError = "%s: request timed out" // This is our initial assumption of how much time it would take for the Linux commands used to inject faults // to finish. This will be confirmed/updated after more testing. requestTimeoutDuration = 5 * time.Second @@ -147,14 +150,14 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht insertTable = "OUTPUT" } - cmdOutput, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + _, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, networkMode, networkNSPath, insertTable, taskArn) if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { statusCode = http.StatusInternalServerError responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) } else if cmdErr != nil { statusCode = http.StatusInternalServerError - responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) } else { statusCode = http.StatusOK responseBody = types.NewNetworkFaultInjectionSuccessResponse("running") @@ -293,7 +296,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt insertTable = "OUTPUT" } - cmdOutput, cmdErr := h.stopNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + _, cmdErr := h.stopNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, networkMode, networkNSPath, insertTable, taskArn) if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { @@ -301,7 +304,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) } else if cmdErr != nil { statusCode = http.StatusInternalServerError - responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) } else { statusCode = http.StatusOK responseBody = types.NewNetworkFaultInjectionSuccessResponse("stopped") @@ -436,7 +439,7 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht stringToBeLogged := "Failed to check fault" port := strconv.FormatUint(uint64(aws.Uint16Value(request.Port)), 10) chainName := fmt.Sprintf("%s-%s-%s", aws.StringValue(request.TrafficType), aws.StringValue(request.Protocol), port) - running, cmdOutput, cmdErr := h.checkNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + running, _, cmdErr := h.checkNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, networkMode, networkNSPath, taskArn) // We've timed out trying to check if the black hole port fault injection is running @@ -445,7 +448,7 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) } else if cmdErr != nil { statusCode = http.StatusInternalServerError - responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) } else { statusCode = http.StatusOK if running { @@ -553,7 +556,7 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { // If there already exists a fault in the task network namespace. @@ -570,7 +573,7 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully started fault" @@ -625,7 +628,7 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { // If there doesn't already exist a network-latency fault @@ -640,7 +643,7 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully stopped fault" @@ -696,7 +699,7 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully checked fault status" @@ -766,7 +769,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { // If there already exists a fault in the task network namespace. @@ -783,7 +786,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully started fault" @@ -838,7 +841,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { // If there doesn't already exist a network-packet-loss fault @@ -853,7 +856,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully stopped fault" @@ -909,7 +912,7 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully checked fault status" diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 4c5b5a3a520..f3f3bf3acef 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -42,12 +42,15 @@ import ( ) const ( + // Request types startFaultRequestType = "start %s" stopFaultRequestType = "stop %s" checkStatusFaultRequestType = "check status %s" - invalidNetworkModeError = "%s mode is not supported. Please use either host or awsvpc mode." - faultInjectionEnabledError = "enableFaultInjection is not enabled for task: %s" - requestTimedOutError = "%s: request timed out" + // Error messages + internalError = "internal error" + invalidNetworkModeError = "%s mode is not supported. Please use either host or awsvpc mode." + faultInjectionEnabledError = "enableFaultInjection is not enabled for task: %s" + requestTimedOutError = "%s: request timed out" // This is our initial assumption of how much time it would take for the Linux commands used to inject faults // to finish. This will be confirmed/updated after more testing. requestTimeoutDuration = 5 * time.Second @@ -147,14 +150,14 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht insertTable = "OUTPUT" } - cmdOutput, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + _, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, networkMode, networkNSPath, insertTable, taskArn) if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { statusCode = http.StatusInternalServerError responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) } else if cmdErr != nil { statusCode = http.StatusInternalServerError - responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) } else { statusCode = http.StatusOK responseBody = types.NewNetworkFaultInjectionSuccessResponse("running") @@ -293,7 +296,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt insertTable = "OUTPUT" } - cmdOutput, cmdErr := h.stopNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + _, cmdErr := h.stopNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, networkMode, networkNSPath, insertTable, taskArn) if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) { @@ -301,7 +304,7 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) } else if cmdErr != nil { statusCode = http.StatusInternalServerError - responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) } else { statusCode = http.StatusOK responseBody = types.NewNetworkFaultInjectionSuccessResponse("stopped") @@ -436,7 +439,7 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht stringToBeLogged := "Failed to check fault" port := strconv.FormatUint(uint64(aws.Uint16Value(request.Port)), 10) chainName := fmt.Sprintf("%s-%s-%s", aws.StringValue(request.TrafficType), aws.StringValue(request.Protocol), port) - running, cmdOutput, cmdErr := h.checkNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, + running, _, cmdErr := h.checkNetworkBlackHolePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName, networkMode, networkNSPath, taskArn) // We've timed out trying to check if the black hole port fault injection is running @@ -445,7 +448,7 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) } else if cmdErr != nil { statusCode = http.StatusInternalServerError - responseBody = types.NewNetworkFaultInjectionErrorResponse(cmdOutput) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) } else { statusCode = http.StatusOK if running { @@ -553,7 +556,7 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { // If there already exists a fault in the task network namespace. @@ -570,7 +573,7 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully started fault" @@ -625,7 +628,7 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { // If there doesn't already exist a network-latency fault @@ -640,7 +643,7 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully stopped fault" @@ -696,7 +699,7 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully checked fault status" @@ -766,7 +769,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { // If there already exists a fault in the task network namespace. @@ -783,7 +786,7 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully started fault" @@ -838,7 +841,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { // If there doesn't already exist a network-packet-loss fault @@ -853,7 +856,7 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully stopped fault" @@ -909,7 +912,7 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. responseBody = types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf(requestTimedOutError, requestType)) httpStatusCode = http.StatusInternalServerError } else if err != nil { - responseBody = types.NewNetworkFaultInjectionErrorResponse(err.Error()) + responseBody = types.NewNetworkFaultInjectionErrorResponse(internalError) httpStatusCode = http.StatusInternalServerError } else { stringToBeLogged = "Successfully checked fault status" diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index 590af933acd..f350cce4090 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -54,7 +54,6 @@ const ( awsvpcNetworkMode = "awsvpc" deviceName = "eth0" invalidNetworkMode = "invalid" - internalError = "internal error" iptablesChainAlreadyExistError = "iptables: Chain already exists." iptablesChainNotFoundError = "iptables: Bad rule (does a matching rule exist in that chain?)." tcLatencyFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","parent":"1:1","options":{"limit":1000,"delay":{"delay":123456789,"jitter":4567,"correlation":0},"ecn":false,"gap":0}}]` @@ -581,7 +580,7 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase name: fmt.Sprintf("%s fail duplicate chain", startNetworkBlackHolePortTestPrefix), expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, - expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(iptablesChainAlreadyExistError), + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) { agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient). Return(happyTaskResponse, nil). @@ -1009,7 +1008,7 @@ func generateCommonNetworkLatencyTestCases(name string) []networkFaultInjectionT "Sources": ipSources, "Unknown": "", }, - expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("failed to check existing network fault: failed to unmarshal tc command output: unexpected end of JSON input. TaskArn: taskArn"), + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) { agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).Return(happyTaskResponse, nil) }, @@ -1564,7 +1563,7 @@ func generateCommonNetworkPacketLossTestCases(name string) []networkFaultInjecti "Unknown": "", "SourcesToFilter": []string{}, }, - expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("failed to check existing network fault: failed to unmarshal tc command output: unexpected end of JSON input. TaskArn: taskArn"), + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) { agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).Return(happyTaskResponse, nil) }, From dcd2dd5e32a492da6fe50aec341ac79c23cbdc20 Mon Sep 17 00:00:00 2001 From: mye956 Date: Mon, 30 Sep 2024 17:35:17 +0000 Subject: [PATCH 2/2] Incorporating telemetry middleware into fault handlers --- agent/handlers/task_server_setup.go | 90 +++++++++++++++++-- .../handlers/task_server_setup_integ_test.go | 4 +- agent/handlers/task_server_setup_test.go | 28 +++--- 3 files changed, 101 insertions(+), 21 deletions(-) diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index 9633ac746af..47718bd5c00 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -211,43 +211,115 @@ func registerFaultHandlers( // Setting up handler endpoints for network blackhole port fault injections muxRouter.Handle( fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkBlackholePort()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StartNetworkBlackholePort(), + ), + metricsFactory, + faulttype.StartNetworkFaultPostfix, + faulttype.BlackHolePortFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkBlackHolePort()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StopNetworkBlackHolePort(), + ), + metricsFactory, + faulttype.StopNetworkFaultPostfix, + faulttype.BlackHolePortFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkBlackHolePort()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.CheckNetworkBlackHolePort(), + ), + metricsFactory, + faulttype.CheckNetworkFaultPostfix, + faulttype.BlackHolePortFaultType, + ), ).Methods("POST") // Setting up handler endpoints for network latency fault injections muxRouter.Handle( fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkLatency()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StartNetworkLatency(), + ), + metricsFactory, + faulttype.StartNetworkFaultPostfix, + faulttype.LatencyFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkLatency()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StopNetworkLatency(), + ), + metricsFactory, + faulttype.StopNetworkFaultPostfix, + faulttype.LatencyFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkLatency()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.CheckNetworkLatency(), + ), + metricsFactory, + faulttype.CheckNetworkFaultPostfix, + faulttype.LatencyFaultType, + ), ).Methods("POST") // Setting up handler endpoints for network packet loss fault injections muxRouter.Handle( fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StartNetworkPacketLoss()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StartNetworkPacketLoss(), + ), + metricsFactory, + faulttype.StartNetworkFaultPostfix, + faulttype.PacketLossFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.StopNetworkPacketLoss()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.StopNetworkPacketLoss(), + ), + metricsFactory, + faulttype.StopNetworkFaultPostfix, + faulttype.PacketLossFaultType, + ), ).Methods("POST") muxRouter.Handle( fault.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix), - tollbooth.LimitFuncHandler(createRateLimiter(), handler.CheckNetworkPacketLoss()), + fault.TelemetryMiddleware( + tollbooth.LimitFuncHandler( + createRateLimiter(), + handler.CheckNetworkPacketLoss(), + ), + metricsFactory, + faulttype.CheckNetworkFaultPostfix, + faulttype.PacketLossFaultType, + ), ).Methods("POST") seelog.Debug("Successfully set up Fault TMDS handlers") diff --git a/agent/handlers/task_server_setup_integ_test.go b/agent/handlers/task_server_setup_integ_test.go index 63e057d3174..109e1f08d22 100644 --- a/agent/handlers/task_server_setup_integ_test.go +++ b/agent/handlers/task_server_setup_integ_test.go @@ -28,7 +28,7 @@ import ( agentV4 "github.com/aws/amazon-ecs-agent/agent/handlers/v4" mock_stats "github.com/aws/amazon-ecs-agent/agent/stats/mock" mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" - mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" "github.com/golang/mock/gomock" "github.com/gorilla/mux" @@ -56,7 +56,7 @@ func startServer(t *testing.T) (*http.Server, int) { ecsClient := mock_ecs.NewMockECSClient(ctrl) agentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn) - metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) + metricsFactory := metrics.NewNopEntryFactory() execWrapper := mock_execwrapper.NewMockExec(ctrl) registerFaultHandlers(router, agentState, metricsFactory, execWrapper) diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 4ae6597b922..c328cd077a9 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -132,6 +132,7 @@ const ( tcLatencyFaultExistsCommandOutput = `[{"kind":"netem","handle":"10:","parent":"1:1","options":{"limit":1000,"delay":{"delay":123456789,"jitter":4567,"correlation":0},"ecn":false,"gap":0}}]` tcCommandEmptyOutput = `[]` requestTimeoutDuration = 5 * time.Second + durationMetricPrefix = "MetadataServer.%s%sDuration" ) var ( @@ -3808,7 +3809,7 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StartNetworkFaultPostfix), faulttype.StartNetworkFaultPostfix, faulttype.BlackHolePortFaultType) } func TestRegisterStopBlackholePortFaultHandler(t *testing.T) { @@ -3828,7 +3829,7 @@ func TestRegisterStopBlackholePortFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("stop blackhole port", "stopped", setExecExpectations, happyBlackHolePortReqBody) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.StopNetworkFaultPostfix), faulttype.StopNetworkFaultPostfix, faulttype.BlackHolePortFaultType) } func TestRegisterCheckBlackholePortFaultHandler(t *testing.T) { @@ -3842,7 +3843,7 @@ func TestRegisterCheckBlackholePortFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("check blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.BlackHolePortFaultType, faulttype.CheckNetworkFaultPostfix), faulttype.CheckNetworkFaultPostfix, faulttype.BlackHolePortFaultType) } func TestRegisterStartLatencyFaultHandler(t *testing.T) { @@ -3858,7 +3859,7 @@ func TestRegisterStartLatencyFaultHandler(t *testing.T) { mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil) } tcs := generateCommonNetworkFaultInjectionTestCases("start latency", "running", setExecExpectations, happyNetworkLatencyReqBody) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix), faulttype.StartNetworkFaultPostfix, faulttype.LatencyFaultType) } func TestRegisterStopLatencyFaultHandler(t *testing.T) { @@ -3872,7 +3873,7 @@ func TestRegisterStopLatencyFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("stop latency", "stopped", setExecExpectations, nil) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StopNetworkFaultPostfix), faulttype.StopNetworkFaultPostfix, faulttype.LatencyFaultType) } func TestRegisterCheckLatencyFaultHandler(t *testing.T) { @@ -3886,7 +3887,7 @@ func TestRegisterCheckLatencyFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("check latency", "running", setExecExpectations, nil) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.CheckNetworkFaultPostfix), faulttype.CheckNetworkFaultPostfix, faulttype.LatencyFaultType) } func TestRegisterStartPacketLossFaultHandler(t *testing.T) { @@ -3902,7 +3903,7 @@ func TestRegisterStartPacketLossFaultHandler(t *testing.T) { mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil) } tcs := generateCommonNetworkFaultInjectionTestCases("start packet loss", "running", setExecExpectations, happyNetworkPacketLossReqBody) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix), faulttype.StartNetworkFaultPostfix, faulttype.PacketLossFaultType) } func TestRegisterStopPacketLossFaultHandler(t *testing.T) { @@ -3916,7 +3917,7 @@ func TestRegisterStopPacketLossFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("stop packet loss", "stopped", setExecExpectations, nil) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StopNetworkFaultPostfix), faulttype.StopNetworkFaultPostfix, faulttype.PacketLossFaultType) } func TestRegisterCheckPacketLossFaultHandler(t *testing.T) { @@ -3930,10 +3931,10 @@ func TestRegisterCheckPacketLossFaultHandler(t *testing.T) { ) } tcs := generateCommonNetworkFaultInjectionTestCases("check packet loss", "running", setExecExpectations, nil) - testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix)) + testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.CheckNetworkFaultPostfix), faulttype.CheckNetworkFaultPostfix, faulttype.PacketLossFaultType) } -func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, tmdsEndpoint string) { +func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, tmdsEndpoint, faultOperation, faultType string) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { // Mocks @@ -3946,6 +3947,13 @@ func testRegisterFaultHandler(t *testing.T, tcs []networkFaultTestCase, tmdsEndp agentState := agentV4.NewTMDSAgentState(state, statsEngine, ecsClient, clusterName, availabilityzone, vpcID, containerInstanceArn) metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) + durationMetricEntry := mock_metrics.NewMockEntry(ctrl) + gomock.InOrder( + metricsFactory.EXPECT().New(fmt.Sprintf(durationMetricPrefix, faultOperation, faultType)).Return(durationMetricEntry).Times(1), + durationMetricEntry.EXPECT().WithFields(gomock.Any()).Return(durationMetricEntry).Times(1), + durationMetricEntry.EXPECT().WithGauge(gomock.Any()).Return(durationMetricEntry).Times(1), + durationMetricEntry.EXPECT().Done(nil).Times(1), + ) execWrapper := mock_execwrapper.NewMockExec(ctrl) if tc.setStateExpectations != nil {