From de264ba0575f62d04a5ada4cde0fd40990136bc6 Mon Sep 17 00:00:00 2001 From: EwenQuim Date: Wed, 3 Jul 2024 22:39:35 +0200 Subject: [PATCH] Multi-return --- ctx.go | 57 ++------ examples/full-app-gourmet/views/partials.go | 4 +- examples/full-app-gourmet/views/recipe.go | 27 ++-- html.go | 64 ++++++++- html_test.go | 10 +- multi_return.go | 61 +++++++++ multi_return_test.go | 106 +++++++++++++++ options.go | 16 ++- options_test.go | 29 +++- serialization.go | 143 +++++++++++++++++++- serialization_test.go | 27 ++++ serve.go | 107 +++++++-------- serve_test.go | 28 +++- 13 files changed, 537 insertions(+), 142 deletions(-) create mode 100644 multi_return.go create mode 100644 multi_return_test.go diff --git a/ctx.go b/ctx.go index a0826086..9c008152 100644 --- a/ctx.go +++ b/ctx.go @@ -64,7 +64,7 @@ type ctx[B any] interface { // and you want to override one above the other, you can do: // c.Render("admin.page.html", recipes, "partials/aaa/nav.partial.html") // By default, [templateToExecute] is added to the list of templates to override. - Render(templateToExecute string, data any, templateGlobsToOverride ...string) (HTML, error) + Render(templateToExecute string, data any, templateGlobsToOverride ...string) (CtxRenderer, error) Cookie(name string) (*http.Cookie, error) // Get request cookie SetCookie(cookie http.Cookie) // Sets response cookie @@ -222,51 +222,14 @@ func (c ContextNoBody) SetCookie(cookie http.Cookie) { // that the templates will be parsed only once, removing // the need to parse the templates on each request but also preventing // to dynamically use new templates. -func (c ContextNoBody) Render(templateToExecute string, data any, layoutsGlobs ...string) (HTML, error) { - if strings.Contains(templateToExecute, "/") || strings.Contains(templateToExecute, "*") { - - layoutsGlobs = append(layoutsGlobs, templateToExecute) // To override all blocks defined in the main template - cloned := template.Must(c.templates.Clone()) - tmpl, err := cloned.ParseFS(c.fs, layoutsGlobs...) - if err != nil { - return "", HTTPError{ - Err: err, - Status: http.StatusInternalServerError, - Title: "Error parsing template", - Detail: fmt.Errorf("error parsing template '%s': %w", layoutsGlobs, err).Error(), - Errors: []ErrorItem{ - { - Name: "templates", - Reason: "Check that the template exists and have the correct extension. Globs: " + strings.Join(layoutsGlobs, ", "), - }, - }, - } - } - c.templates = template.Must(tmpl.Clone()) - } - - // Get only last template name (for example, with partials/nav/main/nav.partial.html, get nav.partial.html) - myTemplate := strings.Split(templateToExecute, "/") - templateToExecute = myTemplate[len(myTemplate)-1] - - c.Res.Header().Set("Content-Type", "text/html; charset=utf-8") - err := c.templates.ExecuteTemplate(c.Res, templateToExecute, data) - if err != nil { - return "", HTTPError{ - Err: err, - Status: http.StatusInternalServerError, - Title: "Error rendering template", - Detail: fmt.Errorf("error executing template '%s': %w", templateToExecute, err).Error(), - Errors: []ErrorItem{ - { - Name: "templates", - Reason: "Check that the template exists and have the correct extension. Template: " + templateToExecute, - }, - }, - } - } - - return "", err +func (c ContextNoBody) Render(templateToExecute string, data any, layoutsGlobs ...string) (CtxRenderer, error) { + return &StdRenderer{ + templateToExecute: templateToExecute, + templates: c.templates, + layoutsGlobs: layoutsGlobs, + fs: c.fs, + data: data, + }, nil } // PathParams returns the path parameters of the request. @@ -427,7 +390,7 @@ func body[B any](c ContextNoBody) (B, error) { body, err = readURLEncoded[B](c.Req, c.readOptions) case "application/xml": body, err = readXML[B](c.Req.Context(), c.Req.Body, c.readOptions) - case "application/x-yaml": + case "application/x-yaml", "text/yaml; charset=utf-8", "application/yaml": // https://www.rfc-editor.org/rfc/rfc9512.html body, err = readYAML[B](c.Req.Context(), c.Req.Body, c.readOptions) case "application/octet-stream": // Read c.Req Body to bytes diff --git a/examples/full-app-gourmet/views/partials.go b/examples/full-app-gourmet/views/partials.go index 5a92efa4..0ddd2dca 100644 --- a/examples/full-app-gourmet/views/partials.go +++ b/examples/full-app-gourmet/views/partials.go @@ -5,12 +5,12 @@ import ( "github.com/go-fuego/fuego/examples/full-app-gourmet/store/types" ) -func (rs Ressource) unitPreselected(c fuego.ContextNoBody) (fuego.HTML, error) { +func (rs Ressource) unitPreselected(c fuego.ContextNoBody) (fuego.CtxRenderer, error) { id := c.QueryParam("IngredientID") ingredient, err := rs.IngredientsQueries.GetIngredient(c.Context(), id) if err != nil { - return "", err + return nil, err } return c.Render("preselected-unit.partial.html", fuego.H{ diff --git a/examples/full-app-gourmet/views/recipe.go b/examples/full-app-gourmet/views/recipe.go index f2a93666..97d29665 100644 --- a/examples/full-app-gourmet/views/recipe.go +++ b/examples/full-app-gourmet/views/recipe.go @@ -104,15 +104,18 @@ func (rs Ressource) showIndex(c fuego.ContextNoBody) (fuego.Templ, error) { }), nil } -func (rs Ressource) showRecipes(c fuego.ContextNoBody) (fuego.Templ, error) { +func (rs Ressource) showRecipes(c fuego.ContextNoBody) (*fuego.DataOrTemplate[[]store.Recipe], error) { recipes, err := rs.RecipesQueries.GetRecipes(c.Context()) if err != nil { return nil, err } - return templa.SearchPage(templa.SearchProps{ - Recipes: recipes, - }), nil + return fuego.DataOrHTML( + recipes, + templa.SearchPage(templa.SearchProps{ + Recipes: recipes, + }), + ), nil } func (rs Ressource) relatedRecipes(c fuego.ContextNoBody) (fuego.Templ, error) { @@ -220,7 +223,7 @@ func (rs Ressource) healthyRecipes(c fuego.ContextNoBody) (fuego.Templ, error) { }), nil } -func (rs Ressource) showRecipesList(c fuego.ContextNoBody) (fuego.HTML, error) { +func (rs Ressource) showRecipesList(c fuego.ContextNoBody) (fuego.CtxRenderer, error) { search := c.QueryParam("search") recipes, err := rs.RecipesQueries.SearchRecipes(c.Context(), store.SearchRecipesParams{ Search: sql.NullString{ @@ -229,28 +232,28 @@ func (rs Ressource) showRecipesList(c fuego.ContextNoBody) (fuego.HTML, error) { }, }) if err != nil { - return "", err + return nil, err } return c.Render("partials/recipes-list.partial.html", recipes) } -func (rs Ressource) addRecipe(c *fuego.ContextWithBody[store.CreateRecipeParams]) (fuego.HTML, error) { +func (rs Ressource) addRecipe(c *fuego.ContextWithBody[store.CreateRecipeParams]) (fuego.CtxRenderer, error) { body, err := c.Body() if err != nil { - return "", err + return nil, err } body.ID = uuid.NewString() _, err = rs.RecipesQueries.CreateRecipe(c.Context(), body) if err != nil { - return "", err + return nil, err } recipes, err := rs.RecipesQueries.GetRecipes(c.Context()) if err != nil { - return "", err + return nil, err } return c.Render("pages/admin.page.html", fuego.H{ @@ -258,12 +261,12 @@ func (rs Ressource) addRecipe(c *fuego.ContextWithBody[store.CreateRecipeParams] }) } -func (rs Ressource) RecipePage(c fuego.ContextNoBody) (fuego.HTML, error) { +func (rs Ressource) RecipePage(c fuego.ContextNoBody) (fuego.CtxRenderer, error) { id := c.PathParam("id") recipe, err := rs.RecipesQueries.GetRecipe(c.Context(), id) if err != nil { - return "", fmt.Errorf("error getting recipe %s: %w", id, err) + return nil, fmt.Errorf("error getting recipe %s: %w", id, err) } ingredients, err := rs.IngredientsQueries.GetIngredientsOfRecipe(c.Context(), id) diff --git a/html.go b/html.go index fdffa571..802aeaf2 100644 --- a/html.go +++ b/html.go @@ -5,9 +5,14 @@ import ( "fmt" "html/template" "io" + "io/fs" + "net/http" + "strings" ) -// CtxRenderer can be used with [github.com/a-h/templ] +// CtxRenderer is an interface that can be used to render a response. +// It is used with standard library templating engine, by using fuego.ContextXXX.Render +// It is compatible with [github.com/a-h/templ] out of the box. // Example: // // func getRecipes(ctx fuego.ContextNoBody) (fuego.CtxRenderer, error) { @@ -50,6 +55,63 @@ type HTML string // H is a shortcut for map[string]any type H map[string]any +// StdRenderer renders a template using the standard library templating engine. +type StdRenderer struct { + templateToExecute string + templates *template.Template + layoutsGlobs []string + fs fs.FS + data any +} + +var _ CtxRenderer = StdRenderer{} + +func (s StdRenderer) Render(ctx context.Context, w io.Writer) error { + if strings.Contains(s.templateToExecute, "/") || strings.Contains(s.templateToExecute, "*") { + + s.layoutsGlobs = append(s.layoutsGlobs, s.templateToExecute) // To override all blocks defined in the main template + cloned := template.Must(s.templates.Clone()) + tmpl, err := cloned.ParseFS(s.fs, s.layoutsGlobs...) + if err != nil { + return HTTPError{ + Err: err, + Status: http.StatusInternalServerError, + Title: "Error parsing template", + Detail: fmt.Errorf("error parsing template '%s': %w", s.layoutsGlobs, err).Error(), + Errors: []ErrorItem{ + { + Name: "templates", + Reason: "Check that the template exists and have the correct extension. Globs: " + strings.Join(s.layoutsGlobs, ", "), + }, + }, + } + } + s.templates = template.Must(tmpl.Clone()) + } + + // Get only last template name (for example, with partials/nav/main/nav.partial.html, get nav.partial.html) + myTemplate := strings.Split(s.templateToExecute, "/") + s.templateToExecute = myTemplate[len(myTemplate)-1] + + err := s.templates.ExecuteTemplate(w, s.templateToExecute, s.data) + if err != nil { + return HTTPError{ + Err: err, + Status: http.StatusInternalServerError, + Title: "Error rendering template", + Detail: fmt.Errorf("error executing template '%s': %w", s.templateToExecute, err).Error(), + Errors: []ErrorItem{ + { + Name: "templates", + Reason: "Check that the template exists and have the correct extension. Template: " + s.templateToExecute, + }, + }, + } + } + + return err +} + // loadTemplates func (s *Server) loadTemplates(patterns ...string) error { tmpl, err := template.ParseFS(s.fs, patterns...) diff --git a/html_test.go b/html_test.go index 990c5be8..7e07f37f 100644 --- a/html_test.go +++ b/html_test.go @@ -18,7 +18,7 @@ func TestRender(t *testing.T) { WithTemplateGlobs("testdata/*.html"), ) - Get(s, "/test", func(ctx *ContextNoBody) (HTML, error) { + Get(s, "/test", func(ctx *ContextNoBody) (CtxRenderer, error) { return ctx.Render("testdata/test.html", H{"Name": "test"}) }) @@ -43,7 +43,7 @@ func TestRender(t *testing.T) { }) t.Run("cannot parse unexisting file", func(t *testing.T) { - Get(s, "/file-not-found", func(ctx ContextNoBody) (HTML, error) { + Get(s, "/file-not-found", func(ctx ContextNoBody) (CtxRenderer, error) { return ctx.Render("testdata/not-found.html", H{"Name": "test"}) }) @@ -56,7 +56,7 @@ func TestRender(t *testing.T) { }) t.Run("can execute template with missing variable in map", func(t *testing.T) { - Get(s, "/impossible", func(ctx ContextNoBody) (HTML, error) { + Get(s, "/impossible", func(ctx ContextNoBody) (CtxRenderer, error) { return ctx.Render("testdata/test.html", H{"NotName": "test"}) }) @@ -71,7 +71,7 @@ func TestRender(t *testing.T) { }) t.Run("cannot execute template with missing variable in struct", func(t *testing.T) { - Get(s, "/impossible-struct", func(ctx ContextNoBody) (HTML, error) { + Get(s, "/impossible-struct", func(ctx ContextNoBody) (CtxRenderer, error) { return ctx.Render("testdata/test.html", struct{}{}) }) @@ -93,7 +93,7 @@ func BenchmarkRender(b *testing.B) { WithTemplateGlobs("testdata/*.html"), ) - Get(s, "/test", func(ctx ContextNoBody) (HTML, error) { + Get(s, "/test", func(ctx ContextNoBody) (CtxRenderer, error) { return ctx.Render("testdata/test.html", H{"Name": "test"}) }) diff --git a/multi_return.go b/multi_return.go new file mode 100644 index 00000000..38def6d1 --- /dev/null +++ b/multi_return.go @@ -0,0 +1,61 @@ +package fuego + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "io" + + "gopkg.in/yaml.v3" +) + +// DataOrTemplate is a struct that can return either data or a template +// depending on the asked type. +type DataOrTemplate[T any] struct { + Data T + Template any +} + +var ( + _ CtxRenderer = DataOrTemplate[any]{} // Can render HTML (template) + _ json.Marshaler = DataOrTemplate[any]{} // Can render JSON (data) + _ xml.Marshaler = DataOrTemplate[any]{} // Can render XML (data) + _ yaml.Marshaler = DataOrTemplate[any]{} // Can render YAML (data) + _ fmt.Stringer = DataOrTemplate[any]{} // Can render string (data) +) + +func (m DataOrTemplate[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(m.Data) +} + +func (m DataOrTemplate[T]) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { + return e.Encode(m.Data) +} + +func (m DataOrTemplate[T]) MarshalYAML() (interface{}, error) { + return m.Data, nil +} + +func (m DataOrTemplate[T]) String() string { + return fmt.Sprintf("%v", m.Data) +} + +func (m DataOrTemplate[T]) Render(c context.Context, w io.Writer) error { + switch m.Template.(type) { + case CtxRenderer: + return m.Template.(CtxRenderer).Render(c, w) + case Renderer: + return m.Template.(Renderer).Render(w) + default: + panic("template must be either CtxRenderer or Renderer") + } +} + +// Helper function to create a DataOrTemplate return item without specifying the type. +func DataOrHTML[T any](data T, template any) *DataOrTemplate[T] { + return &DataOrTemplate[T]{ + Data: data, + Template: template, + } +} diff --git a/multi_return_test.go b/multi_return_test.go new file mode 100644 index 00000000..83240b43 --- /dev/null +++ b/multi_return_test.go @@ -0,0 +1,106 @@ +package fuego + +import ( + "context" + "errors" + "io" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +type MockCtxRenderer struct { + RenderFunc func(context.Context, io.Writer) error +} + +func RenderString(s string) MockCtxRenderer { + return MockCtxRenderer{ + RenderFunc: func(c context.Context, w io.Writer) error { + _, err := w.Write([]byte(s)) + return err + }, + } +} + +var _ CtxRenderer = MockCtxRenderer{} + +func (m MockCtxRenderer) Render(c context.Context, w io.Writer) error { + if m.RenderFunc == nil { + return errors.New("RenderFunc is nil") + } + return m.RenderFunc(c, w) +} + +type MyType struct { + Name string +} + +func TestMultiReturn(t *testing.T) { + s := NewServer() + + Get(s, "/data", func(c ContextNoBody) (DataOrTemplate[MyType], error) { + entity := MyType{Name: "Ewen"} + + return DataOrTemplate[MyType]{ + Data: entity, + Template: RenderString(`
` + entity.Name + `
`), + }, nil + }) + + Get(s, "/other", func(c ContextNoBody) (*DataOrTemplate[MyType], error) { + entity := MyType{Name: "Ewen"} + + return DataOrHTML( + entity, + RenderString(`
`+entity.Name+`
`), + ), nil + }) + + t.Run("requests HTML by default", func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/data", nil) + + s.Mux.ServeHTTP(recorder, req) + + require.Equal(t, 200, recorder.Code) + require.Equal(t, "text/html; charset=utf-8", recorder.Header().Get("Content-Type")) + require.Equal(t, `
Ewen
`, recorder.Body.String()) + }) + + t.Run("requests JSON", func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/data", nil) + req.Header.Set("Accept", "application/json") + + s.Mux.ServeHTTP(recorder, req) + + require.Equal(t, 200, recorder.Code) + require.Equal(t, "application/json", recorder.Header().Get("Content-Type")) + require.Equal(t, crlf(`{"Name":"Ewen"}`), recorder.Body.String()) + }) + + t.Run("requests XML", func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/data", nil) + req.Header.Set("Accept", "application/xml") + + s.Mux.ServeHTTP(recorder, req) + + require.Equal(t, 200, recorder.Code) + require.Equal(t, "application/xml", recorder.Header().Get("Content-Type")) + require.Equal(t, `Ewen`, recorder.Body.String()) + }) + + t.Run("requests HTML", func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/data", nil) + req.Header.Set("Accept", "text/html") + + s.Mux.ServeHTTP(recorder, req) + + require.Equal(t, 200, recorder.Code) + require.Contains(t, recorder.Header().Get("Content-Type"), "text/html") + require.Equal(t, `
Ewen
`, recorder.Body.String()) + }) +} diff --git a/options.go b/options.go index f6d43c18..58d29efc 100644 --- a/options.go +++ b/options.go @@ -70,10 +70,12 @@ type Server struct { DisallowUnknownFields bool // If true, the server will return an error if the request body contains unknown fields. Useful for quick debugging in development. DisableOpenapi bool // If true, the routes within the server will not generate an openapi spec. maxBodySize int64 - Serialize func(w http.ResponseWriter, ans any) // Used to serialize the response. Defaults to [SendJSON]. - SerializeError func(w http.ResponseWriter, err error) // Used to serialize the error response. Defaults to [SendJSONError]. - ErrorHandler func(err error) error // Used to transform any error into a unified error type structure with status code. Defaults to [ErrorHandler] - startTime time.Time + + // Custom serializer that overrides the default one. + Serialize func(w http.ResponseWriter, ans any) + SerializeError func(w http.ResponseWriter, err error) // Used to serialize the error response. Defaults to [SendJSONError]. + ErrorHandler func(err error) error // Used to transform any error into a unified error type structure with status code. Defaults to [ErrorHandler] + startTime time.Time OpenAPIConfig OpenAPIConfig @@ -109,7 +111,6 @@ func NewServer(options ...func(*Server)) *Server { defaultOptions := [...]func(*Server){ WithAddr("localhost:9999"), WithDisallowUnknownFields(true), - WithSerializer(SendJSON), WithErrorSerializer(SendJSONError), WithErrorHandler(ErrorHandler), } @@ -258,6 +259,9 @@ func WithAddr(addr string) func(*Server) { return func(c *Server) { c.Server.Addr = addr } } +// WithXML sets the serializer to XML +// +// Deprecated: fuego supports automatic XML serialization when using the header "Accept: application/xml". func WithXML() func(*Server) { return func(c *Server) { c.Serialize = SendXML @@ -274,6 +278,8 @@ func WithLogHandler(handler slog.Handler) func(*Server) { } } +// WithSerializer sets a custom serializer that overrides the default one. +// Please send a PR if you think the default serializer should be improved, instead of jumping to this option. func WithSerializer(serializer func(w http.ResponseWriter, ans any)) func(*Server) { return func(c *Server) { c.Serialize = serializer } } diff --git a/options_test.go b/options_test.go index 16e1788f..e0d186f6 100644 --- a/options_test.go +++ b/options_test.go @@ -5,6 +5,7 @@ import ( "html/template" "io" "log/slog" + "net/http" "net/http/httptest" "testing" @@ -37,15 +38,14 @@ func TestNewServer(t *testing.T) { } func TestWithXML(t *testing.T) { - s := NewServer( - WithXML(), - ) + s := NewServer() Get(s, "/", controller) Get(s, "/error", controllerWithError) t.Run("response is XML", func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Accept", "application/xml") s.Mux.ServeHTTP(recorder, req) @@ -57,12 +57,13 @@ func TestWithXML(t *testing.T) { t.Run("error response is XML", func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest("GET", "/error", nil) + req.Header.Set("Accept", "application/xml") s.Mux.ServeHTTP(recorder, req) require.Equal(t, 500, recorder.Code) - require.Equal(t, "application/xml", recorder.Header().Get("Content-Type")) require.Equal(t, "Internal Server Error500", recorder.Body.String()) + require.Equal(t, "application/xml", recorder.Header().Get("Content-Type")) }) } @@ -373,3 +374,23 @@ func TestServerTags(t *testing.T) { require.Equal(t, subGroup.tags, []string{"my-server-tag"}) }) } + +func TestCustomSerialization(t *testing.T) { + s := NewServer( + WithSerializer(func(w http.ResponseWriter, a any) { + w.WriteHeader(202) + w.Write([]byte("custom serialization")) + }), + ) + + Get(s, "/", func(c *ContextNoBody) (ans, error) { + return ans{Ans: "Hello World"}, nil + }) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + s.Mux.ServeHTTP(w, req) + + require.Equal(t, 202, w.Code) + require.Equal(t, "custom serialization", w.Body.String()) +} diff --git a/serialization.go b/serialization.go index 7fda640d..1880840b 100644 --- a/serialization.go +++ b/serialization.go @@ -5,9 +5,12 @@ import ( "encoding/json" "encoding/xml" "errors" + "fmt" "log/slog" "net/http" "reflect" + + "gopkg.in/yaml.v3" ) // OutTransformer is an interface for entities that can be transformed. @@ -76,8 +79,22 @@ func Send(w http.ResponseWriter, text string) { _, _ = w.Write([]byte(text)) } +// SendYAML sends a YAML response. +// Declared as a variable to be able to override it for clients that need to customize serialization. +var SendYAML = func(w http.ResponseWriter, ans any) { + w.Header().Set("Content-Type", "application/x-yaml") + err := yaml.NewEncoder(w).Encode(ans) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("Cannot serialize YAML", "error", err) + _, _ = w.Write([]byte(`{"error":"Cannot serialize YAML"}`)) + return + } +} + // SendJSON sends a JSON response. -func SendJSON(w http.ResponseWriter, ans any) { +// Declared as a variable to be able to override it for clients that need to customize serialization. +var SendJSON = func(w http.ResponseWriter, ans any) { w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(ans) if err != nil { @@ -88,6 +105,28 @@ func SendJSON(w http.ResponseWriter, ans any) { } } +// SendError sends an error. +// Declared as a variable to be able to override it for clients that need to customize serialization. +var SendError = func(w http.ResponseWriter, r *http.Request, err error) { + accept := parseAcceptHeader(r.Header.Get("Accept"), nil) + if accept == "" { + accept = "application/json" + } + + switch accept { + case "application/xml": + SendXMLError(w, err) + case "text/html": + _ = SendHTMLError(r.Context(), w, err) + case "text/plain": + _ = SendText(w, err) + case "application/json": + SendJSONError(w, err) + default: + SendJSONError(w, err) + } +} + // SendJSONError sends a JSON error response. // If the error implements ErrorWithStatus, the status code will be set. func SendJSONError(w http.ResponseWriter, err error) { @@ -109,7 +148,8 @@ func SendJSONError(w http.ResponseWriter, err error) { } // SendXML sends a XML response. -func SendXML(w http.ResponseWriter, ans any) { +// Declared as a variable to be able to override it for clients that need to customize serialization. +var SendXML = func(w http.ResponseWriter, ans any) { w.Header().Set("Content-Type", "application/xml") err := xml.NewEncoder(w).Encode(ans) if err != nil { @@ -132,3 +172,102 @@ func SendXMLError(w http.ResponseWriter, err error) { w.WriteHeader(status) SendXML(w, err) } + +func SendHTMLError(ctx context.Context, w http.ResponseWriter, err error) error { + status := http.StatusInternalServerError + var errorStatus ErrorWithStatus + if errors.As(err, &errorStatus) { + status = errorStatus.StatusCode() + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(status) + return SendHTML(ctx, w, err.Error()) +} + +// SendHTML sends a HTML response. +// Declared as a variable to be able to override it for clients that need to customize serialization. +var SendHTML = func(ctx context.Context, w http.ResponseWriter, ans any) error { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + + ctxRenderer, ok := any(ans).(CtxRenderer) + if ok { + return ctxRenderer.Render(ctx, w) + } + + renderer, ok := any(ans).(Renderer) + if ok { + return renderer.Render(w) + } + + html, ok := any(ans).(HTML) + if ok { + _, err := w.Write([]byte(html)) + return err + } + + htmlString, ok := any(ans).(string) + if ok { + _, err := w.Write([]byte(htmlString)) + return err + } + + // The type cannot be converted to HTML + return fmt.Errorf("cannot serialize HTML from type %T (not string, fuego.HTML and does not implement fuego.CtxRenderer or fuego.Renderer)", ans) +} + +func SendText(w http.ResponseWriter, ans any) error { + var err error + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + stringToWrite, ok := any(ans).(string) + if !ok { + stringToWritePtr, okPtr := any(ans).(*string) + if okPtr { + stringToWrite = *stringToWritePtr + } else { + stringToWrite = fmt.Sprintf("%v", ans) + } + } + _, err = w.Write([]byte(stringToWrite)) + + return err +} + +func InferAcceptHeaderFromType(ans any) string { + _, ok := any(ans).(string) + if ok { + return "text/plain" + } + + _, ok = any(ans).(*string) + if ok { + return "text/plain" + } + + _, ok = any(ans).(HTML) + if ok { + return "text/html" + } + + _, ok = any(ans).(CtxRenderer) + if ok { + return "text/html" + } + + _, ok = any(&ans).(*CtxRenderer) + if ok { + return "text/html" + } + + _, ok = any(ans).(Renderer) + if ok { + return "text/html" + } + + _, ok = any(&ans).(*Renderer) + if ok { + return "text/html" + } + + return "application/json" +} diff --git a/serialization_test.go b/serialization_test.go index be9f21d7..e894d25c 100644 --- a/serialization_test.go +++ b/serialization_test.go @@ -2,6 +2,7 @@ package fuego import ( "context" + "io" "net/http/httptest" "testing" @@ -239,3 +240,29 @@ func TestSend(t *testing.T) { require.Equal(t, "Hello World", w.Body.String()) } + +type templateMock struct{} + +func (t templateMock) Render(w io.Writer) error { + return nil +} + +var _ Renderer = templateMock{} + +func TestInferAcceptHeaderFromType(t *testing.T) { + t.Run("can infer json", func(t *testing.T) { + accept := InferAcceptHeaderFromType(response{}) + require.Equal(t, "application/json", accept) + }) + + t.Run("can infer that type is a template (implements Renderer)", func(t *testing.T) { + accept := InferAcceptHeaderFromType(templateMock{}) + require.Equal(t, "text/html", accept) + }) + + t.Run("can infer that type is a template (implements CtxRenderer)", func(t *testing.T) { + + accept := InferAcceptHeaderFromType(MockCtxRenderer{}) + require.Equal(t, "text/html", accept) + }) +} diff --git a/serve.go b/serve.go index 219ff12c..275bc2e7 100644 --- a/serve.go +++ b/serve.go @@ -1,11 +1,13 @@ package fuego import ( + "errors" "fmt" "html/template" "log/slog" "net/http" "reflect" + "strings" "time" ) @@ -76,13 +78,7 @@ func initContext[Contextable ctx[Body], Body any](baseContext ContextNoBody) Con // HTTPHandler converts a Fuego controller into a http.HandlerFunc. func HTTPHandler[ReturnType, Body any, Contextable ctx[Body]](s *Server, controller func(c Contextable) (ReturnType, error)) http.HandlerFunc { - returnsHTML := reflect.TypeOf(controller).Out(0).Name() == "HTML" - var r ReturnType - _, returnsString := any(r).(*string) - if !returnsString { - _, returnsString = any(r).(string) - } - + // Just a check, not used at request time baseContext := *new(Contextable) if reflect.TypeOf(baseContext) == nil { slog.Info(fmt.Sprintf("context is nil: %v %T", baseContext, baseContext)) @@ -91,10 +87,10 @@ func HTTPHandler[ReturnType, Body any, Contextable ctx[Body]](s *Server, control return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Powered-By", "Fuego") - w.Header().Set("Trailer", "Server-Timing") - timeCtxInit := time.Now() + // CONTEXT INITIALIZATION + timeCtxInit := time.Now() var templates *template.Template if s.template != nil { templates = template.Must(s.template.Clone()) @@ -113,10 +109,11 @@ func HTTPHandler[ReturnType, Body any, Contextable ctx[Body]](s *Server, control timeController := time.Now() w.Header().Set("Server-Timing", Timing{"fuegoReqInit", timeController.Sub(timeCtxInit), ""}.String()) + // CONTROLLER ans, err := controller(ctx) if err != nil { err = s.ErrorHandler(err) - s.SerializeError(w, err) + SendError(w, r, err) return } timeAfterController := time.Now() @@ -126,66 +123,58 @@ func HTTPHandler[ReturnType, Body any, Contextable ctx[Body]](s *Server, control return } - ctxRenderer, ok := any(ans).(CtxRenderer) - if ok { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - err = ctxRenderer.Render(r.Context(), w) - if err != nil { - err = s.ErrorHandler(err) - s.SerializeError(w, err) - } - w.Header().Set("Server-Timing", Timing{"render", time.Since(timeAfterController), ""}.String()) + // TRANSFORM OUT + timeTransformOut := time.Now() + ans, err = transformOut(r.Context(), ans) + if err != nil { + err = s.ErrorHandler(err) + SendError(w, r, err) return } + timeAfterTransformOut := time.Now() + w.Header().Add("Server-Timing", Timing{"transformOut", timeAfterTransformOut.Sub(timeTransformOut), "transformOut"}.String()) - renderer, ok := any(ans).(Renderer) - if ok { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - err = renderer.Render(w) - if err != nil { - err = s.ErrorHandler(err) - s.SerializeError(w, err) - } - w.Header().Add("Server-Timing", Timing{"render", time.Since(timeAfterController), ""}.String()) + // SERIALIZATION + // Custom serialization + if s.Serialize != nil { + s.Serialize(w, ans) return } - if returnsHTML { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - _, err = w.Write([]byte(any(ans).(HTML))) - if err != nil { - s.SerializeError(w, err) - } - w.Header().Add("Server-Timing", Timing{"render", time.Since(timeAfterController), ""}.String()) - return + // Default serialization + switch parseAcceptHeader(r.Header.Get("Accept"), ans) { + case "application/xml": + SendXML(w, ans) + case "text/html": + err = SendHTML(r.Context(), w, ans) + case "text/plain": + err = SendText(w, ans) + case "application/json": + SendJSON(w, ans) + case "application/yaml": + SendYAML(w, ans) + default: + SendError(w, r, errors.New("unsupported Accept header")) } + w.Header().Add("Server-Timing", Timing{"serialize", time.Since(timeAfterTransformOut), ""}.String()) - timeTransformOut := time.Now() - ans, err = transformOut(r.Context(), ans) + // FINAL ERROR HANDLING if err != nil { err = s.ErrorHandler(err) - s.SerializeError(w, err) - return + SendError(w, r, err) } + } +} - if returnsString { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - stringToWrite, ok := any(ans).(string) - if !ok { - stringToWrite = *any(ans).(*string) - } - _, err = w.Write([]byte(stringToWrite)) - if err != nil { - s.SerializeError(w, err) - } - w.Header().Add("Server-Timing", Timing{"write", time.Since(timeTransformOut), "transformOut"}.String()) - return - } - - timeAfterTransformOut := time.Now() - w.Header().Add("Server-Timing", Timing{"transformOut", timeAfterTransformOut.Sub(timeTransformOut), "transformOut"}.String()) - - s.Serialize(w, ans) - w.Header().Add("Server-Timing", Timing{"serialize", time.Since(timeAfterTransformOut), ""}.String()) +func parseAcceptHeader(accept string, ans any) string { + if strings.Index(accept, ",") > 0 { + accept = accept[:strings.Index(accept, ",")] + } + if accept == "*/*" { + accept = "" + } + if accept == "" { + accept = InferAcceptHeaderFromType(ans) } + return accept } diff --git a/serve_test.go b/serve_test.go index 2e147d12..43e7fd1a 100644 --- a/serve_test.go +++ b/serve_test.go @@ -165,6 +165,7 @@ func TestHttpHandler(t *testing.T) { handler := HTTPHandler(s, testControllerReturningPtrToString) req := httptest.NewRequest("GET", "/testing", nil) + req.Header.Set("Accept", "text/plain") w := httptest.NewRecorder() handler(w, req) @@ -266,7 +267,7 @@ func TestServeRenderer(t *testing.T) { s.Mux.ServeHTTP(w, req) require.Equal(t, 500, w.Code) - require.Equal(t, "

error

", w.Body.String()) + require.Equal(t, crlf(`{"title":"Internal Server Error","status":500}`), w.Body.String()) }) t.Run("error in rendering", func(t *testing.T) { @@ -275,7 +276,7 @@ func TestServeRenderer(t *testing.T) { s.Mux.ServeHTTP(w, req) require.Equal(t, 500, w.Code) - require.Equal(t, "

error

", w.Body.String()) + require.Equal(t, crlf(`{"title":"Internal Server Error","status":500}`), w.Body.String()) }) }) @@ -305,7 +306,17 @@ func TestServeRenderer(t *testing.T) { s.Mux.ServeHTTP(w, req) require.Equal(t, 500, w.Code) - require.Equal(t, "

error

", w.Body.String()) + require.Equal(t, crlf(`{"title":"Internal Server Error","status":500}`), w.Body.String()) + }) + + t.Run("error return, asking for HTML", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ctx/error-in-controller", nil) + req.Header.Set("Accept", "text/html") + w := httptest.NewRecorder() + s.Mux.ServeHTTP(w, req) + + require.Equal(t, 500, w.Code) + require.Equal(t, "Internal Server Error (500): ", w.Body.String()) }) t.Run("error in rendering", func(t *testing.T) { @@ -314,7 +325,7 @@ func TestServeRenderer(t *testing.T) { s.Mux.ServeHTTP(w, req) require.Equal(t, 500, w.Code) - require.Equal(t, "

error

", w.Body.String()) + require.Equal(t, crlf(`{"title":"Internal Server Error","status":500}`), w.Body.String()) }) }) } @@ -483,7 +494,13 @@ func TestServer_RunTLS(t *testing.T) { defer conn.Close() client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}} - resp, err := client.Get(fmt.Sprintf("https://%s/test", s.Server.Addr)) + req, err := http.NewRequest("GET", fmt.Sprintf("https://%s/test", s.Server.Addr), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + + resp, err := client.Do(req) if err != nil { t.Fatal(err) } @@ -563,3 +580,4 @@ func newTLSTestHelper() (*tlsTestHelper, error) { keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateKeyBytes}) return &tlsTestHelper{cert: certPEM, key: keyPEM}, nil } +