From d8a593a2c1c2552e47501e4e46532694af020ef0 Mon Sep 17 00:00:00 2001 From: Madeorsk Date: Sat, 21 Dec 2024 13:44:59 +0100 Subject: [PATCH] Allow to set a custom error handler in router configuration. + Add an option for a custom error handler and set a complete default one. * Change handler context to allow its initialization without a matched route. * Alphabetical order of HTTPError errors. --- docs/getting_started.md | 9 +++ examples/basic/main.zig | 9 +++ src/http/context.zig | 4 +- src/http/lib.zig | 8 +- src/http/router.zig | 94 ++++++++++++++++++++++- src/http/server.zig | 166 +++++++++++++--------------------------- 6 files changed, 169 insertions(+), 121 deletions(-) diff --git a/docs/getting_started.md b/docs/getting_started.md index 6104881..00e7263 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -97,6 +97,15 @@ pub fn main() !void { }); } }.handler_fn, + .error_handler = struct { + fn handler_fn(ctx: *Context, _: anyerror) !void { + try ctx.respond(.{ + .status = .@"Internal Server Error", + .mime = http.Mime.HTML, + .body = "Oh no, Internal Server Error!", + }); + } + }.handler_fn, }); // This provides the entry function into the Tardy runtime. This will run diff --git a/examples/basic/main.zig b/examples/basic/main.zig index 1e53a66..c07074a 100644 --- a/examples/basic/main.zig +++ b/examples/basic/main.zig @@ -78,6 +78,15 @@ pub fn main() !void { }); } }.handler_fn, + .error_handler = struct { + fn handler_fn(ctx: *Context, _: anyerror) !void { + try ctx.respond(.{ + .status = .@"Internal Server Error", + .mime = http.Mime.HTML, + .body = "Oh no, Internal Server Error!", + }); + } + }.handler_fn, }); // This provides the entry function into the Tardy runtime. This will run diff --git a/src/http/context.zig b/src/http/context.zig index e3d81ed..d208658 100644 --- a/src/http/context.zig +++ b/src/http/context.zig @@ -18,8 +18,6 @@ const _SSE = @import("sse.zig").SSE; const Runtime = @import("tardy").Runtime; const TaskFn = @import("tardy").TaskFn; -const raw_respond = @import("server.zig").raw_respond; - // Context is dependent on the server that gets created. pub fn Context(comptime Server: type, comptime AppState: type) type { return struct { @@ -30,7 +28,7 @@ pub fn Context(comptime Server: type, comptime AppState: type) type { /// Custom user-data state. state: AppState, /// The matched route instance. - route: *const Route(Server, AppState), + route: ?*const Route(Server, AppState), /// The Request that triggered this handler. request: *const Request, /// The Response that will be returned. diff --git a/src/http/lib.zig b/src/http/lib.zig index 1e9e40b..5d7e3d3 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -9,10 +9,12 @@ pub const Headers = @import("../core/case_string_map.zig").CaseStringMap([]const pub const Server = @import("server.zig").Server; pub const HTTPError = error{ - TooManyHeaders, ContentTooLarge, - MalformedRequest, + HTTPVersionNotSupported, InvalidMethod, + LengthRequired, + MalformedRequest, + MethodNotAllowed, + TooManyHeaders, URITooLong, - HTTPVersionNotSupported, }; diff --git a/src/http/router.zig b/src/http/router.zig index c2e0a3a..90b9046 100644 --- a/src/http/router.zig +++ b/src/http/router.zig @@ -2,6 +2,8 @@ const std = @import("std"); const log = std.log.scoped(.@"zzz/http/router"); const assert = std.debug.assert; +const HTTPError = @import("lib.zig").HTTPError; + const _Route = @import("router/route.zig").Route; const Capture = @import("router/routing_trie.zig").Capture; @@ -13,7 +15,13 @@ const _Context = @import("context.zig").Context; const _RoutingTrie = @import("router/routing_trie.zig").RoutingTrie; const QueryMap = @import("router/routing_trie.zig").QueryMap; -/// Default not found handler: send a plain text response. +/// Error handler type. +pub fn ErrorHandlerFn(comptime Server: type, comptime AppState: type) type { + const Context = _Context(Server, AppState); + return *const fn (context: *Context, err: anyerror) anyerror!void; +} + +/// Create a default not found handler: send a plain text response. pub fn default_not_found_handler(comptime Server: type, comptime AppState: type) _Route(Server, AppState).HandlerFn { const Context = _Context(Server, AppState); @@ -28,6 +36,87 @@ pub fn default_not_found_handler(comptime Server: type, comptime AppState: type) }.not_found_handler; } +/// Create a default error handler: send a plain text response with the error, if known, internal server error otherwise. +pub fn default_error_handler(comptime Server: type, comptime AppState: type) ErrorHandlerFn(Server, AppState) { + const Context = _Context(Server, AppState); + return struct { fn f(ctx: *Context, err: anyerror) !void { + // Handle all default HTTP errors. + switch (err) { + HTTPError.ContentTooLarge => { + try ctx.respond(.{ + .status = .@"Content Too Large", + .mime = Mime.TEXT, + .body = "Request was too large.", + }); + }, + HTTPError.HTTPVersionNotSupported => { + try ctx.respond(.{ + .status = .@"HTTP Version Not Supported", + .mime = Mime.HTML, + .body = "HTTP version not supported.", + }); + }, + HTTPError.InvalidMethod => { + try ctx.respond(.{ + .status = .@"Not Implemented", + .mime = Mime.TEXT, + .body = "Not implemented.", + }); + }, + HTTPError.LengthRequired => { + try ctx.respond(.{ + .status = .@"Length Required", + .mime = Mime.TEXT, + .body = "Length required.", + }); + }, + HTTPError.MalformedRequest => { + try ctx.respond(.{ + .status = .@"Bad Request", + .mime = Mime.TEXT, + .body = "Malformed request.", + }); + }, + HTTPError.MethodNotAllowed => { + if (ctx.route) |route| { + add_allow_header: { + // We also need to add to Allow header. + // This uses the connection's arena to allocate 64 bytes. + const allowed = route.get_allowed(ctx.provision.arena.allocator()) catch break :add_allow_header; + ctx.provision.response.headers.put_assume_capacity("Allow", allowed); + } + } + try ctx.respond(.{ + .status = .@"Method Not Allowed", + .mime = Mime.TEXT, + .body = "Method not allowed.", + }); + }, + HTTPError.TooManyHeaders => { + try ctx.respond(.{ + .status = .@"Request Header Fields Too Large", + .mime = Mime.TEXT, + .body = "Too many headers.", + }); + }, + HTTPError.URITooLong => { + try ctx.respond(.{ + .status = .@"URI Too Long", + .mime = Mime.TEXT, + .body = "URI too long.", + }); + }, + else => { + try ctx.respond(.{ + .status = .@"Internal Server Error", + .mime = Mime.TEXT, + .body = "Internal server error.", + }); + }, + } + } }.f; +} + /// Initialize a router with the given routes. pub fn Router(comptime Server: type, comptime AppState: type) type { return struct { @@ -40,10 +129,12 @@ pub fn Router(comptime Server: type, comptime AppState: type) type { /// Router configuration structure. pub const Configuration = struct { not_found_handler: Route.HandlerFn = default_not_found_handler(Server, AppState), + error_handler: ErrorHandlerFn(Server, AppState) = default_error_handler(Server, AppState), }; routes: RoutingTrie, not_found_route: Route, + error_handler: ErrorHandlerFn(Server, AppState), state: AppState, pub fn init(state: AppState, comptime _routes: []const Route, comptime configuration: Configuration) Self { @@ -51,6 +142,7 @@ pub fn Router(comptime Server: type, comptime AppState: type) type { // Initialize the routing tree from the given routes. .routes = comptime RoutingTrie.init(_routes), .not_found_route = comptime Route.init("").all(configuration.not_found_handler), + .error_handler = configuration.error_handler, .state = state, }; diff --git a/src/http/server.zig b/src/http/server.zig index 4f026c2..b607662 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -827,61 +827,52 @@ pub fn Server(comptime security: Security, comptime AppState: type) type { return socket; } + /// Try to pass the given error to the router error handler. + fn handle_error(router: *const Router, provision: *Provision, context: *Context, err: anyerror) RecvStatus { + // Call router error handler. + router.error_handler(context, err) catch { + provision.response.set(.{ + .status = .@"Internal Server Error", + .mime = Mime.TEXT, + .body = "Internal server error.", + }); + return raw_respond(provision) catch unreachable; + }; + return .spawned; + } + fn route_and_respond(runtime: *Runtime, p: *Provision, router: *const Router) !RecvStatus { - route: { + { const found = try router.get_route_from_host(p.request.uri.?, p.captures, &p.queries); const optional_handler = found.route.get_handler(p.request.method.?); - if (optional_handler) |handler| { - const context: *Context = try p.arena.allocator().create(Context); - context.* = .{ - .allocator = p.arena.allocator(), - .runtime = runtime, - .state = router.state, - .route = &found.route, - .request = &p.request, - .response = &p.response, - .captures = found.captures, - .queries = found.queries, - .provision = p, - }; + const context: *Context = try p.arena.allocator().create(Context); + context.* = .{ + .allocator = p.arena.allocator(), + .runtime = runtime, + .state = router.state, + .route = &found.route, + .request = &p.request, + .response = &p.response, + .captures = found.captures, + .queries = found.queries, + .provision = p, + }; + if (optional_handler) |handler| { @call(.auto, handler, .{ context, }) catch |e| { log.err("\"{s}\" handler failed with error: {}", .{ p.request.uri.?, e }); - p.response.set(.{ - .status = .@"Internal Server Error", - .mime = Mime.HTML, - .body = "", - }); - - return try raw_respond(p); + // Call router error handler. + return handle_error(router, p, context, e); }; return .spawned; } else { // If we match the route but not the method. - p.response.set(.{ - .status = .@"Method Not Allowed", - .mime = Mime.HTML, - .body = "405 Method Not Allowed", - }); - - // We also need to add to Allow header. - // This uses the connection's arena to allocate 64 bytes. - const allowed = found.route.get_allowed(p.arena.allocator()) catch { - p.response.set(.{ - .status = .@"Internal Server Error", - .mime = Mime.HTML, - .body = "", - }); - - break :route; - }; - - p.response.headers.put_assume_capacity("Allow", allowed); - break :route; + // Call router error handler with Method Not Allowed error. + return handle_error(router, p, context, HTTPError.MethodNotAllowed); } } @@ -903,14 +894,22 @@ pub fn Server(comptime security: Security, comptime AppState: type) type { var stage = provision.stage; const job = provision.job.recv; - if (job.count >= config.request_bytes_max) { - provision.response.set(.{ - .status = .@"Content Too Large", - .mime = Mime.HTML, - .body = "Request was too large", - }); + // Initialize a context for the error handler. + const context: *Context = try provision.arena.allocator().create(Context); + context.* = .{ + .allocator = provision.arena.allocator(), + .runtime = rt, + .state = router.state, + .route = null, + .request = &provision.request, + .response = &provision.response, + .captures = provision.captures, + .queries = &provision.queries, + .provision = provision, + }; - return try raw_respond(provision); + if (job.count >= config.request_bytes_max) { + return handle_error(router, provision, context, HTTPError.ContentTooLarge); } switch (stage) { @@ -943,52 +942,8 @@ pub fn Server(comptime security: Security, comptime AppState: type) type { .request_uri_bytes_max = config.request_uri_bytes_max, }, ) catch |e| { - switch (e) { - HTTPError.ContentTooLarge => { - provision.response.set(.{ - .status = .@"Content Too Large", - .mime = Mime.HTML, - .body = "Request was too large", - }); - }, - HTTPError.TooManyHeaders => { - provision.response.set(.{ - .status = .@"Request Header Fields Too Large", - .mime = Mime.HTML, - .body = "Too Many Headers", - }); - }, - HTTPError.MalformedRequest => { - provision.response.set(.{ - .status = .@"Bad Request", - .mime = Mime.HTML, - .body = "Malformed Request", - }); - }, - HTTPError.URITooLong => { - provision.response.set(.{ - .status = .@"URI Too Long", - .mime = Mime.HTML, - .body = "URI Too Long", - }); - }, - HTTPError.InvalidMethod => { - provision.response.set(.{ - .status = .@"Not Implemented", - .mime = Mime.HTML, - .body = "Not Implemented", - }); - }, - HTTPError.HTTPVersionNotSupported => { - provision.response.set(.{ - .status = .@"HTTP Version Not Supported", - .mime = Mime.HTML, - .body = "HTTP Version Not Supported", - }); - }, - } - - return raw_respond(provision) catch unreachable; + // Call router error handler. + return handle_error(router, provision, context, e); }; // Logging information about Request. @@ -1003,13 +958,7 @@ pub fn Server(comptime security: Security, comptime AppState: type) type { const is_http_1_1 = provision.request.version == .@"HTTP/1.1"; const is_host_present = provision.request.headers.get("Host") != null; if (is_http_1_1 and !is_host_present) { - provision.response.set(.{ - .status = .@"Bad Request", - .mime = Mime.HTML, - .body = "Missing \"Host\" Header", - }); - - return try raw_respond(provision); + return handle_error(router, provision, context, HTTPError.MalformedRequest); } if (!provision.request.expect_body()) { @@ -1067,13 +1016,7 @@ pub fn Server(comptime security: Security, comptime AppState: type) type { const content_length = blk: { const length_string = provision.request.headers.get("Content-Length") orelse { - provision.response.set(.{ - .status = .@"Length Required", - .mime = Mime.HTML, - .body = "", - }); - - return try raw_respond(provision); + return handle_error(router, provision, context, HTTPError.LengthRequired); }; break :blk try std.fmt.parseInt(u32, length_string, 10); @@ -1084,12 +1027,7 @@ pub fn Server(comptime security: Security, comptime AppState: type) type { // If this body will be too long, abort early. if (request_length > config.request_bytes_max) { - provision.response.set(.{ - .status = .@"Content Too Large", - .mime = Mime.HTML, - .body = "", - }); - return try raw_respond(provision); + return handle_error(router, provision, context, HTTPError.ContentTooLarge); } if (job.count >= request_length) {