From 0349f379ae1874a4e22b10c3ef5d082f2d75f315 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Fri, 30 Aug 2024 13:37:00 -0700 Subject: [PATCH] ragserver: fix other variants to be similar to ragserver-genkit Apply suggestions from go.dev/cl/608737 to the other ragserver variants Change-Id: I627cec1ff21c8f7c1b1d257b542435793ade680d Reviewed-on: https://go-review.googlesource.com/c/example/+/609305 TryBot-Bypass: Eli Bendersky Auto-Submit: Eli Bendersky Reviewed-by: Ian Lance Taylor Reviewed-by: Eli Bendersky --- ragserver/ragserver-langchaingo/json.go | 9 +++------ ragserver/ragserver-langchaingo/main.go | 6 +++--- ragserver/ragserver/json.go | 9 +++------ ragserver/ragserver/main.go | 6 +++--- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/ragserver/ragserver-langchaingo/json.go b/ragserver/ragserver-langchaingo/json.go index c0bb7ce6..e19a6894 100644 --- a/ragserver/ragserver-langchaingo/json.go +++ b/ragserver/ragserver-langchaingo/json.go @@ -14,7 +14,7 @@ import ( // readRequestJSON expects req to have a JSON content type with a body that // contains a JSON-encoded value complying with the underlying type of target. // It populates target, or returns an error. -func readRequestJSON(target any, req *http.Request) error { +func readRequestJSON(req *http.Request, target any) error { contentType := req.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(contentType) if err != nil { @@ -26,14 +26,11 @@ func readRequestJSON(target any, req *http.Request) error { dec := json.NewDecoder(req.Body) dec.DisallowUnknownFields() - if err := dec.Decode(target); err != nil { - return err - } - return nil + return dec.Decode(target) } // renderJSON renders 'v' as JSON and writes it as a response into w. -func renderJSON(w http.ResponseWriter, v interface{}) { +func renderJSON(w http.ResponseWriter, v any) { js, err := json.Marshal(v) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/ragserver/ragserver-langchaingo/main.go b/ragserver/ragserver-langchaingo/main.go index 86aad2c3..f679c44b 100644 --- a/ragserver/ragserver-langchaingo/main.go +++ b/ragserver/ragserver-langchaingo/main.go @@ -60,7 +60,7 @@ func main() { mux := http.NewServeMux() mux.HandleFunc("POST /add/", server.addDocumentsHandler) - mux.HandleFunc("GET /query/", server.queryHandler) + mux.HandleFunc("POST /query/", server.queryHandler) port := cmp.Or(os.Getenv("SERVERPORT"), "9020") address := "localhost:" + port @@ -84,7 +84,7 @@ func (rs *ragServer) addDocumentsHandler(w http.ResponseWriter, req *http.Reques } ar := &addRequest{} - err := readRequestJSON(ar, req) + err := readRequestJSON(req, ar) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -108,7 +108,7 @@ func (rs *ragServer) queryHandler(w http.ResponseWriter, req *http.Request) { Content string } qr := &queryRequest{} - err := readRequestJSON(qr, req) + err := readRequestJSON(req, qr) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/ragserver/ragserver/json.go b/ragserver/ragserver/json.go index c0bb7ce6..e19a6894 100644 --- a/ragserver/ragserver/json.go +++ b/ragserver/ragserver/json.go @@ -14,7 +14,7 @@ import ( // readRequestJSON expects req to have a JSON content type with a body that // contains a JSON-encoded value complying with the underlying type of target. // It populates target, or returns an error. -func readRequestJSON(target any, req *http.Request) error { +func readRequestJSON(req *http.Request, target any) error { contentType := req.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(contentType) if err != nil { @@ -26,14 +26,11 @@ func readRequestJSON(target any, req *http.Request) error { dec := json.NewDecoder(req.Body) dec.DisallowUnknownFields() - if err := dec.Decode(target); err != nil { - return err - } - return nil + return dec.Decode(target) } // renderJSON renders 'v' as JSON and writes it as a response into w. -func renderJSON(w http.ResponseWriter, v interface{}) { +func renderJSON(w http.ResponseWriter, v any) { js, err := json.Marshal(v) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/ragserver/ragserver/main.go b/ragserver/ragserver/main.go index e55c66fe..f1a21f8d 100644 --- a/ragserver/ragserver/main.go +++ b/ragserver/ragserver/main.go @@ -52,7 +52,7 @@ func main() { mux := http.NewServeMux() mux.HandleFunc("POST /add/", server.addDocumentsHandler) - mux.HandleFunc("GET /query/", server.queryHandler) + mux.HandleFunc("POST /query/", server.queryHandler) port := cmp.Or(os.Getenv("SERVERPORT"), "9020") address := "localhost:" + port @@ -77,7 +77,7 @@ func (rs *ragServer) addDocumentsHandler(w http.ResponseWriter, req *http.Reques } ar := &addRequest{} - err := readRequestJSON(ar, req) + err := readRequestJSON(req, ar) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -127,7 +127,7 @@ func (rs *ragServer) queryHandler(w http.ResponseWriter, req *http.Request) { Content string } qr := &queryRequest{} - err := readRequestJSON(qr, req) + err := readRequestJSON(req, qr) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return