-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix router bugs on max_new_tokens and dataprep gaudi yaml file #273
Changes from 4 commits
9c65ecd
aa4140d
d9838a0
0ec4446
98e05cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,7 @@ const ( | |
ServiceURL = "serviceUrl" | ||
ServiceNode = "node" | ||
DataPrep = "DataPrep" | ||
Parameters = "parameters" | ||
) | ||
|
||
type EnsembleStepOutput struct { | ||
|
@@ -198,6 +199,32 @@ func executeStep( | |
return callService(step, serviceURL, input, headers) | ||
} | ||
|
||
func mergeRequests(respReq []byte, initReqData map[string]interface{}) []byte { | ||
var respReqData map[string]interface{} | ||
|
||
if _, exists := initReqData[Parameters]; exists { | ||
if err := json.Unmarshal(respReq, &respReqData); err != nil { | ||
log.Error(err, "Error unmarshaling respReqData:") | ||
return nil | ||
} | ||
// Merge init request into respReq | ||
for key, value := range initReqData[Parameters].(map[string]interface{}) { | ||
/*if _, exists := respReqData[key]; !exists { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this if not needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @KfreeZ, just let's keep it there, because let's see if they need to keep the response instead overwriting by initial request. |
||
respReqData[key] = value | ||
}*/ | ||
// overwrite the respReq by initial request | ||
respReqData[key] = value | ||
} | ||
mergedBytes, err := json.Marshal(respReqData) | ||
if err != nil { | ||
log.Error(err, "Error marshaling merged data:") | ||
return nil | ||
} | ||
return mergedBytes | ||
} | ||
return respReq | ||
} | ||
|
||
func handleSwitchNode( | ||
route *mcv1alpha3.Step, | ||
graph mcv1alpha3.GMConnector, | ||
|
@@ -239,6 +266,13 @@ func handleSwitchPipeline(nodeName string, | |
var statusCode int | ||
var responseBytes []byte | ||
var err error | ||
|
||
initReqData := make(map[string]interface{}) | ||
if err = json.Unmarshal(initInput, &initReqData); err != nil { | ||
log.Error(err, "Error unmarshaling initReqData:") | ||
return nil, 500, err | ||
} | ||
|
||
for index, route := range currentNode.Steps { | ||
if route.InternalService.IsDownstreamService { | ||
log.Info( | ||
|
@@ -252,9 +286,11 @@ func handleSwitchPipeline(nodeName string, | |
} | ||
log.Info("Current Step Information", "Node Name", nodeName, "Step Index", index) | ||
request := input | ||
log.Info("Print Original Request Bytes", "Request Bytes", request) | ||
if route.Data == "$response" && index > 0 { | ||
request = responseBytes | ||
request = mergeRequests(responseBytes, initReqData) | ||
} | ||
log.Info("Print New Request Bytes", "Request Bytes", request) | ||
if route.Condition == "" { | ||
responseBytes, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers) | ||
if err != nil { | ||
|
@@ -348,6 +384,12 @@ func handleSequencePipeline(nodeName string, | |
var statusCode int | ||
var responseBytes []byte | ||
var err error | ||
|
||
initReqData := make(map[string]interface{}) | ||
if err = json.Unmarshal(initInput, &initReqData); err != nil { | ||
log.Error(err, "Error unmarshaling initReqData:") | ||
return nil, 500, err | ||
} | ||
for i := range currentNode.Steps { | ||
step := ¤tNode.Steps[i] | ||
stepType := ServiceURL | ||
|
@@ -366,9 +408,11 @@ func handleSequencePipeline(nodeName string, | |
} | ||
log.Info("Starting execution of step", "type", stepType, "stepName", step.StepName) | ||
request := input | ||
log.Info("Print Original Request Bytes", "Request Bytes", request) | ||
if step.Data == "$response" && i > 0 { | ||
request = responseBytes | ||
request = mergeRequests(responseBytes, initReqData) | ||
} | ||
log.Info("Print New Request Bytes", "Request Bytes", request) | ||
if step.Condition != "" { | ||
if !gjson.ValidBytes(responseBytes) { | ||
return nil, 500, fmt.Errorf("invalid response") | ||
|
@@ -467,6 +511,10 @@ func mcGraphHandler(w http.ResponseWriter, req *http.Request) { | |
log.Error(err, "failed to write mcGraphHandler response") | ||
return | ||
} | ||
|
||
if err := writer.Flush(); err != nil { | ||
log.Error(err, "error flushing writer when processing response") | ||
} | ||
} | ||
}() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,77 +1,76 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
apiVersion: gmc.opea.io/v1alpha3 | ||
kind: GMConnector | ||
metadata: | ||
labels: | ||
app.kubernetes.io/name: gmconnector | ||
app.kubernetes.io/managed-by: kustomize | ||
gmc/platform: xeon | ||
name: chatqa | ||
namespace: chatqa | ||
spec: | ||
routerConfig: | ||
name: router | ||
serviceName: router-service | ||
nodes: | ||
root: | ||
routerType: Sequence | ||
steps: | ||
- name: Embedding | ||
internalService: | ||
serviceName: embedding-svc | ||
config: | ||
endpoint: /v1/embeddings | ||
TEI_EMBEDDING_ENDPOINT: tei-embedding-svc | ||
- name: TeiEmbedding | ||
internalService: | ||
serviceName: tei-embedding-svc | ||
isDownstreamService: true | ||
- name: Retriever | ||
data: $response | ||
internalService: | ||
serviceName: retriever-svc | ||
config: | ||
endpoint: /v1/retrieval | ||
REDIS_URL: redis-vector-db | ||
TEI_EMBEDDING_ENDPOINT: tei-embedding-svc | ||
- name: VectorDB | ||
internalService: | ||
serviceName: redis-vector-db | ||
isDownstreamService: true | ||
- name: Reranking | ||
data: $response | ||
internalService: | ||
serviceName: reranking-svc | ||
config: | ||
endpoint: /v1/reranking | ||
TEI_RERANKING_ENDPOINT: tei-reranking-svc | ||
- name: TeiReranking | ||
internalService: | ||
serviceName: tei-reranking-svc | ||
config: | ||
endpoint: /rerank | ||
isDownstreamService: true | ||
- name: Llm | ||
data: $response | ||
internalService: | ||
serviceName: llm-svc | ||
config: | ||
endpoint: /v1/chat/completions | ||
TGI_LLM_ENDPOINT: tgi-service-m | ||
- name: Tgi | ||
internalService: | ||
serviceName: tgi-service-m | ||
config: | ||
endpoint: /generate | ||
isDownstreamService: true | ||
- name: DataPrep | ||
internalService: | ||
serviceName: data-prep-svc | ||
config: | ||
endpoint: /v1/dataprep | ||
REDIS_URL: redis-vector-db | ||
INDEX_NAME: data-prep | ||
TEI_ENDPOINT: tei-embedding-svc | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
apiVersion: gmc.opea.io/v1alpha3 | ||
kind: GMConnector | ||
metadata: | ||
labels: | ||
app.kubernetes.io/name: gmconnector | ||
app.kubernetes.io/managed-by: kustomize | ||
gmc/platform: xeon | ||
name: chatqa | ||
namespace: chatqa | ||
spec: | ||
routerConfig: | ||
name: router | ||
serviceName: router-service | ||
nodes: | ||
root: | ||
routerType: Sequence | ||
steps: | ||
- name: Embedding | ||
internalService: | ||
serviceName: embedding-svc | ||
config: | ||
endpoint: /v1/embeddings | ||
TEI_EMBEDDING_ENDPOINT: tei-embedding-svc | ||
- name: TeiEmbedding | ||
internalService: | ||
serviceName: tei-embedding-svc | ||
isDownstreamService: true | ||
- name: Retriever | ||
data: $response | ||
internalService: | ||
serviceName: retriever-svc | ||
config: | ||
endpoint: /v1/retrieval | ||
REDIS_URL: redis-vector-db | ||
TEI_EMBEDDING_ENDPOINT: tei-embedding-svc | ||
- name: VectorDB | ||
internalService: | ||
serviceName: redis-vector-db | ||
isDownstreamService: true | ||
- name: Reranking | ||
data: $response | ||
internalService: | ||
serviceName: reranking-svc | ||
config: | ||
endpoint: /v1/reranking | ||
TEI_RERANKING_ENDPOINT: tei-reranking-svc | ||
- name: TeiReranking | ||
internalService: | ||
serviceName: tei-reranking-svc | ||
config: | ||
endpoint: /rerank | ||
isDownstreamService: true | ||
- name: Llm | ||
data: $response | ||
internalService: | ||
serviceName: llm-svc | ||
config: | ||
endpoint: /v1/chat/completions | ||
TGI_LLM_ENDPOINT: tgi-service-m | ||
- name: Tgi | ||
internalService: | ||
serviceName: tgi-service-m | ||
config: | ||
endpoint: /generate | ||
isDownstreamService: true | ||
- name: DataPrep | ||
internalService: | ||
serviceName: data-prep-svc | ||
config: | ||
endpoint: /v1/dataprep | ||
REDIS_URL: redis-vector-db | ||
TEI_ENDPOINT: tei-embedding-svc | ||
isDownstreamService: true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why the handleEnsemblePipeline does not need this mergeRequest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @KfreeZ, based on current implementation, EnsemblePipeline have no response case, so just let it there.