From ff90e444665a77a340d6b8a7420b6e7d8a73c135 Mon Sep 17 00:00:00 2001 From: EwenQuim Date: Wed, 3 Jul 2024 22:39:35 +0200 Subject: [PATCH] Multi-return --- ctx.go | 57 +----- examples/full-app-gourmet/views/partials.go | 4 +- examples/full-app-gourmet/views/recipe.go | 27 +-- html.go | 64 ++++++- html_test.go | 10 +- multi_return.go | 61 +++++++ multi_return_test.go | 118 ++++++++++++ options.go | 22 ++- options_test.go | 30 ++- serialization.go | 191 ++++++++++++++++++-- serialization_test.go | 70 ++++++- serve.go | 75 ++------ serve_test.go | 29 ++- 13 files changed, 601 insertions(+), 157 deletions(-) create mode 100644 multi_return.go create mode 100644 multi_return_test.go diff --git a/ctx.go b/ctx.go index a0826086..9c008152 100644 --- a/ctx.go +++ b/ctx.go @@ -64,7 +64,7 @@ type ctx[B any] interface { // and you want to override one above the other, you can do: // c.Render("admin.page.html", recipes, "partials/aaa/nav.partial.html") // By default, [templateToExecute] is added to the list of templates to override. - Render(templateToExecute string, data any, templateGlobsToOverride ...string) (HTML, error) + Render(templateToExecute string, data any, templateGlobsToOverride ...string) (CtxRenderer, error) Cookie(name string) (*http.Cookie, error) // Get request cookie SetCookie(cookie http.Cookie) // Sets response cookie @@ -222,51 +222,14 @@ func (c ContextNoBody) SetCookie(cookie http.Cookie) { // that the templates will be parsed only once, removing // the need to parse the templates on each request but also preventing // to dynamically use new templates. -func (c ContextNoBody) Render(templateToExecute string, data any, layoutsGlobs ...string) (HTML, error) { - if strings.Contains(templateToExecute, "/") || strings.Contains(templateToExecute, "*") { - - layoutsGlobs = append(layoutsGlobs, templateToExecute) // To override all blocks defined in the main template - cloned := template.Must(c.templates.Clone()) - tmpl, err := cloned.ParseFS(c.fs, layoutsGlobs...) - if err != nil { - return "", HTTPError{ - Err: err, - Status: http.StatusInternalServerError, - Title: "Error parsing template", - Detail: fmt.Errorf("error parsing template '%s': %w", layoutsGlobs, err).Error(), - Errors: []ErrorItem{ - { - Name: "templates", - Reason: "Check that the template exists and have the correct extension. Globs: " + strings.Join(layoutsGlobs, ", "), - }, - }, - } - } - c.templates = template.Must(tmpl.Clone()) - } - - // Get only last template name (for example, with partials/nav/main/nav.partial.html, get nav.partial.html) - myTemplate := strings.Split(templateToExecute, "/") - templateToExecute = myTemplate[len(myTemplate)-1] - - c.Res.Header().Set("Content-Type", "text/html; charset=utf-8") - err := c.templates.ExecuteTemplate(c.Res, templateToExecute, data) - if err != nil { - return "", HTTPError{ - Err: err, - Status: http.StatusInternalServerError, - Title: "Error rendering template", - Detail: fmt.Errorf("error executing template '%s': %w", templateToExecute, err).Error(), - Errors: []ErrorItem{ - { - Name: "templates", - Reason: "Check that the template exists and have the correct extension. Template: " + templateToExecute, - }, - }, - } - } - - return "", err +func (c ContextNoBody) Render(templateToExecute string, data any, layoutsGlobs ...string) (CtxRenderer, error) { + return &StdRenderer{ + templateToExecute: templateToExecute, + templates: c.templates, + layoutsGlobs: layoutsGlobs, + fs: c.fs, + data: data, + }, nil } // PathParams returns the path parameters of the request. @@ -427,7 +390,7 @@ func body[B any](c ContextNoBody) (B, error) { body, err = readURLEncoded[B](c.Req, c.readOptions) case "application/xml": body, err = readXML[B](c.Req.Context(), c.Req.Body, c.readOptions) - case "application/x-yaml": + case "application/x-yaml", "text/yaml; charset=utf-8", "application/yaml": // https://www.rfc-editor.org/rfc/rfc9512.html body, err = readYAML[B](c.Req.Context(), c.Req.Body, c.readOptions) case "application/octet-stream": // Read c.Req Body to bytes diff --git a/examples/full-app-gourmet/views/partials.go b/examples/full-app-gourmet/views/partials.go index 5a92efa4..0ddd2dca 100644 --- a/examples/full-app-gourmet/views/partials.go +++ b/examples/full-app-gourmet/views/partials.go @@ -5,12 +5,12 @@ import ( "github.com/go-fuego/fuego/examples/full-app-gourmet/store/types" ) -func (rs Ressource) unitPreselected(c fuego.ContextNoBody) (fuego.HTML, error) { +func (rs Ressource) unitPreselected(c fuego.ContextNoBody) (fuego.CtxRenderer, error) { id := c.QueryParam("IngredientID") ingredient, err := rs.IngredientsQueries.GetIngredient(c.Context(), id) if err != nil { - return "", err + return nil, err } return c.Render("preselected-unit.partial.html", fuego.H{ diff --git a/examples/full-app-gourmet/views/recipe.go b/examples/full-app-gourmet/views/recipe.go index f2a93666..97d29665 100644 --- a/examples/full-app-gourmet/views/recipe.go +++ b/examples/full-app-gourmet/views/recipe.go @@ -104,15 +104,18 @@ func (rs Ressource) showIndex(c fuego.ContextNoBody) (fuego.Templ, error) { }), nil } -func (rs Ressource) showRecipes(c fuego.ContextNoBody) (fuego.Templ, error) { +func (rs Ressource) showRecipes(c fuego.ContextNoBody) (*fuego.DataOrTemplate[[]store.Recipe], error) { recipes, err := rs.RecipesQueries.GetRecipes(c.Context()) if err != nil { return nil, err } - return templa.SearchPage(templa.SearchProps{ - Recipes: recipes, - }), nil + return fuego.DataOrHTML( + recipes, + templa.SearchPage(templa.SearchProps{ + Recipes: recipes, + }), + ), nil } func (rs Ressource) relatedRecipes(c fuego.ContextNoBody) (fuego.Templ, error) { @@ -220,7 +223,7 @@ func (rs Ressource) healthyRecipes(c fuego.ContextNoBody) (fuego.Templ, error) { }), nil } -func (rs Ressource) showRecipesList(c fuego.ContextNoBody) (fuego.HTML, error) { +func (rs Ressource) showRecipesList(c fuego.ContextNoBody) (fuego.CtxRenderer, error) { search := c.QueryParam("search") recipes, err := rs.RecipesQueries.SearchRecipes(c.Context(), store.SearchRecipesParams{ Search: sql.NullString{ @@ -229,28 +232,28 @@ func (rs Ressource) showRecipesList(c fuego.ContextNoBody) (fuego.HTML, error) { }, }) if err != nil { - return "", err + return nil, err } return c.Render("partials/recipes-list.partial.html", recipes) } -func (rs Ressource) addRecipe(c *fuego.ContextWithBody[store.CreateRecipeParams]) (fuego.HTML, error) { +func (rs Ressource) addRecipe(c *fuego.ContextWithBody[store.CreateRecipeParams]) (fuego.CtxRenderer, error) { body, err := c.Body() if err != nil { - return "", err + return nil, err } body.ID = uuid.NewString() _, err = rs.RecipesQueries.CreateRecipe(c.Context(), body) if err != nil { - return "", err + return nil, err } recipes, err := rs.RecipesQueries.GetRecipes(c.Context()) if err != nil { - return "", err + return nil, err } return c.Render("pages/admin.page.html", fuego.H{ @@ -258,12 +261,12 @@ func (rs Ressource) addRecipe(c *fuego.ContextWithBody[store.CreateRecipeParams] }) } -func (rs Ressource) RecipePage(c fuego.ContextNoBody) (fuego.HTML, error) { +func (rs Ressource) RecipePage(c fuego.ContextNoBody) (fuego.CtxRenderer, error) { id := c.PathParam("id") recipe, err := rs.RecipesQueries.GetRecipe(c.Context(), id) if err != nil { - return "", fmt.Errorf("error getting recipe %s: %w", id, err) + return nil, fmt.Errorf("error getting recipe %s: %w", id, err) } ingredients, err := rs.IngredientsQueries.GetIngredientsOfRecipe(c.Context(), id) diff --git a/html.go b/html.go index fdffa571..802aeaf2 100644 --- a/html.go +++ b/html.go @@ -5,9 +5,14 @@ import ( "fmt" "html/template" "io" + "io/fs" + "net/http" + "strings" ) -// CtxRenderer can be used with [github.com/a-h/templ] +// CtxRenderer is an interface that can be used to render a response. +// It is used with standard library templating engine, by using fuego.ContextXXX.Render +// It is compatible with [github.com/a-h/templ] out of the box. // Example: // // func getRecipes(ctx fuego.ContextNoBody) (fuego.CtxRenderer, error) { @@ -50,6 +55,63 @@ type HTML string // H is a shortcut for map[string]any type H map[string]any +// StdRenderer renders a template using the standard library templating engine. +type StdRenderer struct { + templateToExecute string + templates *template.Template + layoutsGlobs []string + fs fs.FS + data any +} + +var _ CtxRenderer = StdRenderer{} + +func (s StdRenderer) Render(ctx context.Context, w io.Writer) error { + if strings.Contains(s.templateToExecute, "/") || strings.Contains(s.templateToExecute, "*") { + + s.layoutsGlobs = append(s.layoutsGlobs, s.templateToExecute) // To override all blocks defined in the main template + cloned := template.Must(s.templates.Clone()) + tmpl, err := cloned.ParseFS(s.fs, s.layoutsGlobs...) + if err != nil { + return HTTPError{ + Err: err, + Status: http.StatusInternalServerError, + Title: "Error parsing template", + Detail: fmt.Errorf("error parsing template '%s': %w", s.layoutsGlobs, err).Error(), + Errors: []ErrorItem{ + { + Name: "templates", + Reason: "Check that the template exists and have the correct extension. Globs: " + strings.Join(s.layoutsGlobs, ", "), + }, + }, + } + } + s.templates = template.Must(tmpl.Clone()) + } + + // Get only last template name (for example, with partials/nav/main/nav.partial.html, get nav.partial.html) + myTemplate := strings.Split(s.templateToExecute, "/") + s.templateToExecute = myTemplate[len(myTemplate)-1] + + err := s.templates.ExecuteTemplate(w, s.templateToExecute, s.data) + if err != nil { + return HTTPError{ + Err: err, + Status: http.StatusInternalServerError, + Title: "Error rendering template", + Detail: fmt.Errorf("error executing template '%s': %w", s.templateToExecute, err).Error(), + Errors: []ErrorItem{ + { + Name: "templates", + Reason: "Check that the template exists and have the correct extension. Template: " + s.templateToExecute, + }, + }, + } + } + + return err +} + // loadTemplates func (s *Server) loadTemplates(patterns ...string) error { tmpl, err := template.ParseFS(s.fs, patterns...) diff --git a/html_test.go b/html_test.go index 990c5be8..7e07f37f 100644 --- a/html_test.go +++ b/html_test.go @@ -18,7 +18,7 @@ func TestRender(t *testing.T) { WithTemplateGlobs("testdata/*.html"), ) - Get(s, "/test", func(ctx *ContextNoBody) (HTML, error) { + Get(s, "/test", func(ctx *ContextNoBody) (CtxRenderer, error) { return ctx.Render("testdata/test.html", H{"Name": "test"}) }) @@ -43,7 +43,7 @@ func TestRender(t *testing.T) { }) t.Run("cannot parse unexisting file", func(t *testing.T) { - Get(s, "/file-not-found", func(ctx ContextNoBody) (HTML, error) { + Get(s, "/file-not-found", func(ctx ContextNoBody) (CtxRenderer, error) { return ctx.Render("testdata/not-found.html", H{"Name": "test"}) }) @@ -56,7 +56,7 @@ func TestRender(t *testing.T) { }) t.Run("can execute template with missing variable in map", func(t *testing.T) { - Get(s, "/impossible", func(ctx ContextNoBody) (HTML, error) { + Get(s, "/impossible", func(ctx ContextNoBody) (CtxRenderer, error) { return ctx.Render("testdata/test.html", H{"NotName": "test"}) }) @@ -71,7 +71,7 @@ func TestRender(t *testing.T) { }) t.Run("cannot execute template with missing variable in struct", func(t *testing.T) { - Get(s, "/impossible-struct", func(ctx ContextNoBody) (HTML, error) { + Get(s, "/impossible-struct", func(ctx ContextNoBody) (CtxRenderer, error) { return ctx.Render("testdata/test.html", struct{}{}) }) @@ -93,7 +93,7 @@ func BenchmarkRender(b *testing.B) { WithTemplateGlobs("testdata/*.html"), ) - Get(s, "/test", func(ctx ContextNoBody) (HTML, error) { + Get(s, "/test", func(ctx ContextNoBody) (CtxRenderer, error) { return ctx.Render("testdata/test.html", H{"Name": "test"}) }) diff --git a/multi_return.go b/multi_return.go new file mode 100644 index 00000000..38def6d1 --- /dev/null +++ b/multi_return.go @@ -0,0 +1,61 @@ +package fuego + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "io" + + "gopkg.in/yaml.v3" +) + +// DataOrTemplate is a struct that can return either data or a template +// depending on the asked type. +type DataOrTemplate[T any] struct { + Data T + Template any +} + +var ( + _ CtxRenderer = DataOrTemplate[any]{} // Can render HTML (template) + _ json.Marshaler = DataOrTemplate[any]{} // Can render JSON (data) + _ xml.Marshaler = DataOrTemplate[any]{} // Can render XML (data) + _ yaml.Marshaler = DataOrTemplate[any]{} // Can render YAML (data) + _ fmt.Stringer = DataOrTemplate[any]{} // Can render string (data) +) + +func (m DataOrTemplate[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(m.Data) +} + +func (m DataOrTemplate[T]) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { + return e.Encode(m.Data) +} + +func (m DataOrTemplate[T]) MarshalYAML() (interface{}, error) { + return m.Data, nil +} + +func (m DataOrTemplate[T]) String() string { + return fmt.Sprintf("%v", m.Data) +} + +func (m DataOrTemplate[T]) Render(c context.Context, w io.Writer) error { + switch m.Template.(type) { + case CtxRenderer: + return m.Template.(CtxRenderer).Render(c, w) + case Renderer: + return m.Template.(Renderer).Render(w) + default: + panic("template must be either CtxRenderer or Renderer") + } +} + +// Helper function to create a DataOrTemplate return item without specifying the type. +func DataOrHTML[T any](data T, template any) *DataOrTemplate[T] { + return &DataOrTemplate[T]{ + Data: data, + Template: template, + } +} diff --git a/multi_return_test.go b/multi_return_test.go new file mode 100644 index 00000000..5b3dcd80 --- /dev/null +++ b/multi_return_test.go @@ -0,0 +1,118 @@ +package fuego + +import ( + "context" + "errors" + "io" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +type MockCtxRenderer struct { + RenderFunc func(context.Context, io.Writer) error +} + +func RenderString(s string) MockCtxRenderer { + return MockCtxRenderer{ + RenderFunc: func(c context.Context, w io.Writer) error { + _, err := w.Write([]byte(s)) + return err + }, + } +} + +var _ CtxRenderer = MockCtxRenderer{} + +func (m MockCtxRenderer) Render(c context.Context, w io.Writer) error { + if m.RenderFunc == nil { + return errors.New("RenderFunc is nil") + } + return m.RenderFunc(c, w) +} + +type MyType struct { + Name string +} + +func TestMultiReturn(t *testing.T) { + s := NewServer() + + Get(s, "/data", func(c ContextNoBody) (DataOrTemplate[MyType], error) { + entity := MyType{Name: "Ewen"} + + return DataOrTemplate[MyType]{ + Data: entity, + Template: RenderString(`
` + entity.Name + `
`), + }, nil + }) + + Get(s, "/other", func(c ContextNoBody) (*DataOrTemplate[MyType], error) { + entity := MyType{Name: "Ewen"} + + return DataOrHTML( + entity, + RenderString(`
`+entity.Name+`
`), + ), nil + }) + + t.Run("requests HTML by default", func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/data", nil) + + s.Mux.ServeHTTP(recorder, req) + + require.Equal(t, 200, recorder.Code) + require.Equal(t, "text/html; charset=utf-8", recorder.Header().Get("Content-Type")) + require.Equal(t, `
Ewen
`, recorder.Body.String()) + }) + + t.Run("requests JSON", func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/data", nil) + req.Header.Set("Accept", "application/json") + + s.Mux.ServeHTTP(recorder, req) + + require.Equal(t, 200, recorder.Code) + require.Equal(t, "application/json", recorder.Header().Get("Content-Type")) + require.Equal(t, crlf(`{"Name":"Ewen"}`), recorder.Body.String()) + }) + + t.Run("requests JSON, using the shortcut", func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/other", nil) + req.Header.Set("Accept", "application/json") + + s.Mux.ServeHTTP(recorder, req) + + require.Equal(t, 200, recorder.Code) + require.Equal(t, "application/json", recorder.Header().Get("Content-Type")) + require.Equal(t, crlf(`{"Name":"Ewen"}`), recorder.Body.String()) + }) + + t.Run("requests XML", func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/data", nil) + req.Header.Set("Accept", "application/xml") + + s.Mux.ServeHTTP(recorder, req) + + require.Equal(t, 200, recorder.Code) + require.Equal(t, "application/xml", recorder.Header().Get("Content-Type")) + require.Equal(t, `Ewen`, recorder.Body.String()) + }) + + t.Run("requests HTML", func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/data", nil) + req.Header.Set("Accept", "text/html") + + s.Mux.ServeHTTP(recorder, req) + + require.Equal(t, 200, recorder.Code) + require.Contains(t, recorder.Header().Get("Content-Type"), "text/html") + require.Equal(t, `
Ewen
`, recorder.Body.String()) + }) +} diff --git a/options.go b/options.go index f6d43c18..8080016d 100644 --- a/options.go +++ b/options.go @@ -70,10 +70,11 @@ type Server struct { DisallowUnknownFields bool // If true, the server will return an error if the request body contains unknown fields. Useful for quick debugging in development. DisableOpenapi bool // If true, the routes within the server will not generate an openapi spec. maxBodySize int64 - Serialize func(w http.ResponseWriter, ans any) // Used to serialize the response. Defaults to [SendJSON]. - SerializeError func(w http.ResponseWriter, err error) // Used to serialize the error response. Defaults to [SendJSONError]. - ErrorHandler func(err error) error // Used to transform any error into a unified error type structure with status code. Defaults to [ErrorHandler] - startTime time.Time + + Serialize func(w http.ResponseWriter, r *http.Request, ans any) error // Custom serializer that overrides the default one. + SerializeError func(w http.ResponseWriter, r *http.Request, err error) // Used to serialize the error response. Defaults to [SendError]. + ErrorHandler func(err error) error // Used to transform any error into a unified error type structure with status code. Defaults to [ErrorHandler] + startTime time.Time OpenAPIConfig OpenAPIConfig @@ -109,8 +110,8 @@ func NewServer(options ...func(*Server)) *Server { defaultOptions := [...]func(*Server){ WithAddr("localhost:9999"), WithDisallowUnknownFields(true), - WithSerializer(SendJSON), - WithErrorSerializer(SendJSONError), + WithSerializer(Send), + WithErrorSerializer(SendError), WithErrorHandler(ErrorHandler), } @@ -258,6 +259,9 @@ func WithAddr(addr string) func(*Server) { return func(c *Server) { c.Server.Addr = addr } } +// WithXML sets the serializer to XML +// +// Deprecated: fuego supports automatic XML serialization when using the header "Accept: application/xml". func WithXML() func(*Server) { return func(c *Server) { c.Serialize = SendXML @@ -274,11 +278,13 @@ func WithLogHandler(handler slog.Handler) func(*Server) { } } -func WithSerializer(serializer func(w http.ResponseWriter, ans any)) func(*Server) { +// WithSerializer sets a custom serializer that overrides the default one. +// Please send a PR if you think the default serializer should be improved, instead of jumping to this option. +func WithSerializer(serializer func(w http.ResponseWriter, r *http.Request, ans any) error) func(*Server) { return func(c *Server) { c.Serialize = serializer } } -func WithErrorSerializer(serializer func(w http.ResponseWriter, err error)) func(*Server) { +func WithErrorSerializer(serializer func(w http.ResponseWriter, r *http.Request, err error)) func(*Server) { return func(c *Server) { c.SerializeError = serializer } } diff --git a/options_test.go b/options_test.go index 16e1788f..888ab7df 100644 --- a/options_test.go +++ b/options_test.go @@ -5,6 +5,7 @@ import ( "html/template" "io" "log/slog" + "net/http" "net/http/httptest" "testing" @@ -37,15 +38,14 @@ func TestNewServer(t *testing.T) { } func TestWithXML(t *testing.T) { - s := NewServer( - WithXML(), - ) + s := NewServer() Get(s, "/", controller) Get(s, "/error", controllerWithError) t.Run("response is XML", func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Accept", "application/xml") s.Mux.ServeHTTP(recorder, req) @@ -57,12 +57,13 @@ func TestWithXML(t *testing.T) { t.Run("error response is XML", func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest("GET", "/error", nil) + req.Header.Set("Accept", "application/xml") s.Mux.ServeHTTP(recorder, req) require.Equal(t, 500, recorder.Code) - require.Equal(t, "application/xml", recorder.Header().Get("Content-Type")) require.Equal(t, "Internal Server Error500", recorder.Body.String()) + require.Equal(t, "application/xml", recorder.Header().Get("Content-Type")) }) } @@ -373,3 +374,24 @@ func TestServerTags(t *testing.T) { require.Equal(t, subGroup.tags, []string{"my-server-tag"}) }) } + +func TestCustomSerialization(t *testing.T) { + s := NewServer( + WithSerializer(func(w http.ResponseWriter, r *http.Request, a any) error { + w.WriteHeader(202) + _, err := w.Write([]byte("custom serialization")) + return err + }), + ) + + Get(s, "/", func(c *ContextNoBody) (ans, error) { + return ans{Ans: "Hello World"}, nil + }) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + s.Mux.ServeHTTP(w, req) + + require.Equal(t, 202, w.Code) + require.Equal(t, "custom serialization", w.Body.String()) +} diff --git a/serialization.go b/serialization.go index 7fda640d..4d053fdd 100644 --- a/serialization.go +++ b/serialization.go @@ -5,9 +5,13 @@ import ( "encoding/json" "encoding/xml" "errors" + "fmt" "log/slog" "net/http" "reflect" + "strings" + + "gopkg.in/yaml.v3" ) // OutTransformer is an interface for entities that can be transformed. @@ -71,13 +75,43 @@ func transformOut[T any](ctx context.Context, ans T) (T, error) { return ans, nil } -// Send sends a string response. -func Send(w http.ResponseWriter, text string) { - _, _ = w.Write([]byte(text)) +// Send sends a response. +// The format is determined by the Accept header. +func Send(w http.ResponseWriter, r *http.Request, ans any) error { + switch parseAcceptHeader(r.Header.Get("Accept"), ans) { + case "application/xml": + return SendXML(w, r, ans) + case "text/html": + return SendHTML(r.Context(), w, ans) + case "text/plain": + return SendText(w, ans) + case "application/json": + SendJSON(w, ans) + case "application/yaml": + SendYAML(w, ans) + default: + return errors.New("unsupported Accept header") + } + + return nil +} + +// SendYAML sends a YAML response. +// Declared as a variable to be able to override it for clients that need to customize serialization. +var SendYAML = func(w http.ResponseWriter, ans any) { + w.Header().Set("Content-Type", "application/x-yaml") + err := yaml.NewEncoder(w).Encode(ans) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Cannot serialize YAML", "error", err) + _, _ = w.Write([]byte(`{"error":"Cannot serialize YAML"}`)) + return + } } // SendJSON sends a JSON response. -func SendJSON(w http.ResponseWriter, ans any) { +// Declared as a variable to be able to override it for clients that need to customize serialization. +var SendJSON = func(w http.ResponseWriter, ans any) { w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(ans) if err != nil { @@ -88,6 +122,28 @@ func SendJSON(w http.ResponseWriter, ans any) { } } +// SendError sends an error. +// Declared as a variable to be able to override it for clients that need to customize serialization. +var SendError = func(w http.ResponseWriter, r *http.Request, err error) { + accept := parseAcceptHeader(r.Header.Get("Accept"), nil) + if accept == "" { + accept = "application/json" + } + + switch accept { + case "application/xml": + SendXMLError(w, r, err) + case "text/html": + _ = SendHTMLError(r.Context(), w, err) + case "text/plain": + _ = SendText(w, err) + case "application/json": + SendJSONError(w, err) + default: + SendJSONError(w, err) + } +} + // SendJSONError sends a JSON error response. // If the error implements ErrorWithStatus, the status code will be set. func SendJSONError(w http.ResponseWriter, err error) { @@ -109,26 +165,137 @@ func SendJSONError(w http.ResponseWriter, err error) { } // SendXML sends a XML response. -func SendXML(w http.ResponseWriter, ans any) { +// Declared as a variable to be able to override it for clients that need to customize serialization. +var SendXML = func(w http.ResponseWriter, r *http.Request, ans any) error { w.Header().Set("Content-Type", "application/xml") - err := xml.NewEncoder(w).Encode(ans) + return xml.NewEncoder(w).Encode(ans) +} + +// SendXMLError sends a XML error response. +// If the error implements ErrorWithStatus, the status code will be set. +func SendXMLError(w http.ResponseWriter, r *http.Request, err error) { + status := http.StatusInternalServerError + var errorStatus ErrorWithStatus + if errors.As(err, &errorStatus) { + status = errorStatus.StatusCode() + } + + w.WriteHeader(status) + err = SendXML(w, r, err) if err != nil { - w.WriteHeader(http.StatusInternalServerError) slog.Error("Cannot serialize XML", "error", err) _, _ = w.Write([]byte(`{"error":"Cannot serialize XML"}`)) - return } } -// SendXMLError sends a XML error response. -// If the error implements ErrorWithStatus, the status code will be set. -func SendXMLError(w http.ResponseWriter, err error) { +func SendHTMLError(ctx context.Context, w http.ResponseWriter, err error) error { status := http.StatusInternalServerError var errorStatus ErrorWithStatus if errors.As(err, &errorStatus) { status = errorStatus.StatusCode() } + w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(status) - SendXML(w, err) + return SendHTML(ctx, w, err.Error()) +} + +// SendHTML sends a HTML response. +// Declared as a variable to be able to override it for clients that need to customize serialization. +var SendHTML = func(ctx context.Context, w http.ResponseWriter, ans any) error { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + + ctxRenderer, ok := any(ans).(CtxRenderer) + if ok { + return ctxRenderer.Render(ctx, w) + } + + renderer, ok := any(ans).(Renderer) + if ok { + return renderer.Render(w) + } + + html, ok := any(ans).(HTML) + if ok { + _, err := w.Write([]byte(html)) + return err + } + + htmlString, ok := any(ans).(string) + if ok { + _, err := w.Write([]byte(htmlString)) + return err + } + + // The type cannot be converted to HTML + return fmt.Errorf("cannot serialize HTML from type %T (not string, fuego.HTML and does not implement fuego.CtxRenderer or fuego.Renderer)", ans) +} + +func SendText(w http.ResponseWriter, ans any) error { + var err error + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + stringToWrite, ok := any(ans).(string) + if !ok { + stringToWritePtr, okPtr := any(ans).(*string) + if okPtr { + stringToWrite = *stringToWritePtr + } else { + stringToWrite = fmt.Sprintf("%v", ans) + } + } + _, err = w.Write([]byte(stringToWrite)) + + return err +} + +func InferAcceptHeaderFromType(ans any) string { + _, ok := any(ans).(string) + if ok { + return "text/plain" + } + + _, ok = any(ans).(*string) + if ok { + return "text/plain" + } + + _, ok = any(ans).(HTML) + if ok { + return "text/html" + } + + _, ok = any(ans).(CtxRenderer) + if ok { + return "text/html" + } + + _, ok = any(&ans).(*CtxRenderer) + if ok { + return "text/html" + } + + _, ok = any(ans).(Renderer) + if ok { + return "text/html" + } + + _, ok = any(&ans).(*Renderer) + if ok { + return "text/html" + } + + return "application/json" +} + +func parseAcceptHeader(accept string, ans any) string { + if strings.Index(accept, ",") > 0 { + accept = accept[:strings.Index(accept, ",")] + } + if accept == "*/*" { + accept = "" + } + if accept == "" { + accept = InferAcceptHeaderFromType(ans) + } + return accept } diff --git a/serialization_test.go b/serialization_test.go index be9f21d7..12573cd0 100644 --- a/serialization_test.go +++ b/serialization_test.go @@ -2,6 +2,7 @@ package fuego import ( "context" + "io" "net/http/httptest" "testing" @@ -46,7 +47,8 @@ func TestJSON(t *testing.T) { func TestXML(t *testing.T) { t.Run("can serialize xml", func(t *testing.T) { w := httptest.NewRecorder() - SendXML(w, response{Message: "Hello World", Code: 200}) + err := SendXML(w, nil, response{Message: "Hello World", Code: 200}) + require.NoError(t, err) body := w.Body.String() require.Equal(t, `Hello World200`, body) @@ -55,7 +57,7 @@ func TestXML(t *testing.T) { t.Run("can serialize xml error", func(t *testing.T) { w := httptest.NewRecorder() err := HTTPError{Detail: "Hello World"} - SendXMLError(w, err) + SendXMLError(w, nil, err) body := w.Body.String() require.Equal(t, `Hello World`, body) @@ -235,7 +237,69 @@ func TestJSONError(t *testing.T) { func TestSend(t *testing.T) { w := httptest.NewRecorder() - Send(w, "Hello World") + SendText(w, "Hello World") require.Equal(t, "Hello World", w.Body.String()) } + +type templateMock struct{} + +func (t templateMock) Render(w io.Writer) error { + return nil +} + +var _ Renderer = templateMock{} + +func TestInferAcceptHeaderFromType(t *testing.T) { + t.Run("can infer json", func(t *testing.T) { + accept := InferAcceptHeaderFromType(response{}) + require.Equal(t, "application/json", accept) + }) + + t.Run("can infer that type is a template (implements Renderer)", func(t *testing.T) { + accept := InferAcceptHeaderFromType(templateMock{}) + require.Equal(t, "text/html", accept) + }) + + t.Run("can infer that type is a template (implements CtxRenderer)", func(t *testing.T) { + accept := InferAcceptHeaderFromType(MockCtxRenderer{}) + require.Equal(t, "text/html", accept) + }) +} + +func TestParseAcceptHeader(t *testing.T) { + t.Run("can parse text/plain", func(t *testing.T) { + accept := parseAcceptHeader("text/plain", "Hello World") + require.Equal(t, "text/plain", accept) + }) + + t.Run("can parse text/html", func(t *testing.T) { + accept := parseAcceptHeader("text/html", "

Hello World

") + require.Equal(t, "text/html", accept) + }) + + t.Run("can parse text/html from multiple options", func(t *testing.T) { + accept := parseAcceptHeader("text/html, text/plain", "

Hello World

") + require.Equal(t, "text/html", accept) + }) + + t.Run("can parse application/json", func(t *testing.T) { + accept := parseAcceptHeader("application/json", ans{}) + require.Equal(t, "application/json", accept) + }) + + t.Run("can infer json", func(t *testing.T) { + accept := parseAcceptHeader("", response{}) + require.Equal(t, "application/json", accept) + }) + + t.Run("can infer json", func(t *testing.T) { + accept := parseAcceptHeader("*/*", response{}) + require.Equal(t, "application/json", accept) + }) + + t.Run("can infer text/html from a real browser", func(t *testing.T) { + accept := parseAcceptHeader("text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", "

Hello World

") + require.Equal(t, "text/html", accept) + }) +} diff --git a/serve.go b/serve.go index 219ff12c..b938e8db 100644 --- a/serve.go +++ b/serve.go @@ -76,13 +76,7 @@ func initContext[Contextable ctx[Body], Body any](baseContext ContextNoBody) Con // HTTPHandler converts a Fuego controller into a http.HandlerFunc. func HTTPHandler[ReturnType, Body any, Contextable ctx[Body]](s *Server, controller func(c Contextable) (ReturnType, error)) http.HandlerFunc { - returnsHTML := reflect.TypeOf(controller).Out(0).Name() == "HTML" - var r ReturnType - _, returnsString := any(r).(*string) - if !returnsString { - _, returnsString = any(r).(string) - } - + // Just a check, not used at request time baseContext := *new(Contextable) if reflect.TypeOf(baseContext) == nil { slog.Info(fmt.Sprintf("context is nil: %v %T", baseContext, baseContext)) @@ -91,10 +85,10 @@ func HTTPHandler[ReturnType, Body any, Contextable ctx[Body]](s *Server, control return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Powered-By", "Fuego") - w.Header().Set("Trailer", "Server-Timing") - timeCtxInit := time.Now() + // CONTEXT INITIALIZATION + timeCtxInit := time.Now() var templates *template.Template if s.template != nil { templates = template.Must(s.template.Clone()) @@ -113,10 +107,11 @@ func HTTPHandler[ReturnType, Body any, Contextable ctx[Body]](s *Server, control timeController := time.Now() w.Header().Set("Server-Timing", Timing{"fuegoReqInit", timeController.Sub(timeCtxInit), ""}.String()) + // CONTROLLER ans, err := controller(ctx) if err != nil { err = s.ErrorHandler(err) - s.SerializeError(w, err) + s.SerializeError(w, r, err) return } timeAfterController := time.Now() @@ -126,66 +121,24 @@ func HTTPHandler[ReturnType, Body any, Contextable ctx[Body]](s *Server, control return } - ctxRenderer, ok := any(ans).(CtxRenderer) - if ok { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - err = ctxRenderer.Render(r.Context(), w) - if err != nil { - err = s.ErrorHandler(err) - s.SerializeError(w, err) - } - w.Header().Set("Server-Timing", Timing{"render", time.Since(timeAfterController), ""}.String()) - return - } - - renderer, ok := any(ans).(Renderer) - if ok { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - err = renderer.Render(w) - if err != nil { - err = s.ErrorHandler(err) - s.SerializeError(w, err) - } - w.Header().Add("Server-Timing", Timing{"render", time.Since(timeAfterController), ""}.String()) - return - } - - if returnsHTML { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - _, err = w.Write([]byte(any(ans).(HTML))) - if err != nil { - s.SerializeError(w, err) - } - w.Header().Add("Server-Timing", Timing{"render", time.Since(timeAfterController), ""}.String()) - return - } - + // TRANSFORM OUT timeTransformOut := time.Now() ans, err = transformOut(r.Context(), ans) if err != nil { err = s.ErrorHandler(err) - s.SerializeError(w, err) - return - } - - if returnsString { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - stringToWrite, ok := any(ans).(string) - if !ok { - stringToWrite = *any(ans).(*string) - } - _, err = w.Write([]byte(stringToWrite)) - if err != nil { - s.SerializeError(w, err) - } - w.Header().Add("Server-Timing", Timing{"write", time.Since(timeTransformOut), "transformOut"}.String()) + s.SerializeError(w, r, err) return } - timeAfterTransformOut := time.Now() w.Header().Add("Server-Timing", Timing{"transformOut", timeAfterTransformOut.Sub(timeTransformOut), "transformOut"}.String()) - s.Serialize(w, ans) + // SERIALIZATION + err = s.Serialize(w, r, ans) + // FINAL ERROR HANDLING + if err != nil { + err = s.ErrorHandler(err) + s.SerializeError(w, r, err) + } w.Header().Add("Server-Timing", Timing{"serialize", time.Since(timeAfterTransformOut), ""}.String()) } } diff --git a/serve_test.go b/serve_test.go index 2e147d12..6fb97b45 100644 --- a/serve_test.go +++ b/serve_test.go @@ -165,6 +165,7 @@ func TestHttpHandler(t *testing.T) { handler := HTTPHandler(s, testControllerReturningPtrToString) req := httptest.NewRequest("GET", "/testing", nil) + req.Header.Set("Accept", "text/plain") w := httptest.NewRecorder() handler(w, req) @@ -234,7 +235,7 @@ func (t testCtxErrorRenderer) Render(ctx context.Context, w io.Writer) error { func TestServeRenderer(t *testing.T) { s := NewServer( - WithErrorSerializer(func(w http.ResponseWriter, err error) { + WithErrorSerializer(func(w http.ResponseWriter, r *http.Request, err error) { w.WriteHeader(500) w.Write([]byte("

error

")) }), @@ -319,6 +320,24 @@ func TestServeRenderer(t *testing.T) { }) } +func TestServeError(t *testing.T) { + s := NewServer() + + Get(s, "/ctx/error-in-controller", func(c *ContextNoBody) (CtxRenderer, error) { + return nil, errors.New("error") + }) + + t.Run("error return, asking for HTML", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ctx/error-in-controller", nil) + req.Header.Set("Accept", "text/html") + w := httptest.NewRecorder() + s.Mux.ServeHTTP(w, req) + + require.Equal(t, 500, w.Code) + require.Equal(t, "Internal Server Error (500): ", w.Body.String()) + }) +} + func TestIni(t *testing.T) { t.Run("can initialize ContextNoBody", func(t *testing.T) { req := httptest.NewRequest("GET", "/ctx/error-in-rendering", nil) @@ -483,7 +502,13 @@ func TestServer_RunTLS(t *testing.T) { defer conn.Close() client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}} - resp, err := client.Get(fmt.Sprintf("https://%s/test", s.Server.Addr)) + req, err := http.NewRequest("GET", fmt.Sprintf("https://%s/test", s.Server.Addr), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + + resp, err := client.Do(req) if err != nil { t.Fatal(err) }