Skip to content

Commit

Permalink
Output streaming support for the whole pipeline in GMC router (#278)
Browse files Browse the repository at this point in the history
* output streaming support for the whole pipeline in GMC.
Signed-off-by: zhlsunshine <[email protected]>

* make the output streaming in line.
Signed-off-by: zhlsunshine <[email protected]>

* change.
Signed-off-by: zhlsunshine <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zhlsunshine and pre-commit-ci[bot] authored Aug 8, 2024
1 parent 5735dd3 commit c412aa3
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 64 deletions.
189 changes: 129 additions & 60 deletions microservices-connector/cmd/router/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
package main

import (
"bufio"
// "bufio"
"bytes"
"context"
"encoding/json"
Expand Down Expand Up @@ -43,14 +43,19 @@ var (
log = logf.Log.WithName("GMCGraphRouter")
mcGraph *mcv1alpha3.GMConnector
defaultNodeName = "root"
Prefix = []byte("data: b'")
Suffix = []byte("'\n\n")
DONE = []byte("[DONE]")
Newline = []byte("\n")
)

const (
ChunkSize = 1024
BufferSize = 1024
ServiceURL = "serviceUrl"
ServiceNode = "node"
DataPrep = "DataPrep"
Parameters = "parameters"
Llm = "Llm"
)

type EnsembleStepOutput struct {
Expand All @@ -63,6 +68,19 @@ type GMCGraphRoutingError struct {
Cause string `json:"cause"`
}

type ReadCloser struct {
*bytes.Reader
}

func (ReadCloser) Close() error {
// Typically, you would release resources here, but for bytes.Reader, there's nothing to do.
return nil
}

func NewReadCloser(b []byte) io.ReadCloser {
return ReadCloser{bytes.NewReader(b)}
}

func (e *GMCGraphRoutingError) Error() string {
return fmt.Sprintf("%s. %s", e.ErrorMessage, e.Cause)
}
Expand Down Expand Up @@ -127,7 +145,12 @@ func prepareErrorResponse(err error, errorMessage string) []byte {
return errorResponseBytes
}

func callService(step *mcv1alpha3.Step, serviceUrl string, input []byte, headers http.Header) ([]byte, int, error) {
func callService(
step *mcv1alpha3.Step,
serviceUrl string,
input []byte,
headers http.Header,
) (io.ReadCloser, int, error) {
defer timeTrack(time.Now(), "step", serviceUrl)
log.Info("Entering callService", "url", serviceUrl)

Expand Down Expand Up @@ -157,21 +180,7 @@ func callService(step *mcv1alpha3.Step, serviceUrl string, input []byte, headers
return nil, 500, err
}

defer func() {
if resp.Body != nil {
err := resp.Body.Close()
if err != nil {
log.Error(err, "An error has occurred while closing the response body")
}
}
}()

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Error(err, "Error while reading the response")
}

return body, resp.StatusCode, err
return resp.Body, resp.StatusCode, nil
}

// Use step service name to create a K8s service if serviceURL is empty
Expand All @@ -190,7 +199,7 @@ func executeStep(
initInput []byte,
input []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
if step.NodeName != "" {
// when nodeName is specified make a recursive call for routing to next step
return routeStep(step.NodeName, graph, initInput, input, headers)
Expand Down Expand Up @@ -231,16 +240,16 @@ func handleSwitchNode(
initInput []byte,
request []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
var statusCode int
var responseBytes []byte
var responseBody io.ReadCloser
var err error
stepType := ServiceURL
if route.NodeName != "" {
stepType = ServiceNode
}
log.Info("Starting execution of step", "Node Name", route.NodeName, "type", stepType, "stepName", route.StepName)
if responseBytes, statusCode, err = executeStep(route, graph, initInput, request, headers); err != nil {
if responseBody, statusCode, err = executeStep(route, graph, initInput, request, headers); err != nil {
return nil, 500, err
}

Expand All @@ -253,18 +262,19 @@ func handleSwitchNode(
statusCode,
)
}
return responseBytes, statusCode, nil
return responseBody, statusCode, nil
}

func handleSwitchPipeline(nodeName string,
graph mcv1alpha3.GMConnector,
initInput []byte,
input []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
currentNode := graph.Spec.Nodes[nodeName]
var statusCode int
var responseBytes []byte
var responseBody io.ReadCloser
var err error

initReqData := make(map[string]interface{})
Expand All @@ -286,35 +296,48 @@ func handleSwitchPipeline(nodeName string,
}
log.Info("Current Step Information", "Node Name", nodeName, "Step Index", index)
request := input
if responseBody != nil {
responseBytes, err = io.ReadAll(responseBody)
if err != nil {
log.Error(err, "Error while reading the response body")
return nil, 500, err
}
log.Info("Print Previous Response Bytes", "Previous Response Bytes",
responseBytes, "Previous Status Code", statusCode)
err = responseBody.Close()
if err != nil {
log.Error(err, "Error while trying to close the responseBody in handleSwitchPipeline")
}
}

log.Info("Print Original Request Bytes", "Request Bytes", request)
if route.Data == "$response" && index > 0 {
request = mergeRequests(responseBytes, initReqData)
}
log.Info("Print New Request Bytes", "Request Bytes", request)
if route.Condition == "" {
responseBytes, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers)
responseBody, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers)
if err != nil {
return responseBytes, statusCode, err
return nil, statusCode, err
}
} else {
if pickupRouteByCondition(initInput, route.Condition) {
responseBytes, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers)
responseBody, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers)
if err != nil {
return responseBytes, statusCode, err
return nil, statusCode, err
}
}
}
log.Info("Print Response Bytes", "Response Bytes", responseBytes, "Status Code", statusCode)
}
return responseBytes, statusCode, err
return responseBody, statusCode, err
}

