From c412aa3168164340e39f390c9484a4cedc4025f1 Mon Sep 17 00:00:00 2001 From: Steve Zhang Date: Fri, 9 Aug 2024 00:18:03 +0800 Subject: [PATCH] Output streaming support for the whole pipeline in GMC router (#278) * output streaming support for the whole pipeline in GMC. Signed-off-by: zhlsunshine * make the output streaming in line. Signed-off-by: zhlsunshine * change. Signed-off-by: zhlsunshine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- microservices-connector/cmd/router/main.go | 189 ++++++++++++------ .../cmd/router/main_test.go | 28 ++- 2 files changed, 153 insertions(+), 64 deletions(-) diff --git a/microservices-connector/cmd/router/main.go b/microservices-connector/cmd/router/main.go index b5508007..391a0c2c 100644 --- a/microservices-connector/cmd/router/main.go +++ b/microservices-connector/cmd/router/main.go @@ -11,7 +11,7 @@ package main import ( - "bufio" + // "bufio" "bytes" "context" "encoding/json" @@ -43,14 +43,19 @@ var ( log = logf.Log.WithName("GMCGraphRouter") mcGraph *mcv1alpha3.GMConnector defaultNodeName = "root" + Prefix = []byte("data: b'") + Suffix = []byte("'\n\n") + DONE = []byte("[DONE]") + Newline = []byte("\n") ) const ( - ChunkSize = 1024 + BufferSize = 1024 ServiceURL = "serviceUrl" ServiceNode = "node" DataPrep = "DataPrep" Parameters = "parameters" + Llm = "Llm" ) type EnsembleStepOutput struct { @@ -63,6 +68,19 @@ type GMCGraphRoutingError struct { Cause string `json:"cause"` } +type ReadCloser struct { + *bytes.Reader +} + +func (ReadCloser) Close() error { + // Typically, you would release resources here, but for bytes.Reader, there's nothing to do. + return nil +} + +func NewReadCloser(b []byte) io.ReadCloser { + return ReadCloser{bytes.NewReader(b)} +} + func (e *GMCGraphRoutingError) Error() string { return fmt.Sprintf("%s. %s", e.ErrorMessage, e.Cause) } @@ -127,7 +145,12 @@ func prepareErrorResponse(err error, errorMessage string) []byte { return errorResponseBytes } -func callService(step *mcv1alpha3.Step, serviceUrl string, input []byte, headers http.Header) ([]byte, int, error) { +func callService( + step *mcv1alpha3.Step, + serviceUrl string, + input []byte, + headers http.Header, +) (io.ReadCloser, int, error) { defer timeTrack(time.Now(), "step", serviceUrl) log.Info("Entering callService", "url", serviceUrl) @@ -157,21 +180,7 @@ func callService(step *mcv1alpha3.Step, serviceUrl string, input []byte, headers return nil, 500, err } - defer func() { - if resp.Body != nil { - err := resp.Body.Close() - if err != nil { - log.Error(err, "An error has occurred while closing the response body") - } - } - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Error(err, "Error while reading the response") - } - - return body, resp.StatusCode, err + return resp.Body, resp.StatusCode, nil } // Use step service name to create a K8s service if serviceURL is empty @@ -190,7 +199,7 @@ func executeStep( initInput []byte, input []byte, headers http.Header, -) ([]byte, int, error) { +) (io.ReadCloser, int, error) { if step.NodeName != "" { // when nodeName is specified make a recursive call for routing to next step return routeStep(step.NodeName, graph, initInput, input, headers) @@ -231,16 +240,16 @@ func handleSwitchNode( initInput []byte, request []byte, headers http.Header, -) ([]byte, int, error) { +) (io.ReadCloser, int, error) { var statusCode int - var responseBytes []byte + var responseBody io.ReadCloser var err error stepType := ServiceURL if route.NodeName != "" { stepType = ServiceNode } log.Info("Starting execution of step", "Node Name", route.NodeName, "type", stepType, "stepName", route.StepName) - if responseBytes, statusCode, err = executeStep(route, graph, initInput, request, headers); err != nil { + if responseBody, statusCode, err = executeStep(route, graph, initInput, request, headers); err != nil { return nil, 500, err } @@ -253,7 +262,7 @@ func handleSwitchNode( statusCode, ) } - return responseBytes, statusCode, nil + return responseBody, statusCode, nil } func handleSwitchPipeline(nodeName string, @@ -261,10 +270,11 @@ func handleSwitchPipeline(nodeName string, initInput []byte, input []byte, headers http.Header, -) ([]byte, int, error) { +) (io.ReadCloser, int, error) { currentNode := graph.Spec.Nodes[nodeName] var statusCode int var responseBytes []byte + var responseBody io.ReadCloser var err error initReqData := make(map[string]interface{}) @@ -286,27 +296,40 @@ func handleSwitchPipeline(nodeName string, } log.Info("Current Step Information", "Node Name", nodeName, "Step Index", index) request := input + if responseBody != nil { + responseBytes, err = io.ReadAll(responseBody) + if err != nil { + log.Error(err, "Error while reading the response body") + return nil, 500, err + } + log.Info("Print Previous Response Bytes", "Previous Response Bytes", + responseBytes, "Previous Status Code", statusCode) + err = responseBody.Close() + if err != nil { + log.Error(err, "Error while trying to close the responseBody in handleSwitchPipeline") + } + } + log.Info("Print Original Request Bytes", "Request Bytes", request) if route.Data == "$response" && index > 0 { request = mergeRequests(responseBytes, initReqData) } log.Info("Print New Request Bytes", "Request Bytes", request) if route.Condition == "" { - responseBytes, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers) + responseBody, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers) if err != nil { - return responseBytes, statusCode, err + return nil, statusCode, err } } else { if pickupRouteByCondition(initInput, route.Condition) { - responseBytes, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers) + responseBody, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers) if err != nil { - return responseBytes, statusCode, err + return nil, statusCode, err } } } - log.Info("Print Response Bytes", "Response Bytes", responseBytes, "Status Code", statusCode) } - return responseBytes, statusCode, err + return responseBody, statusCode, err } func handleEnsemblePipeline(nodeName string, @@ -314,7 +337,7 @@ func handleEnsemblePipeline(nodeName string, initInput []byte, input []byte, headers http.Header, -) ([]byte, int, error) { +) (io.ReadCloser, int, error) { currentNode := graph.Spec.Nodes[nodeName] ensembleRes := make([]chan EnsembleStepOutput, len(currentNode.Steps)) errChan := make(chan error) @@ -328,8 +351,12 @@ func handleEnsemblePipeline(nodeName string, resultChan := make(chan EnsembleStepOutput) ensembleRes[i] = resultChan go func() { - output, statusCode, err := executeStep(step, graph, initInput, input, headers) + responseBody, statusCode, err := executeStep(step, graph, initInput, input, headers) if err == nil { + output, rerr := io.ReadAll(responseBody) + if rerr != nil { + log.Error(rerr, "Error while reading the response body") + } var res map[string]interface{} if err = json.Unmarshal(output, &res); err == nil { resultChan <- EnsembleStepOutput{ @@ -339,6 +366,10 @@ func handleEnsemblePipeline(nodeName string, return } } + rerr := responseBody.Close() + if rerr != nil { + log.Error(rerr, "Error while trying to close the responseBody in handleEnsemblePipeline") + } errChan <- err }() } @@ -361,7 +392,8 @@ func handleEnsemblePipeline(nodeName string, ensembleStepOutput.StepStatusCode, ) stepResponse, _ := json.Marshal(ensembleStepOutput.StepResponse) - return stepResponse, ensembleStepOutput.StepStatusCode, nil + stepIOReader := NewReadCloser(stepResponse) + return stepIOReader, ensembleStepOutput.StepStatusCode, nil } else { response[key] = ensembleStepOutput.StepResponse } @@ -371,7 +403,8 @@ func handleEnsemblePipeline(nodeName string, } // return json.Marshal(response) combinedResponse, _ := json.Marshal(response) // TODO check if you need err handling for Marshalling - return combinedResponse, 200, nil + combinedIOReader := NewReadCloser(combinedResponse) + return combinedIOReader, 200, nil } func handleSequencePipeline(nodeName string, @@ -379,9 +412,10 @@ func handleSequencePipeline(nodeName string, initInput []byte, input []byte, headers http.Header, -) ([]byte, int, error) { +) (io.ReadCloser, int, error) { currentNode := graph.Spec.Nodes[nodeName] var statusCode int + var responseBody io.ReadCloser var responseBytes []byte var err error @@ -409,6 +443,20 @@ func handleSequencePipeline(nodeName string, log.Info("Starting execution of step", "type", stepType, "stepName", step.StepName) request := input log.Info("Print Original Request Bytes", "Request Bytes", request) + if responseBody != nil { + responseBytes, err = io.ReadAll(responseBody) + if err != nil { + log.Error(err, "Error while reading the response body") + return nil, 500, err + } + log.Info("Print Previous Response Bytes", "Previous Response Bytes", + responseBytes, "Previous Status Code", statusCode) + err := responseBody.Close() + if err != nil { + log.Error(err, "Error while trying to close the responseBody in handleSequencePipeline") + } + } + if step.Data == "$response" && i > 0 { request = mergeRequests(responseBytes, initReqData) } @@ -419,13 +467,12 @@ func handleSequencePipeline(nodeName string, } // if the condition does not match for the step in the sequence we stop and return the response if !gjson.GetBytes(responseBytes, step.Condition).Exists() { - return responseBytes, 500, nil + return responseBody, 500, nil } } - if responseBytes, statusCode, err = executeStep(step, graph, initInput, request, headers); err != nil { + if responseBody, statusCode, err = executeStep(step, graph, initInput, request, headers); err != nil { return nil, 500, err } - log.Info("Print Response Bytes", "Response Bytes", responseBytes, "Status Code", statusCode) /* Only if a step is a hard dependency, we will check for its success. */ @@ -439,18 +486,18 @@ func handleSequencePipeline(nodeName string, statusCode, ) // Stop the execution of sequence right away if step is a hard dependency and is unsuccessful - return responseBytes, statusCode, nil + return responseBody, statusCode, nil } } } - return responseBytes, statusCode, nil + return responseBody, statusCode, nil } func routeStep(nodeName string, graph mcv1alpha3.GMConnector, initInput, input []byte, headers http.Header, -) ([]byte, int, error) { +) (io.ReadCloser, int, error) { defer timeTrack(time.Now(), "node", nodeName) currentNode := graph.Spec.Nodes[nodeName] log.Info("Current Node", "Node Name", nodeName) @@ -478,9 +525,14 @@ func mcGraphHandler(w http.ResponseWriter, req *http.Request) { go func() { defer close(done) - inputBytes, _ := io.ReadAll(req.Body) - response, statusCode, err := routeStep(defaultNodeName, *mcGraph, inputBytes, inputBytes, req.Header) + inputBytes, err := io.ReadAll(req.Body) + if err != nil { + log.Error(err, "failed to read request body") + http.Error(w, "failed to read request body", http.StatusBadRequest) + return + } + responseBody, statusCode, err := routeStep(defaultNodeName, *mcGraph, inputBytes, inputBytes, req.Header) if err != nil { log.Error(err, "failed to process request") w.Header().Set("Content-Type", "application/json") @@ -490,37 +542,54 @@ func mcGraphHandler(w http.ResponseWriter, req *http.Request) { } return } - if json.Valid(response) { - w.Header().Set("Content-Type", "application/json") - } - w.WriteHeader(statusCode) - - writer := bufio.NewWriter(w) defer func() { - if err := writer.Flush(); err != nil { - log.Error(err, "error flushing writer when processing response") + err := responseBody.Close() + if err != nil { + log.Error(err, "Error while trying to close the responseBody in mcGraphHandler") } }() - for start := 0; start < len(response); start += ChunkSize { - end := start + ChunkSize - if end > len(response) { - end = len(response) + w.Header().Set("Content-Type", "application/json") + buffer := make([]byte, BufferSize) + for { + n, err := responseBody.Read(buffer) + if err != nil && err != io.EOF { + log.Error(err, "failed to read from response body") + http.Error(w, "failed to read from response body", http.StatusInternalServerError) + return } - if _, err := writer.Write(response[start:end]); err != nil { - log.Error(err, "failed to write mcGraphHandler response") + if n == 0 { + break + } + + sliceBF := buffer[:n] + if !bytes.HasPrefix(sliceBF, DONE) { + sliceBF = bytes.TrimPrefix(sliceBF, Prefix) + sliceBF = bytes.TrimSuffix(sliceBF, Suffix) + } else { + sliceBF = bytes.Join([][]byte{Newline, sliceBF}, nil) + } + + log.Info("[llm - chat_stream] chunk:", "Buffer", string(sliceBF)) + // Write the chunk to the ResponseWriter + if _, err := w.Write(sliceBF); err != nil { + log.Error(err, "failed to write to ResponseWriter") return } - if err := writer.Flush(); err != nil { - log.Error(err, "error flushing writer when processing response") + // Flush the data to the client immediately + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } else { + log.Error(errors.New("unable to flush data"), "ResponseWriter does not support flushing") + return } } }() select { case <-ctx.Done(): - log.Error(errors.New("failed to process request"), "request timed out") + log.Error(errors.New("request timed out"), "failed to process request") http.Error(w, "request timed out", http.StatusGatewayTimeout) case <-done: log.Info("mcGraphHandler is done") diff --git a/microservices-connector/cmd/router/main_test.go b/microservices-connector/cmd/router/main_test.go index 25afa912..6f2d7e54 100644 --- a/microservices-connector/cmd/router/main_test.go +++ b/microservices-connector/cmd/router/main_test.go @@ -116,7 +116,12 @@ func TestSimpleModelChainer(t *testing.T) { return } var response map[string]interface{} - err = json.Unmarshal(res, &response) + responseBytes, rerr := io.ReadAll(res) + if rerr != nil { + t.Fatalf("Error while reading the response body: %v", rerr) + return + } + err = json.Unmarshal(responseBytes, &response) if err != nil { return } @@ -217,7 +222,12 @@ func TestSimpleServiceEnsemble(t *testing.T) { return } var response map[string]interface{} - err = json.Unmarshal(res, &response) + responseBytes, rerr := io.ReadAll(res) + if rerr != nil { + t.Fatalf("Error while reading the response body") + return + } + err = json.Unmarshal(responseBytes, &response) if err != nil { return } @@ -452,7 +462,12 @@ func TestMCWithCondition(t *testing.T) { return } var response map[string]interface{} - err = json.Unmarshal(res, &response) + responseBytes, rerr := io.ReadAll(res) + if rerr != nil { + t.Fatalf("Error while reading the response body") + return + } + err = json.Unmarshal(responseBytes, &response) if err != nil { return } @@ -536,7 +551,12 @@ func TestCallServiceWhenNoneHeadersToPropagateIsEmpty(t *testing.T) { return } var response map[string]interface{} - err = json.Unmarshal(res, &response) + responseBytes, rerr := io.ReadAll(res) + if rerr != nil { + t.Fatalf("Error while reading the response body") + return + } + err = json.Unmarshal(responseBytes, &response) if err != nil { return }