func handleEnsemblePipeline(nodeName string,
graph mcv1alpha3.GMConnector,
initInput []byte,
input []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
currentNode := graph.Spec.Nodes[nodeName]
ensembleRes := make([]chan EnsembleStepOutput, len(currentNode.Steps))
errChan := make(chan error)
Expand All @@ -328,8 +351,12 @@ func handleEnsemblePipeline(nodeName string,
resultChan := make(chan EnsembleStepOutput)
ensembleRes[i] = resultChan
go func() {
output, statusCode, err := executeStep(step, graph, initInput, input, headers)
responseBody, statusCode, err := executeStep(step, graph, initInput, input, headers)
if err == nil {
output, rerr := io.ReadAll(responseBody)
if rerr != nil {
log.Error(rerr, "Error while reading the response body")
}
var res map[string]interface{}
if err = json.Unmarshal(output, &res); err == nil {
resultChan <- EnsembleStepOutput{
Expand All @@ -339,6 +366,10 @@ func handleEnsemblePipeline(nodeName string,
return
}
}
rerr := responseBody.Close()
if rerr != nil {
log.Error(rerr, "Error while trying to close the responseBody in handleEnsemblePipeline")
}
errChan <- err
}()
}
Expand All @@ -361,7 +392,8 @@ func handleEnsemblePipeline(nodeName string,
ensembleStepOutput.StepStatusCode,
)
stepResponse, _ := json.Marshal(ensembleStepOutput.StepResponse)
return stepResponse, ensembleStepOutput.StepStatusCode, nil
stepIOReader := NewReadCloser(stepResponse)
return stepIOReader, ensembleStepOutput.StepStatusCode, nil
} else {
response[key] = ensembleStepOutput.StepResponse
}
Expand All @@ -371,17 +403,19 @@ func handleEnsemblePipeline(nodeName string,
}
// return json.Marshal(response)
combinedResponse, _ := json.Marshal(response) // TODO check if you need err handling for Marshalling
return combinedResponse, 200, nil
combinedIOReader := NewReadCloser(combinedResponse)
return combinedIOReader, 200, nil
}

func handleSequencePipeline(nodeName string,
graph mcv1alpha3.GMConnector,
initInput []byte,
input []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
currentNode := graph.Spec.Nodes[nodeName]
var statusCode int
var responseBody io.ReadCloser
var responseBytes []byte
var err error

Expand Down Expand Up @@ -409,6 +443,20 @@ func handleSequencePipeline(nodeName string,
log.Info("Starting execution of step", "type", stepType, "stepName", step.StepName)
request := input
log.Info("Print Original Request Bytes", "Request Bytes", request)
if responseBody != nil {
responseBytes, err = io.ReadAll(responseBody)
if err != nil {
log.Error(err, "Error while reading the response body")
return nil, 500, err
}
log.Info("Print Previous Response Bytes", "Previous Response Bytes",
responseBytes, "Previous Status Code", statusCode)
err := responseBody.Close()
if err != nil {
log.Error(err, "Error while trying to close the responseBody in handleSequencePipeline")
}
}

if step.Data == "$response" && i > 0 {
request = mergeRequests(responseBytes, initReqData)
}
Expand All @@ -419,13 +467,12 @@ func handleSequencePipeline(nodeName string,
}
// if the condition does not match for the step in the sequence we stop and return the response
if !gjson.GetBytes(responseBytes, step.Condition).Exists() {
return responseBytes, 500, nil
return responseBody, 500, nil
}
}
if responseBytes, statusCode, err = executeStep(step, graph, initInput, request, headers); err != nil {
if responseBody, statusCode, err = executeStep(step, graph, initInput, request, headers); err != nil {
return nil, 500, err
}
log.Info("Print Response Bytes", "Response Bytes", responseBytes, "Status Code", statusCode)
/*
Only if a step is a hard dependency, we will check for its success.
*/
Expand All @@ -439,18 +486,18 @@ func handleSequencePipeline(nodeName string,
statusCode,
)
// Stop the execution of sequence right away if step is a hard dependency and is unsuccessful
return responseBytes, statusCode, nil
return responseBody, statusCode, nil
}
}
}
return responseBytes, statusCode, nil
return responseBody, statusCode, nil
}

func routeStep(nodeName string,
graph mcv1alpha3.GMConnector,
initInput, input []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
defer timeTrack(time.Now(), "node", nodeName)
currentNode := graph.Spec.Nodes[nodeName]
log.Info("Current Node", "Node Name", nodeName)
Expand Down Expand Up @@ -478,9 +525,14 @@ func mcGraphHandler(w http.ResponseWriter, req *http.Request) {
go func() {
defer close(done)

inputBytes, _ := io.ReadAll(req.Body)
response, statusCode, err := routeStep(defaultNodeName, *mcGraph, inputBytes, inputBytes, req.Header)
inputBytes, err := io.ReadAll(req.Body)
if err != nil {
log.Error(err, "failed to read request body")
http.Error(w, "failed to read request body", http.StatusBadRequest)
return
}

responseBody, statusCode, err := routeStep(defaultNodeName, *mcGraph, inputBytes, inputBytes, req.Header)
if err != nil {
log.Error(err, "failed to process request")
w.Header().Set("Content-Type", "application/json")
Expand All @@ -490,37 +542,54 @@ func mcGraphHandler(w http.ResponseWriter, req *http.Request) {
}
return
}
if json.Valid(response) {
w.Header().Set("Content-Type", "application/json")
}
w.WriteHeader(statusCode)

writer := bufio.NewWriter(w)
defer func() {
if err := writer.Flush(); err != nil {
log.Error(err, "error flushing writer when processing response")
err := responseBody.Close()
if err != nil {
log.Error(err, "Error while trying to close the responseBody in mcGraphHandler")
}
}()

for start := 0; start < len(response); start += ChunkSize {
end := start + ChunkSize
if end > len(response) {
end = len(response)
w.Header().Set("Content-Type", "application/json")
buffer := make([]byte, BufferSize)
for {
n, err := responseBody.Read(buffer)
if err != nil && err != io.EOF {
log.Error(err, "failed to read from response body")
http.Error(w, "failed to read from response body", http.StatusInternalServerError)
return
}
if _, err := writer.Write(response[start:end]); err != nil {
log.Error(err, "failed to write mcGraphHandler response")
if n == 0 {
break
}

sliceBF := buffer[:n]
if !bytes.HasPrefix(sliceBF, DONE) {
sliceBF = bytes.TrimPrefix(sliceBF, Prefix)
sliceBF = bytes.TrimSuffix(sliceBF, Suffix)
} else {
sliceBF = bytes.Join([][]byte{Newline, sliceBF}, nil)
}

log.Info("[llm - chat_stream] chunk:", "Buffer", string(sliceBF))
// Write the chunk to the ResponseWriter
if _, err := w.Write(sliceBF); err != nil {
log.Error(err, "failed to write to ResponseWriter")
return
}

if err := writer.Flush(); err != nil {
log.Error(err, "error flushing writer when processing response")
// Flush the data to the client immediately
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
} else {
log.Error(errors.New("unable to flush data"), "ResponseWriter does not support flushing")
return
}
}
}()

select {
case <-ctx.Done():
log.Error(errors.New("failed to process request"), "request timed out")
log.Error(errors.New("request timed out"), "failed to process request")
http.Error(w, "request timed out", http.StatusGatewayTimeout)
case <-done:
log.Info("mcGraphHandler is done")
Expand Down
Loading

0 comments on commit c412aa3

Please sign in to comment.