From 35df0ff1c6ca06002e85fc46d96700bdf5bb61f5 Mon Sep 17 00:00:00 2001 From: Muki Kiboigo Date: Sun, 22 Dec 2024 21:39:40 -0800 Subject: [PATCH] feat(router): Major Rewrite + Middleware This commit changes a lot regarding Routing. We have returned to a runtime Router as it makes things much easier in terms of maintaining it. This also adds layers allowing for Middleware to be applied subjectively. --- build.zig | 1 + build.zig.zon | 4 +- docs/getting_started.md | 105 +- docs/https.md | 89 +- examples/basic/main.zig | 95 +- examples/benchmark/main.zig | 20 +- examples/fs/main.zig | 29 +- examples/middleware/main.zig | 92 ++ examples/minram/main.zig | 20 +- examples/multithread/main.zig | 32 +- examples/sse/main.zig | 40 +- examples/tls/main.zig | 88 +- examples/unix/main.zig | 18 +- examples/valgrind/main.zig | 38 +- flake.nix | 31 +- src/core/case_string_map.zig | 2 +- src/core/zc_buffer.zig | 4 +- src/http/context.zig | 305 +++--- src/http/lib.zig | 12 + src/http/response.zig | 1 + src/http/router.zig | 101 +- src/http/router/bundle.zig | 8 + src/http/router/fs_dir.zig | 468 +++++---- src/http/router/layer.zig | 18 + src/http/router/middleware.zig | 120 +++ src/http/router/route.zig | 388 ++++--- src/http/router/routing_trie.zig | 479 +++++---- src/http/router/token_hash_map.zig | 203 ---- src/http/server.zig | 1579 ++++++++++++++-------------- src/http/sse.zig | 85 +- src/tls/bear.zig | 4 +- src/unit_test.zig | 1 - 32 files changed, 2304 insertions(+), 2176 deletions(-) create mode 100644 examples/middleware/main.zig create mode 100644 src/http/router/bundle.zig create mode 100644 src/http/router/layer.zig create mode 100644 src/http/router/middleware.zig delete mode 100644 src/http/router/token_hash_map.zig diff --git a/build.zig b/build.zig index af5866f..83a3912 100644 --- a/build.zig +++ b/build.zig @@ -33,6 +33,7 @@ pub fn build(b: *std.Build) void { add_example(b, "tls", true, target, optimize, zzz); add_example(b, "minram", false, target, optimize, zzz); add_example(b, "fs", false, target, optimize, zzz); + add_example(b, "middleware", false, target, optimize, zzz); add_example(b, "multithread", false, target, optimize, zzz); add_example(b, "benchmark", false, target, optimize, zzz); add_example(b, "valgrind", true, target, optimize, zzz); diff --git a/build.zig.zon b/build.zig.zon index 1d70f35..a4cdc2f 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -4,8 +4,8 @@ .minimum_zig_version = "0.13.0", .dependencies = .{ .tardy = .{ - .url = "git+https://github.com/mookums/tardy?ref=v0.2.1#e133b423e455a637a068296d0a3d00a4c71b1143", - .hash = "1220c3c3e6dd2fc6e61645dff2d22d46cb3eca3fb5c4c14a0015dab881d0a5af3976", + .url = "git+https://github.com/mookums/tardy#c2851c1e8ec5c66a16a0fb318d6bacea4ae9cc0b", + .hash = "122054b7c88eca71ab699e7ff530cd56303459356ea9ff3c9c78794943b035cadfea", }, .bearssl = .{ .url = "git+https://github.com/mookums/bearssl-zig#37a96eee56fe2543579bbc6da148ca886f3dd32b", diff --git a/docs/getting_started.md b/docs/getting_started.md index 6104881..46444cc 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -1,15 +1,11 @@ # Getting Started zzz is a networking framework that allows for modularity and flexibility in design. For most use cases, this flexibility is not a requirement and so various defaults are provided. -For this guide, we will assume that you are running on a modern Linux platform and looking to design a service that utilizes HTTP. We will need both `zzz` and `tardy` for this to work. -You will need to match the version of Tardy that zzz is currently using to the version of Tardy you currently use within your program. This will eventually be standardized. - -These are the current latest releases and are compatible. +For this guide, we will assume that you are running on a modern Linux platform and looking to design a service that utilizes HTTP. +This is the current latest release. `zig fetch --save git+https://github.com/mookums/zzz#v0.2.0` -`zig fetch --save git+https://github.com/mookums/tardy#v0.1.0` - ## Hello, World! We can write a quick example that serves out "Hello, World" responses to any client that connects to the server. This example is the same as the one that is provided within the `examples/basic` directory. @@ -24,10 +20,43 @@ const tardy = zzz.tardy; const Tardy = tardy.Tardy(.auto); const Runtime = tardy.Runtime; -const Server = http.Server(.plain, *const i8); -const Router = Server.Router; -const Context = Server.Context; -const Route = Server.Route; +const Server = http.Server; +const Router = http.Router; +const Context = http.Context; +const Route = http.Route; + +fn root_handler(ctx: *Context, id: i8) !void { + const body_fmt = + \\ + \\ + \\ + \\

Hello, World!

+ \\

id: {d}

+ \\ + \\ + ; + const body = try std.fmt.allocPrint(ctx.allocator, body_fmt, .{id}); + // This is the standard response and what you + // will usually be using. This will send to the + // client and then continue to await more requests. + return try ctx.respond(.{ + .status = .OK, + .mime = http.Mime.HTML, + .body = body[0..], + }); +} + +fn echo_handler(ctx: *Context, _: void) !void { + const body = if (ctx.request.body) |b| + try ctx.allocator.dupe(u8, b) + else + ""; + return try ctx.respond(.{ + .status = .OK, + .mime = http.Mime.HTML, + .body = body[0..], + }); +} pub fn main() !void { const host: []const u8 = "0.0.0.0"; @@ -47,57 +76,11 @@ pub fn main() !void { const num: i8 = 12; - var router = Router.init(&num, &[_]Route{ - Route.init("/").get(struct { - fn handler_fn(ctx: *Context) !void { - const body_fmt = - \\ - \\ - \\ - \\

Hello, World!

- \\

id: {d}

- \\ - \\ - ; - - const body = try std.fmt.allocPrint(ctx.allocator, body_fmt, .{ctx.state.*}); - - // This is the standard response and what you - // will usually be using. This will send to the - // client and then continue to await more requests. - try ctx.respond(.{ - .status = .OK, - .mime = http.Mime.HTML, - .body = body[0..], - }); - } - }.handler_fn), - - Route.init("/echo").post(struct { - fn handler_fn(ctx: *Context) !void { - const body = if (ctx.request.body) |b| - try ctx.allocator.dupe(u8, b) - else - ""; - - try ctx.respond(.{ - .status = .OK, - .mime = http.Mime.HTML, - .body = body[0..], - }); - } - }.handler_fn), - }, .{ - .not_found_handler = struct { - fn handler_fn(ctx: *Context) !void { - try ctx.respond(.{ - .status = .@"Not Found", - .mime = http.Mime.HTML, - .body = "Not Found Handler!", - }); - } - }.handler_fn, - }); + var router = try Router.init(allocator, &.{ + Route.init("/").get(num, root_handler).layer(), + Route.init("/echo").post({}, echo_handler).layer(), + }, .{}); + defer router.deinit(allocator); // This provides the entry function into the Tardy runtime. This will run // exactly once inside of each runtime (each thread gets a single runtime). diff --git a/docs/https.md b/docs/https.md index 014b855..6fba17d 100644 --- a/docs/https.md +++ b/docs/https.md @@ -12,20 +12,33 @@ const log = std.log.scoped(.@"examples/tls"); const zzz = @import("zzz"); const http = zzz.HTTP; -const tardy = @import("tardy"); +const tardy = zzz.tardy; const Tardy = tardy.Tardy(.auto); const Runtime = tardy.Runtime; -const Server = http.Server(.{ .tls = .{ - .cert = .{ .file = .{ .path = "./examples/tls/certs/cert.pem" } }, - .key = .{ .file = .{ .path = "./examples/tls/certs/key.pem" } }, - .cert_name = "CERTIFICATE", - .key_name = "EC PRIVATE KEY", -} }, void); +const Server = http.Server; +const Context = http.Context; +const Route = http.Route; +const Router = http.Router; -const Context = Server.Context; -const Route = Server.Route; -const Router = Server.Router; +fn root_handler(ctx: *Context, _: void) !void { + const body = + \\ + \\ + \\ + \\ + \\ + \\ + \\

Hello, World!

+ \\ + \\ + ; + return try ctx.respond(.{ + .status = .OK, + .mime = http.Mime.HTML, + .body = body[0..], + }); +} pub fn main() !void { const host: []const u8 = "0.0.0.0"; @@ -37,49 +50,41 @@ pub fn main() !void { const allocator = gpa.allocator(); defer _ = gpa.deinit(); - var router = Router.init({}, &[_]Route{ - Route.init("/embed/pico.min.css").serve_embedded_file(http.Mime.CSS, @embedFile("embed/pico.min.css")), + var t = try Tardy.init(.{ + .allocator = allocator, + .threading = .single, + }); + defer t.deinit(); - Route.init("/").get(struct { - pub fn handler_fn(ctx: *Context) !void { - const body = - \\ - \\ - \\ - \\ - \\ - \\ - \\

Hello, World!

- \\ - \\ - ; + var router = try Router.init(allocator, &.{ + Route.init("/embed/pico.min.css").serve_embedded_file( + http.Mime.CSS, + @embedFile("embed/pico.min.css"), + ).layer(), - try ctx.respond(.{ - .status = .OK, - .mime = http.Mime.HTML, - .body = body[0..], - }); - } - }.handler_fn), + Route.init("/").get({}, root_handler).layer(), - Route.init("/kill").get(struct { - pub fn handler_fn(ctx: *Context) !void { + Route.init("/kill").get({}, struct { + pub fn handler_fn(ctx: *Context, _: void) !void { ctx.runtime.stop(); } - }.handler_fn), + }.handler_fn).layer(), }, .{}); - - var t = try Tardy.init(.{ - .allocator = allocator, - .threading = .single, - }); - defer t.deinit(); + defer router.deinit(allocator); + router.print_route_tree(); try t.entry( &router, struct { fn entry(rt: *Runtime, r: *const Router) !void { - var server = Server.init(rt.allocator, .{}); + var server = Server.init(rt.allocator, .{ + .security = .{ .tls = .{ + .cert = .{ .file = .{ .path = "./examples/tls/certs/cert.pem" } }, + .key = .{ .file = .{ .path = "./examples/tls/certs/key.pem" } }, + .cert_name = "CERTIFICATE", + .key_name = "EC PRIVATE KEY", + } }, + }); try server.bind(.{ .ip = .{ .host = host, .port = port } }); try server.serve(r, rt); } diff --git a/examples/basic/main.zig b/examples/basic/main.zig index 1e53a66..d189e7b 100644 --- a/examples/basic/main.zig +++ b/examples/basic/main.zig @@ -8,10 +8,43 @@ const tardy = zzz.tardy; const Tardy = tardy.Tardy(.auto); const Runtime = tardy.Runtime; -const Server = http.Server(.plain, *const i8); -const Router = Server.Router; -const Context = Server.Context; -const Route = Server.Route; +const Server = http.Server; +const Router = http.Router; +const Context = http.Context; +const Route = http.Route; + +fn root_handler(ctx: *Context, id: i8) !void { + const body_fmt = + \\ + \\ + \\ + \\

Hello, World!

+ \\

id: {d}

+ \\ + \\ + ; + const body = try std.fmt.allocPrint(ctx.allocator, body_fmt, .{id}); + // This is the standard response and what you + // will usually be using. This will send to the + // client and then continue to await more requests. + return try ctx.respond(.{ + .status = .OK, + .mime = http.Mime.HTML, + .body = body[0..], + }); +} + +fn echo_handler(ctx: *Context, _: void) !void { + const body = if (ctx.request.body) |b| + try ctx.allocator.dupe(u8, b) + else + ""; + return try ctx.respond(.{ + .status = .OK, + .mime = http.Mime.HTML, + .body = body[0..], + }); +} pub fn main() !void { const host: []const u8 = "0.0.0.0"; @@ -31,54 +64,12 @@ pub fn main() !void { const num: i8 = 12; - var router = Router.init(&num, &[_]Route{ - Route.init("/").get(struct { - fn handler_fn(ctx: *Context) !void { - const body_fmt = - \\ - \\ - \\ - \\

Hello, World!

- \\

id: {d}

- \\ - \\ - ; - const body = try std.fmt.allocPrint(ctx.allocator, body_fmt, .{ctx.state.*}); - // This is the standard response and what you - // will usually be using. This will send to the - // client and then continue to await more requests. - try ctx.respond(.{ - .status = .OK, - .mime = http.Mime.HTML, - .body = body[0..], - }); - } - }.handler_fn), - - Route.init("/echo").post(struct { - fn handler_fn(ctx: *Context) !void { - const body = if (ctx.request.body) |b| - try ctx.allocator.dupe(u8, b) - else - ""; - try ctx.respond(.{ - .status = .OK, - .mime = http.Mime.HTML, - .body = body[0..], - }); - } - }.handler_fn), - }, .{ - .not_found_handler = struct { - fn handler_fn(ctx: *Context) !void { - try ctx.respond(.{ - .status = .@"Not Found", - .mime = http.Mime.HTML, - .body = "Not Found Handler!", - }); - } - }.handler_fn, - }); + var router = try Router.init(allocator, &.{ + Route.init("/").get(num, root_handler).layer(), + Route.init("/echo").post({}, echo_handler).layer(), + }, .{}); + defer router.deinit(allocator); + router.print_route_tree(); // This provides the entry function into the Tardy runtime. This will run // exactly once inside of each runtime (each thread gets a single runtime). diff --git a/examples/benchmark/main.zig b/examples/benchmark/main.zig index d4c0bf3..b6c8cd2 100644 --- a/examples/benchmark/main.zig +++ b/examples/benchmark/main.zig @@ -8,16 +8,16 @@ const tardy = zzz.tardy; const Tardy = tardy.Tardy(.auto); const Runtime = tardy.Runtime; -const Server = http.Server(.plain, void); -const Context = Server.Context; -const Route = Server.Route; -const Router = Server.Router; +const Server = http.Server; +const Context = http.Context; +const Route = http.Route; +const Router = http.Router; pub const std_options = .{ .log_level = .err, }; -fn hi_handler(ctx: *Context) !void { +fn hi_handler(ctx: *Context, _: void) !void { const name = ctx.captures[0].string; const body = try std.fmt.allocPrint(ctx.allocator, @@ -39,7 +39,7 @@ fn hi_handler(ctx: *Context) !void { \\ , .{name}); - try ctx.respond(.{ + return try ctx.respond(.{ .status = .OK, .mime = http.Mime.HTML, .body = body, @@ -60,10 +60,12 @@ pub fn main() !void { }); defer t.deinit(); - var router = Router.init({}, &[_]Route{ - Route.init("/").serve_embedded_file(http.Mime.HTML, @embedFile("index.html")), - Route.init("/hi/%s").get(hi_handler), + var router = try Router.init(allocator, &.{ + Route.init("/").serve_embedded_file(http.Mime.HTML, @embedFile("index.html")).layer(), + Route.init("/hi/%s").get({}, hi_handler).layer(), }, .{}); + defer router.deinit(allocator); + router.print_route_tree(); try t.entry( &router, diff --git a/examples/fs/main.zig b/examples/fs/main.zig index 5d46c6f..167e787 100644 --- a/examples/fs/main.zig +++ b/examples/fs/main.zig @@ -8,10 +8,11 @@ const tardy = zzz.tardy; const Tardy = tardy.Tardy(.auto); const Runtime = tardy.Runtime; -const Server = http.Server(.plain, void); -const Router = Server.Router; -const Context = Server.Context; -const Route = Server.Route; +const Server = http.Server; +const Router = http.Router; +const Context = http.Context; +const Route = http.Route; +const FsDir = http.FsDir; pub fn main() !void { const host: []const u8 = "0.0.0.0"; @@ -29,9 +30,9 @@ pub fn main() !void { }); defer t.deinit(); - var router = Router.init({}, &[_]Route{ - Route.init("/").get(struct { - pub fn handler_fn(ctx: *Context) !void { + var router = try Router.init(allocator, &.{ + Route.init("/").get({}, struct { + pub fn handler_fn(ctx: *Context, _: void) !void { const body = \\ \\ @@ -40,22 +41,18 @@ pub fn main() !void { \\ \\ ; - try ctx.respond(.{ + return try ctx.respond(.{ .status = .OK, .mime = http.Mime.HTML, .body = body[0..], }); } - }.handler_fn), + }.handler_fn).layer(), - Route.init("/kill").get(struct { - pub fn handler_fn(ctx: *Context) !void { - ctx.runtime.stop(); - } - }.handler_fn), - - Route.init("/static").serve_fs_dir("./examples/fs/static"), + FsDir.serve("/", "./examples/fs/static"), }, .{}); + defer router.deinit(allocator); + router.print_route_tree(); try t.entry( &router, diff --git a/examples/middleware/main.zig b/examples/middleware/main.zig new file mode 100644 index 0000000..6690f0e --- /dev/null +++ b/examples/middleware/main.zig @@ -0,0 +1,92 @@ +const std = @import("std"); +const log = std.log.scoped(.@"examples/middleware"); + +const zzz = @import("zzz"); +const http = zzz.HTTP; + +const tardy = zzz.tardy; +const Tardy = tardy.Tardy(.auto); +const Runtime = tardy.Runtime; + +const Server = http.Server; +const Router = http.Router; +const Context = http.Context; +const Route = http.Route; +const Next = http.Next; +const Middleware = http.Middleware; + +fn root_handler(ctx: *Context, id: i8) !void { + const body_fmt = + \\ + \\ + \\ + \\

Hello, World!

+ \\

id: {d}

+ \\ + \\ + ; + const body = try std.fmt.allocPrint(ctx.allocator, body_fmt, .{id}); + // This is the standard response and what you + // will usually be using. This will send to the + // client and then continue to await more requests. + return try ctx.respond(.{ + .status = .OK, + .mime = http.Mime.HTML, + .body = body[0..], + }); +} + +fn pre_middleware(next: *Next, _: void) !void { + log.info("pre request middleware: {s}", .{next.ctx.request.uri.?}); + return try next.run(); +} + +fn post_middleware(next: *Next, _: void) !void { + log.info("post request middleware: {s}", .{next.ctx.request.uri.?}); + return try next.run(); +} + +pub fn main() !void { + const host: []const u8 = "0.0.0.0"; + const port: u16 = 9862; + + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = gpa.allocator(); + defer _ = gpa.deinit(); + + // Creating our Tardy instance that + // will spawn our runtimes. + var t = try Tardy.init(.{ + .allocator = allocator, + .threading = .single, + }); + defer t.deinit(); + + const num: i8 = 12; + + var router = try Router.init(allocator, &.{ + Middleware.init().before({}, pre_middleware).after({}, post_middleware).layer(), + Route.init("/").get(num, root_handler).layer(), + }, .{}); + defer router.deinit(allocator); + router.print_route_tree(); + + // This provides the entry function into the Tardy runtime. This will run + // exactly once inside of each runtime (each thread gets a single runtime). + try t.entry( + &router, + struct { + fn entry(rt: *Runtime, r: *const Router) !void { + var server = Server.init(rt.allocator, .{}); + try server.bind(.{ .ip = .{ .host = host, .port = port } }); + try server.serve(r, rt); + } + }.entry, + {}, + struct { + fn exit(rt: *Runtime, _: void) !void { + try Server.clean(rt); + } + }.exit, + ); +} diff --git a/examples/minram/main.zig b/examples/minram/main.zig index 6a1ef5f..d6de930 100644 --- a/examples/minram/main.zig +++ b/examples/minram/main.zig @@ -8,10 +8,10 @@ const tardy = zzz.tardy; const Tardy = tardy.Tardy(.auto); const Runtime = tardy.Runtime; -const Server = http.Server(.plain, void); -const Router = Server.Router; -const Context = Server.Context; -const Route = Server.Route; +const Server = http.Server; +const Router = http.Router; +const Context = http.Context; +const Route = http.Route; pub fn main() !void { const host: []const u8 = "0.0.0.0"; @@ -34,9 +34,9 @@ pub fn main() !void { }); defer t.deinit(); - var router = Router.init({}, &[_]Route{ - Route.init("/").get(struct { - pub fn handler_fn(ctx: *Context) !void { + var router = try Router.init(allocator, &.{ + Route.init("/").get({}, struct { + pub fn handler_fn(ctx: *Context, _: void) !void { const body = \\ \\ @@ -45,14 +45,16 @@ pub fn main() !void { \\ \\ ; - try ctx.respond(.{ + return try ctx.respond(.{ .status = .OK, .mime = http.Mime.HTML, .body = body[0..], }); } - }.handler_fn), + }.handler_fn).layer(), }, .{}); + defer router.deinit(allocator); + router.print_route_tree(); try t.entry( &router, diff --git a/examples/multithread/main.zig b/examples/multithread/main.zig index 96beba8..bae66e0 100644 --- a/examples/multithread/main.zig +++ b/examples/multithread/main.zig @@ -8,12 +8,12 @@ const tardy = zzz.tardy; const Tardy = tardy.Tardy(.auto); const Runtime = tardy.Runtime; -const Server = http.Server(.plain, void); -const Router = Server.Router; -const Context = Server.Context; -const Route = Server.Route; +const Server = http.Server; +const Router = http.Router; +const Context = http.Context; +const Route = http.Route; -fn hi_handler(ctx: *Context) !void { +fn hi_handler(ctx: *Context, _: void) !void { const name = ctx.captures[0].string; const greeting = ctx.queries.get("greeting") orelse "Hi"; @@ -36,27 +36,27 @@ fn hi_handler(ctx: *Context) !void { \\ , .{ greeting, name }); - try ctx.respond(.{ + return try ctx.respond(.{ .status = .OK, .mime = http.Mime.HTML, .body = body, }); } -fn redir_handler(ctx: *Context) !void { +fn redir_handler(ctx: *Context, _: void) !void { ctx.response.headers.put_assume_capacity("Location", "/hi/redirect"); - try ctx.respond(.{ + return try ctx.respond(.{ .status = .@"Permanent Redirect", .mime = http.Mime.HTML, .body = "", }); } -fn post_handler(ctx: *Context) !void { +fn post_handler(ctx: *Context, _: void) !void { log.debug("Body: {s}", .{ctx.request.body orelse ""}); - try ctx.respond(.{ + return try ctx.respond(.{ .status = .OK, .mime = http.Mime.HTML, .body = "", @@ -80,12 +80,14 @@ pub fn main() !void { }); defer t.deinit(); - var router = Router.init({}, &[_]Route{ - Route.init("/").serve_embedded_file(http.Mime.HTML, @embedFile("index.html")), - Route.init("/hi/%s").get(hi_handler), - Route.init("/redirect").get(redir_handler), - Route.init("/post").post(post_handler), + var router = try Router.init(allocator, &.{ + Route.init("/").serve_embedded_file(http.Mime.HTML, @embedFile("index.html")).layer(), + Route.init("/hi/%s").get({}, hi_handler).layer(), + Route.init("/redirect").get({}, redir_handler).layer(), + Route.init("/post").post({}, post_handler).layer(), }, .{}); + defer router.deinit(allocator); + router.print_route_tree(); try t.entry( &router, diff --git a/examples/sse/main.zig b/examples/sse/main.zig index 2b09d40..2c0c186 100644 --- a/examples/sse/main.zig +++ b/examples/sse/main.zig @@ -11,11 +11,11 @@ const Task = tardy.Task; const Broadcast = tardy.Broadcast; const Channel = tardy.Channel; -const Server = http.Server(.plain, *Broadcast(usize)); -const Router = Server.Router; -const Context = Server.Context; -const Route = Server.Route; -const SSE = Server.SSE; +const Server = http.Server; +const Router = http.Router; +const Context = http.Context; +const Route = http.Route; +const SSE = http.SSE; // When using SSE, you end up leaving the various abstractions that zzz has setup for you // and you begin programming more against the tardy runtime. @@ -33,11 +33,11 @@ fn sse_send(_: *Runtime, value_opt: ?*const usize, ctx: *SSEBroadcastContext) !v .{value.*}, ); - try ctx.sse.send(.{ .data = data }, ctx, sse_recv); + return try ctx.sse.send(.{ .data = data }, ctx, sse_recv); } else { const broadcast = ctx.sse.runtime.storage.get_ptr("broadcast", Broadcast(usize)); broadcast.unsubscribe(ctx.channel); - try ctx.sse.context.close(); + return try ctx.sse.context.close(); } } @@ -62,25 +62,25 @@ fn sse_init(rt: *Runtime, success: bool, sse: *SSE) !void { const broadcast = sse.runtime.storage.get_ptr("broadcast", Broadcast(usize)); const context = try sse.allocator.create(SSEBroadcastContext); context.* = .{ .sse = sse, .channel = try broadcast.subscribe(rt, 10) }; - try context.channel.recv(context, sse_send); + return try context.channel.recv(context, sse_send); } -fn sse_handler(ctx: *Context) !void { +fn sse_handler(ctx: *Context, _: void) !void { log.debug("going into sse mode", .{}); - try ctx.to_sse(sse_init); + return try ctx.to_sse(sse_init); } -fn msg_handler(ctx: *Context) !void { +fn msg_handler(ctx: *Context, broadcast: *Broadcast(usize)) !void { log.debug("message handler", .{}); - try ctx.state.send(0); - try ctx.respond(.{ + try broadcast.send(0); + return try ctx.respond(.{ .status = .OK, .mime = http.Mime.HTML, .body = "", }); } -fn kill_handler(ctx: *Context) !void { +fn kill_handler(ctx: *Context, _: void) !void { ctx.runtime.stop(); } @@ -105,12 +105,14 @@ pub fn main() !void { var broadcast = try Broadcast(usize).init(allocator, max_conn); defer broadcast.deinit(); - var router = Router.init(&broadcast, &[_]Route{ - Route.init("/").serve_embedded_file(http.Mime.HTML, @embedFile("index.html")), - Route.init("/kill").get(kill_handler), - Route.init("/stream").get(sse_handler), - Route.init("/message").post(msg_handler), + var router = try Router.init(allocator, &.{ + Route.init("/").serve_embedded_file(http.Mime.HTML, @embedFile("index.html")).layer(), + Route.init("/kill").get({}, kill_handler).layer(), + Route.init("/stream").get({}, sse_handler).layer(), + Route.init("/message").post(&broadcast, msg_handler).layer(), }, .{}); + defer router.deinit(allocator); + router.print_route_tree(); const EntryParams = struct { router: *const Router, diff --git a/examples/tls/main.zig b/examples/tls/main.zig index af37ed1..d99e3f2 100644 --- a/examples/tls/main.zig +++ b/examples/tls/main.zig @@ -8,16 +8,29 @@ const tardy = zzz.tardy; const Tardy = tardy.Tardy(.auto); const Runtime = tardy.Runtime; -const Server = http.Server(.{ .tls = .{ - .cert = .{ .file = .{ .path = "./examples/tls/certs/cert.pem" } }, - .key = .{ .file = .{ .path = "./examples/tls/certs/key.pem" } }, - .cert_name = "CERTIFICATE", - .key_name = "EC PRIVATE KEY", -} }, void); +const Server = http.Server; +const Context = http.Context; +const Route = http.Route; +const Router = http.Router; -const Context = Server.Context; -const Route = Server.Route; -const Router = Server.Router; +fn root_handler(ctx: *Context, _: void) !void { + const body = + \\ + \\ + \\ + \\ + \\ + \\ + \\

Hello, World!

+ \\ + \\ + ; + return try ctx.respond(.{ + .status = .OK, + .mime = http.Mime.HTML, + .body = body[0..], + }); +} pub fn main() !void { const host: []const u8 = "0.0.0.0"; @@ -29,48 +42,41 @@ pub fn main() !void { const allocator = gpa.allocator(); defer _ = gpa.deinit(); - var router = Router.init({}, &[_]Route{ - Route.init("/embed/pico.min.css").serve_embedded_file(http.Mime.CSS, @embedFile("embed/pico.min.css")), - - Route.init("/").get(struct { - pub fn handler_fn(ctx: *Context) !void { - const body = - \\ - \\ - \\ - \\ - \\ - \\ - \\

Hello, World!

- \\ - \\ - ; - try ctx.respond(.{ - .status = .OK, - .mime = http.Mime.HTML, - .body = body[0..], - }); - } - }.handler_fn), - - Route.init("/kill").get(struct { - pub fn handler_fn(ctx: *Context) !void { - ctx.runtime.stop(); - } - }.handler_fn), - }, .{}); - var t = try Tardy.init(.{ .allocator = allocator, .threading = .single, }); defer t.deinit(); + var router = try Router.init(allocator, &.{ + Route.init("/embed/pico.min.css").serve_embedded_file( + http.Mime.CSS, + @embedFile("embed/pico.min.css"), + ).layer(), + + Route.init("/").get({}, root_handler).layer(), + + Route.init("/kill").get({}, struct { + pub fn handler_fn(ctx: *Context, _: void) !void { + ctx.runtime.stop(); + } + }.handler_fn).layer(), + }, .{}); + defer router.deinit(allocator); + router.print_route_tree(); + try t.entry( &router, struct { fn entry(rt: *Runtime, r: *const Router) !void { - var server = Server.init(rt.allocator, .{}); + var server = Server.init(rt.allocator, .{ + .security = .{ .tls = .{ + .cert = .{ .file = .{ .path = "./examples/tls/certs/cert.pem" } }, + .key = .{ .file = .{ .path = "./examples/tls/certs/key.pem" } }, + .cert_name = "CERTIFICATE", + .key_name = "EC PRIVATE KEY", + } }, + }); try server.bind(.{ .ip = .{ .host = host, .port = port } }); try server.serve(r, rt); } diff --git a/examples/unix/main.zig b/examples/unix/main.zig index 1cc6bf8..bd58870 100644 --- a/examples/unix/main.zig +++ b/examples/unix/main.zig @@ -8,17 +8,17 @@ const tardy = zzz.tardy; const Tardy = tardy.Tardy(.auto); const Runtime = tardy.Runtime; -const Server = http.Server(.plain, void); -const Context = Server.Context; -const Route = Server.Route; -const Router = Server.Router; +const Server = http.Server; +const Context = http.Context; +const Route = http.Route; +const Router = http.Router; pub const std_options = .{ .log_level = .err, }; -pub fn root_handler(ctx: *Context) !void { - try ctx.respond(.{ +pub fn root_handler(ctx: *Context, _: void) !void { + return try ctx.respond(.{ .status = .OK, .mime = http.Mime.HTML, .body = "This is an HTTP benchmark", @@ -36,7 +36,11 @@ pub fn main() !void { }); defer t.deinit(); - var router = Router.init({}, &[_]Route{Route.init("/").get(root_handler)}, .{}); + var router = try Router.init(allocator, &.{ + Route.init("/").get({}, root_handler).layer(), + }, .{}); + defer router.deinit(allocator); + router.print_route_tree(); try t.entry( &router, diff --git a/examples/valgrind/main.zig b/examples/valgrind/main.zig index cfa4e27..7c9dc5d 100644 --- a/examples/valgrind/main.zig +++ b/examples/valgrind/main.zig @@ -7,10 +7,10 @@ const tardy = zzz.tardy; const Tardy = tardy.Tardy(.auto); const Runtime = tardy.Runtime; -const Server = http.Server(.plain, void); -const Router = Server.Router; -const Context = Server.Context; -const Route = Server.Route; +const Server = http.Server; +const Router = http.Router; +const Context = http.Context; +const Route = http.Route; pub fn main() !void { const host: []const u8 = "0.0.0.0"; @@ -20,9 +20,15 @@ pub fn main() !void { const allocator = gpa.allocator(); defer _ = gpa.deinit(); - var router = Router.init({}, &[_]Route{ - Route.init("/").get(struct { - pub fn handler_fn(ctx: *Context) !void { + var t = try Tardy.init(.{ + .allocator = allocator, + .threading = .single, + }); + defer t.deinit(); + + var router = try Router.init(allocator, &.{ + Route.init("/").get({}, struct { + pub fn handler_fn(ctx: *Context, _: void) !void { const body = \\ \\ @@ -31,26 +37,22 @@ pub fn main() !void { \\ \\ ; - try ctx.respond(.{ + return try ctx.respond(.{ .status = .OK, .mime = http.Mime.HTML, .body = body[0..], }); } - }.handler_fn), + }.handler_fn).layer(), - Route.init("/kill").get(struct { - pub fn handler_fn(ctx: *Context) !void { + Route.init("/kill").get({}, struct { + pub fn handler_fn(ctx: *Context, _: void) !void { ctx.runtime.stop(); } - }.handler_fn), + }.handler_fn).layer(), }, .{}); - - var t = try Tardy.init(.{ - .allocator = allocator, - .threading = .single, - }); - defer t.deinit(); + defer router.deinit(allocator); + router.print_route_tree(); try t.entry( &router, diff --git a/flake.nix b/flake.nix index 4da8abd..dc7a32c 100644 --- a/flake.nix +++ b/flake.nix @@ -1,6 +1,5 @@ { - description = - "a framework for writing performant and reliable networked services"; + description = "a framework for writing performant and reliable networked services"; inputs = { nixpkgs.url = "github:nixos/nixpkgs/release-24.11"; @@ -8,17 +7,21 @@ flake-utils.url = "github:numtide/flake-utils"; }; - outputs = { nixpkgs, iguana, flake-utils, ... }: - flake-utils.lib.eachDefaultSystem (system: - let - pkgs = import nixpkgs { inherit system; }; - iguanaLib = iguana.lib.${system}; - in { - devShells.default = iguanaLib.mkShell { - zigVersion = "0.13.0"; - withZls = true; + outputs = { + nixpkgs, + iguana, + flake-utils, + ... + }: + flake-utils.lib.eachDefaultSystem (system: let + pkgs = import nixpkgs {inherit system;}; + iguanaLib = iguana.lib.${system}; + in { + devShells.default = iguanaLib.mkShell { + zigVersion = "0.13.0"; + withZls = true; - extraPackages = with pkgs; [ openssl wrk ]; - }; - }); + extraPackages = with pkgs; [openssl wrk]; + }; + }); } diff --git a/src/core/case_string_map.zig b/src/core/case_string_map.zig index 59dfeea..78f98c1 100644 --- a/src/core/case_string_map.zig +++ b/src/core/case_string_map.zig @@ -88,7 +88,7 @@ test "CaseStringMap: Add Stuff" { try csm.put("Content-Length", "100"); csm.put_assume_capacity("Host", "localhost:9999"); - const content_length = csm.get("content-length"); + const content_length = csm.get("Content-length"); try testing.expect(content_length != null); const host = csm.get("host"); diff --git a/src/core/zc_buffer.zig b/src/core/zc_buffer.zig index bf16c1b..95449ec 100644 --- a/src/core/zc_buffer.zig +++ b/src/core/zc_buffer.zig @@ -85,9 +85,9 @@ pub const ZeroCopyBuffer = struct { pub fn shrink_clear_and_free(self: *ZeroCopyBuffer, new_size: usize) !void { assert(new_size <= self.len); - if (!self.allocator.resize(self.ptr[0..self.len], new_size)) { + if (!self.allocator.resize(self.ptr[0..self.capacity], new_size)) { const slice = try self.allocator.realloc( - self.ptr[0..self.len], + self.ptr[0..self.capacity], new_size, ); self.ptr = slice.ptr; diff --git a/src/http/context.zig b/src/http/context.zig index e3d81ed..6fcdd16 100644 --- a/src/http/context.zig +++ b/src/http/context.zig @@ -13,161 +13,170 @@ const Request = @import("request.zig").Request; const Response = @import("response.zig").Response; const ResponseSetOptions = Response.ResponseSetOptions; const Mime = @import("mime.zig").Mime; -const _SSE = @import("sse.zig").SSE; +const SSE = @import("sse.zig").SSE; +const MiddlewareWithData = @import("router/middleware.zig").MiddlewareWithData; +const Next = @import("router/middleware.zig").Next; const Runtime = @import("tardy").Runtime; const TaskFn = @import("tardy").TaskFn; +const Server = @import("server.zig").Server; + const raw_respond = @import("server.zig").raw_respond; // Context is dependent on the server that gets created. -pub fn Context(comptime Server: type, comptime AppState: type) type { - return struct { - const Self = @This(); - const SSE = _SSE(Server, AppState); - allocator: std.mem.Allocator, - runtime: *Runtime, - /// Custom user-data state. - state: AppState, - /// The matched route instance. - route: *const Route(Server, AppState), - /// The Request that triggered this handler. - request: *const Request, - /// The Response that will be returned. - response: *Response, - captures: []Capture, - queries: *QueryMap, - provision: *Provision, - triggered: bool = false, - - pub fn to_sse(self: *Self, then: TaskFn(bool, *SSE)) !void { - const sse = try self.allocator.create(SSE); - sse.* = .{ - .context = self, - .runtime = self.runtime, - .allocator = self.allocator, - }; - - try self.respond_headers_only( - .{ - .status = .OK, - .mime = Mime.generate( - "text/event-stream", - "sse", - "Server-Sent Events", - ), - }, - null, - sse, - then, - ); - } - - pub fn close(self: *Self) !void { - self.provision.job = .close; - try self.runtime.net.close( - self.provision, - Server.close_task, - self.provision.socket, - ); - } - - pub fn send_then( - self: *Self, - data: []const u8, - ctx: anytype, - then: TaskFn(bool, @TypeOf(ctx)), - ) !void { - const pslice = Pseudoslice.init(data, "", self.provision.buffer); - - const first_chunk = try Server.prepare_send( - self.runtime, - self.provision, - .{ - .other = .{ - .func = then, - .ctx = ctx, - }, +pub const Context = struct { + const Self = @This(); + allocator: std.mem.Allocator, + runtime: *Runtime, + /// The Request that triggered this handler. + request: *const Request, + /// The Response that will be returned. + response: *Response, + captures: []Capture, + queries: *QueryMap, + provision: *Provision, + next: *Next, + triggered: bool = false, + + pub fn to_sse(self: *Self, then: TaskFn(bool, *SSE)) !void { + const sse = try self.allocator.create(SSE); + sse.* = .{ + .context = self, + .runtime = self.runtime, + .allocator = self.allocator, + }; + + try self.respond_headers_only( + .{ + .status = .OK, + .mime = Mime.generate( + "text/event-stream", + "sse", + "Server-Sent Events", + ), + }, + null, + sse, + then, + ); + } + + pub fn close(self: *Self) !void { + self.provision.job = .close; + try self.runtime.net.close( + self.provision, + Server.close_task, + self.provision.socket, + ); + } + + pub fn send_then( + self: *Self, + data: []const u8, + ctx: anytype, + then: TaskFn(bool, @TypeOf(ctx)), + ) !void { + const pslice = Pseudoslice.init(data, "", self.provision.buffer); + + const first_chunk = try Server.prepare_send( + self.runtime, + self.provision, + .{ + .other = .{ + .func = then, + .ctx = ctx, }, - pslice, - ); - - try self.runtime.net.send( - self.provision, - Server.send_then_other_task, - self.provision.socket, - first_chunk, - ); + }, + pslice, + ); + + try self.runtime.net.send( + self.provision, + Server.send_then_other_task, + self.provision.socket, + first_chunk, + ); + } + + pub fn send_then_recv(self: *Self, data: []const u8) !void { + const pslice = Pseudoslice.init(data, "", self.provision.buffer); + + const first_chunk = try Server.prepare_send( + self.runtime, + self.provision, + .recv, + pslice, + ); + + try self.runtime.net.send( + self.provision, + Server.send_then_recv_task, + self.provision.socket, + first_chunk, + ); + } + + // This will respond with the headers only. + // You will be in charge of sending the body. + pub fn respond_headers_only( + self: *Self, + options: ResponseSetOptions, + content_length: ?usize, + ctx: anytype, + then: TaskFn(bool, @TypeOf(ctx)), + ) !void { + assert(!self.triggered); + self.triggered = true; + + // the body should not be set. + assert(options.body == null); + self.response.set(options); + + const headers = try self.provision.response.headers_into_buffer( + self.provision.buffer, + content_length, + ); + + try self.send_then(headers, ctx, then); + } + + pub fn respond_without_middleware(self: *Self) !void { + const body = self.response.body orelse ""; + const headers = try self.provision.response.headers_into_buffer( + self.provision.buffer, + @intCast(body.len), + ); + const pslice = Pseudoslice.init(headers, body, self.provision.buffer); + + const first_chunk = try Server.prepare_send( + self.runtime, + self.provision, + .recv, + pslice, + ); + + try self.runtime.net.send( + self.provision, + Server.send_then_recv_task, + self.provision.socket, + first_chunk, + ); + } + + /// This is your standard response. + pub fn respond(self: *Self, options: ResponseSetOptions) !void { + assert(!self.triggered); + self.triggered = true; + self.response.set(options); + + // If we have a post chain, iterate through it. + if (self.next.post_chain.len > 0) { + self.next.stage = .post; + try self.next.run(); + return; } - pub fn send_then_recv(self: *Self, data: []const u8) !void { - const pslice = Pseudoslice.init(data, "", self.provision.buffer); - - const first_chunk = try Server.prepare_send( - self.runtime, - self.provision, - .recv, - pslice, - ); - - try self.runtime.net.send( - self.provision, - Server.send_then_recv_task, - self.provision.socket, - first_chunk, - ); - } - - // This will respond with the headers only. - // You will be in charge of sending the body. - pub fn respond_headers_only( - self: *Self, - options: ResponseSetOptions, - content_length: ?usize, - ctx: anytype, - then: TaskFn(bool, @TypeOf(ctx)), - ) !void { - assert(!self.triggered); - self.triggered = true; - - // the body should not be set. - assert(options.body == null); - self.response.set(options); - - const headers = try self.provision.response.headers_into_buffer( - self.provision.buffer, - content_length, - ); - - try self.send_then(headers, ctx, then); - } - - /// This is your standard response. - pub fn respond(self: *Self, options: ResponseSetOptions) !void { - assert(!self.triggered); - self.triggered = true; - self.response.set(options); - - const body = options.body orelse ""; - const headers = try self.provision.response.headers_into_buffer( - self.provision.buffer, - @intCast(body.len), - ); - const pslice = Pseudoslice.init(headers, body, self.provision.buffer); - - const first_chunk = try Server.prepare_send( - self.runtime, - self.provision, - .recv, - pslice, - ); - - try self.runtime.net.send( - self.provision, - Server.send_then_recv_task, - self.provision.socket, - first_chunk, - ); - } - }; -} + try self.respond_without_middleware(); + } +}; diff --git a/src/http/lib.zig b/src/http/lib.zig index 1e9e40b..42f1223 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -6,6 +6,18 @@ pub const Mime = @import("mime.zig").Mime; pub const Date = @import("date.zig").Date; pub const Headers = @import("../core/case_string_map.zig").CaseStringMap([]const u8); +pub const Router = @import("router.zig").Router; +pub const Route = @import("router/route.zig").Route; +pub const Layer = @import("router/layer.zig").Layer; + +pub const Context = @import("context.zig").Context; +pub const Middleware = @import("router/middleware.zig").Middleware; +pub const MiddlewareFn = @import("router/middleware.zig").MiddlewareFn; +pub const Next = @import("router/middleware.zig").Next; +pub const SSE = @import("sse.zig").SSE; + +pub const FsDir = @import("router/fs_dir.zig").FsDir; + pub const Server = @import("server.zig").Server; pub const HTTPError = error{ diff --git a/src/http/response.zig b/src/http/response.zig index 3759ae2..fac5578 100644 --- a/src/http/response.zig +++ b/src/http/response.zig @@ -30,6 +30,7 @@ pub const Response = struct { self.status = null; self.mime = null; self.body = null; + self.headers.clear(); } pub const ResponseSetOptions = struct { diff --git a/src/http/router.zig b/src/http/router.zig index c2e0a3a..9417db7 100644 --- a/src/http/router.zig +++ b/src/http/router.zig @@ -2,66 +2,77 @@ const std = @import("std"); const log = std.log.scoped(.@"zzz/http/router"); const assert = std.debug.assert; -const _Route = @import("router/route.zig").Route; +const Layer = @import("router/layer.zig").Layer; +const Route = @import("router/route.zig").Route; +const Bundle = @import("router/bundle.zig").Bundle; +const TypedHandlerFn = @import("router/route.zig").TypedHandlerFn; +const FoundBundle = @import("router/routing_trie.zig").FoundBundle; const Capture = @import("router/routing_trie.zig").Capture; const Request = @import("request.zig").Request; const Response = @import("response.zig").Response; const Mime = @import("mime.zig").Mime; -const _Context = @import("context.zig").Context; +const Context = @import("context.zig").Context; -const _RoutingTrie = @import("router/routing_trie.zig").RoutingTrie; +const RoutingTrie = @import("router/routing_trie.zig").RoutingTrie; const QueryMap = @import("router/routing_trie.zig").QueryMap; /// Default not found handler: send a plain text response. -pub fn default_not_found_handler(comptime Server: type, comptime AppState: type) _Route(Server, AppState).HandlerFn { - const Context = _Context(Server, AppState); - - return struct { - fn not_found_handler(ctx: *Context) !void { - try ctx.respond(.{ - .status = .@"Not Found", - .mime = Mime.TEXT, - .body = "Not Found", - }); - } - }.not_found_handler; -} +pub const default_not_found_handler = struct { + fn not_found_handler(ctx: *Context, _: void) !void { + return try ctx.respond(.{ + .status = .@"Not Found", + .mime = Mime.TEXT, + .body = "Not Found", + }); + } +}.not_found_handler; /// Initialize a router with the given routes. -pub fn Router(comptime Server: type, comptime AppState: type) type { - return struct { - const Self = @This(); - const RoutingTrie = _RoutingTrie(Server, AppState); - const FoundRoute = RoutingTrie.FoundRoute; - const Route = _Route(Server, AppState); - const Context = _Context(Server, AppState); +pub const Router = struct { + /// Router configuration structure. + pub const Configuration = struct { + not_found: TypedHandlerFn(void) = default_not_found_handler, + }; + + routes: RoutingTrie, + configuration: Configuration, - /// Router configuration structure. - pub const Configuration = struct { - not_found_handler: Route.HandlerFn = default_not_found_handler(Server, AppState), + pub fn init( + allocator: std.mem.Allocator, + layers: []const Layer, + configuration: Configuration, + ) !Router { + const self = Router{ + .routes = try RoutingTrie.init(allocator, layers), + .configuration = configuration, }; - routes: RoutingTrie, - not_found_route: Route, - state: AppState, + return self; + } - pub fn init(state: AppState, comptime _routes: []const Route, comptime configuration: Configuration) Self { - const self = Self{ - // Initialize the routing tree from the given routes. - .routes = comptime RoutingTrie.init(_routes), - .not_found_route = comptime Route.init("").all(configuration.not_found_handler), - .state = state, - }; + pub fn deinit(self: *Router, allocator: std.mem.Allocator) void { + self.routes.deinit(allocator); + } - return self; - } + pub fn print_route_tree(self: *const Router) void { + self.routes.print(); + } - pub fn get_route_from_host(self: Self, path: []const u8, captures: []Capture, queries: *QueryMap) !FoundRoute { - return try self.routes.get_route(path, captures, queries) orelse { - queries.clear(); - return FoundRoute{ .route = self.not_found_route, .captures = captures[0..0], .queries = queries }; + pub fn get_bundle_from_host( + self: *const Router, + path: []const u8, + captures: []Capture, + queries: *QueryMap, + ) !FoundBundle { + return try self.routes.get_route(path, captures, queries) orelse { + queries.clear(); + const not_found_bundle: Bundle = .{ + .pre = &.{}, + .route = Route.init("").all({}, self.configuration.not_found), + .post = &.{}, }; - } - }; -} + return .{ .bundle = not_found_bundle, .captures = captures[0..0], .queries = queries }; + }; + } +}; diff --git a/src/http/router/bundle.zig b/src/http/router/bundle.zig new file mode 100644 index 0000000..bf3f8e3 --- /dev/null +++ b/src/http/router/bundle.zig @@ -0,0 +1,8 @@ +const MiddlewareWithData = @import("middleware.zig").MiddlewareWithData; +const Route = @import("route.zig").Route; + +pub const Bundle = struct { + pre: []const MiddlewareWithData, + route: Route, + post: []const MiddlewareWithData, +}; diff --git a/src/http/router/fs_dir.zig b/src/http/router/fs_dir.zig index 0eb1779..987c11a 100644 --- a/src/http/router/fs_dir.zig +++ b/src/http/router/fs_dir.zig @@ -2,10 +2,12 @@ const std = @import("std"); const log = std.log.scoped(.@"zzz/http/router"); const assert = std.debug.assert; +const Route = @import("route.zig").Route; +const Layer = @import("layer.zig").Layer; const Request = @import("../request.zig").Request; const Response = @import("../response.zig").Response; const Mime = @import("../mime.zig").Mime; -const _Context = @import("../context.zig").Context; +const Context = @import("../context.zig").Context; const OpenResult = @import("tardy").OpenResult; const ReadResult = @import("tardy").ReadResult; @@ -16,107 +18,167 @@ const Runtime = @import("tardy").Runtime; const Stat = @import("tardy").Stat; const Cross = @import("tardy").Cross; -pub fn FsDir(Server: type, AppState: type) type { - return struct { - const Context = _Context(Server, AppState); - - const FileProvision = struct { - mime: Mime, - context: *Context, - request: *const Request, - response: *Response, - fd: std.posix.fd_t, - file_size: u64, - rd_offset: usize, - current_length: usize, - buffer: []u8, - }; +pub const FsDir = struct { + const FileProvision = struct { + mime: Mime, + context: *Context, + request: *const Request, + response: *Response, + fd: std.posix.fd_t, + file_size: u64, + rd_offset: usize, + current_length: usize, + buffer: []u8, + }; + + /// Serve a Filesystem Directory as a Layer. + pub fn serve(comptime url_path: []const u8, comptime dir_path: []const u8) Layer { + const url_with_match_all = comptime std.fmt.comptimePrint( + "{s}/%r", + .{std.mem.trimRight(u8, url_path, "/")}, + ); - fn open_file_task(rt: *Runtime, result: OpenResult, provision: *FileProvision) !void { - errdefer provision.context.respond(.{ - .status = .@"Internal Server Error", + return Route.init(url_with_match_all).get({}, struct { + fn fs_dir_handler(ctx: *Context, _: void) !void { + try inner_handler(ctx, dir_path); + } + }.fs_dir_handler).layer(); + } + + fn open_file_task(rt: *Runtime, result: OpenResult, provision: *FileProvision) !void { + errdefer provision.context.respond(.{ + .status = .@"Internal Server Error", + .mime = Mime.HTML, + .body = "", + }) catch unreachable; + + const fd = result.unwrap() catch |e| { + log.warn("file not found | {}", .{e}); + try provision.context.respond(.{ + .status = .@"Not Found", .mime = Mime.HTML, - .body = "", - }) catch unreachable; + .body = "File Not Found", + }); + return; + }; + provision.fd = fd; + + return try rt.fs.stat(provision, stat_file_task, fd); + } + + fn stat_file_task(rt: *Runtime, result: StatResult, provision: *FileProvision) !void { + errdefer provision.context.respond(.{ + .status = .@"Internal Server Error", + .mime = Mime.HTML, + .body = "", + }) catch unreachable; + + const stat = result.unwrap() catch |e| { + log.warn("stat on fd={d} failed | {}", .{ provision.fd, e }); + try provision.context.respond(.{ + .status = .@"Not Found", + .mime = Mime.HTML, + .body = "File Not Found", + }); + return; + }; - const fd = result.unwrap() catch |e| { - log.warn("file not found | {}", .{e}); - try provision.context.respond(.{ - .status = .@"Not Found", - .mime = Mime.HTML, - .body = "File Not Found", - }); - return; - }; - provision.fd = fd; + // Set file size. + provision.file_size = stat.size; + log.debug("file size: {d}", .{provision.file_size}); - try rt.fs.stat(provision, stat_file_task, fd); + // generate the etag and attach it to the response. + var hash = std.hash.Wyhash.init(0); + hash.update(std.mem.asBytes(&stat.size)); + if (stat.modified) |modified| { + hash.update(std.mem.asBytes(&modified.seconds)); + hash.update(std.mem.asBytes(&modified.nanos)); } - - fn stat_file_task(rt: *Runtime, result: StatResult, provision: *FileProvision) !void { - errdefer provision.context.respond(.{ - .status = .@"Internal Server Error", - .mime = Mime.HTML, - .body = "", - }) catch unreachable; - - const stat = result.unwrap() catch |e| { - log.warn("stat on fd={d} failed | {}", .{ provision.fd, e }); - try provision.context.respond(.{ - .status = .@"Not Found", + const etag_hash = hash.final(); + + const calc_etag = try std.fmt.allocPrint( + provision.context.allocator, + "\"{d}\"", + .{etag_hash}, + ); + + provision.response.headers.put_assume_capacity("ETag", calc_etag); + + // If we have an ETag on the request... + if (provision.request.headers.get("If-None-Match")) |etag| { + if (std.mem.eql(u8, etag, calc_etag)) { + // If the ETag matches. + return try provision.context.respond(.{ + .status = .@"Not Modified", .mime = Mime.HTML, - .body = "File Not Found", + .body = "", }); - return; - }; - - // Set file size. - provision.file_size = stat.size; - log.debug("file size: {d}", .{provision.file_size}); - - // generate the etag and attach it to the response. - var hash = std.hash.Wyhash.init(0); - hash.update(std.mem.asBytes(&stat.size)); - if (stat.modified) |modified| { - hash.update(std.mem.asBytes(&modified.seconds)); - hash.update(std.mem.asBytes(&modified.nanos)); } - const etag_hash = hash.final(); - - const calc_etag = try std.fmt.allocPrint( - provision.context.allocator, - "\"{d}\"", - .{etag_hash}, - ); + } - provision.response.headers.put_assume_capacity("ETag", calc_etag); + provision.response.set(.{ + .status = .OK, + .mime = provision.mime, + .body = null, + }); + + const headers = try provision.response.headers_into_buffer( + provision.buffer, + @intCast(stat.size), + ); + provision.current_length = headers.len; + + return try rt.fs.read( + provision, + read_file_task, + provision.fd, + provision.buffer[provision.current_length..], + provision.rd_offset, + ); + } + + fn read_file_task(rt: *Runtime, result: ReadResult, provision: *FileProvision) !void { + errdefer { + std.posix.close(provision.fd); + provision.context.close() catch unreachable; + } - // If we have an ETag on the request... - if (provision.request.headers.get("If-None-Match")) |etag| { - if (std.mem.eql(u8, etag, calc_etag)) { - // If the ETag matches. - try provision.context.respond(.{ - .status = .@"Not Modified", - .mime = Mime.HTML, - .body = "", + const length = result.unwrap() catch |e| { + switch (e) { + error.EndOfFile => { + log.debug("done streaming file | rd off: {d} | f size: {d} ", .{ + provision.rd_offset, + provision.file_size, }); - return; - } + + std.posix.close(provision.fd); + return try provision.context.send_then_recv( + provision.buffer[0..provision.current_length], + ); + }, + else => { + log.warn("reading on fd={d} failed | {}", .{ provision.fd, e }); + std.posix.close(provision.fd); + return try provision.context.close(); + }, } + }; - provision.response.set(.{ - .status = .OK, - .mime = provision.mime, - .body = null, - }); + const length_as_usize: usize = @intCast(length); + provision.rd_offset += length_as_usize; + provision.current_length += length_as_usize; + log.debug("current offset: {d} | fd: {}", .{ provision.rd_offset, provision.fd }); - const headers = try provision.response.headers_into_buffer( - provision.buffer, - @intCast(stat.size), + assert(provision.rd_offset <= length_as_usize); + assert(provision.current_length <= provision.buffer.len); + if (provision.current_length == provision.buffer.len) { + return try provision.context.send_then( + provision.buffer[0..provision.current_length], + provision, + send_file_task, ); - provision.current_length = headers.len; - - try rt.fs.read( + } else { + return try rt.fs.read( provision, read_file_task, provision.fd, @@ -124,162 +186,106 @@ pub fn FsDir(Server: type, AppState: type) type { provision.rd_offset, ); } + } - fn read_file_task(rt: *Runtime, result: ReadResult, provision: *FileProvision) !void { - errdefer { - std.posix.close(provision.fd); - provision.context.close() catch unreachable; - } - - const length = result.unwrap() catch |e| { - switch (e) { - error.EndOfFile => { - log.debug("done streaming file | rd off: {d} | f size: {d} ", .{ - provision.rd_offset, - provision.file_size, - }); - - std.posix.close(provision.fd); - try provision.context.send_then_recv( - provision.buffer[0..provision.current_length], - ); - return; - }, - else => { - log.warn("reading on fd={d} failed | {}", .{ provision.fd, e }); - std.posix.close(provision.fd); - try provision.context.close(); - return; - }, - } - }; - - const length_as_usize: usize = @intCast(length); - provision.rd_offset += length_as_usize; - provision.current_length += length_as_usize; - log.debug("current offset: {d} | fd: {}", .{ provision.rd_offset, provision.fd }); - - assert(provision.rd_offset <= length_as_usize); - assert(provision.current_length <= provision.buffer.len); - if (provision.current_length == provision.buffer.len) { - try provision.context.send_then( - provision.buffer[0..provision.current_length], - provision, - send_file_task, - ); - } else { - try rt.fs.read( - provision, - read_file_task, - provision.fd, - provision.buffer[provision.current_length..], - provision.rd_offset, - ); - } + fn send_file_task(rt: *Runtime, success: bool, provision: *FileProvision) !void { + errdefer { + std.posix.close(provision.fd); + provision.context.close() catch unreachable; } - fn send_file_task(rt: *Runtime, success: bool, provision: *FileProvision) !void { - errdefer { - std.posix.close(provision.fd); - provision.context.close() catch unreachable; - } - - if (!success) { - log.warn("send file stream failed!", .{}); - std.posix.close(provision.fd); - return; - } - - // reset current length - provision.current_length = 0; + if (!success) { + log.warn("send file stream failed!", .{}); + std.posix.close(provision.fd); + return; + } - // continue streaming.. - try rt.fs.read( - provision, - read_file_task, - provision.fd, - provision.buffer, - provision.rd_offset, - ); + // reset current length + provision.current_length = 0; + + // continue streaming.. + return try rt.fs.read( + provision, + read_file_task, + provision.fd, + provision.buffer, + provision.rd_offset, + ); + } + + fn inner_handler(ctx: *Context, dir_path: []const u8) !void { + if (ctx.captures.len == 0) { + return try ctx.respond(.{ + .status = .@"Not Found", + .mime = Mime.HTML, + .body = "", + }); } - pub fn handler_fn(ctx: *Context, dir_path: []const u8) !void { - if (ctx.captures.len == 0) { - try ctx.respond(.{ + //TODO Can we do this once and for all at initialization? + // Resolving the base directory. + const resolved_dir = try std.fs.path.resolve(ctx.allocator, &[_][]const u8{dir_path}); + defer ctx.allocator.free(resolved_dir); + + // Resolving the requested file. + const search_path = ctx.captures[0].remaining; + const resolved_file_path = blk: { + // This appears to be leaking BUT the ctx.allocator is an + // arena so it does get cleaned up eventually. + const file_path = std.fs.path.resolve( + ctx.allocator, + &[_][]const u8{ dir_path, search_path }, + ) catch { + return try ctx.respond(.{ .status = .@"Not Found", .mime = Mime.HTML, .body = "", }); - return; - } - - //TODO Can we do this once and for all at initialization? - // Resolving the base directory. - const resolved_dir = try std.fs.path.resolve(ctx.allocator, &[_][]const u8{dir_path}); - defer ctx.allocator.free(resolved_dir); - - // Resolving the requested file. - const search_path = ctx.captures[0].remaining; - const resolved_file_path = blk: { - // This appears to be leaking BUT the ctx.allocator is an - // arena so it does get cleaned up eventually. - const file_path = std.fs.path.resolve( - ctx.allocator, - &[_][]const u8{ dir_path, search_path }, - ) catch { - try ctx.respond(.{ - .status = .@"Not Found", - .mime = Mime.HTML, - .body = "", - }); - return; - }; - const file_path_z = try ctx.allocator.dupeZ(u8, file_path); - ctx.allocator.free(file_path); - break :blk file_path_z; }; + const file_path_z = try ctx.allocator.dupeZ(u8, file_path); + ctx.allocator.free(file_path); + break :blk file_path_z; + }; - // The resolved path should always start like the base directory path, - // otherwise it means that the user is trying to access something forbidden. - if (!std.mem.startsWith(u8, resolved_file_path, resolved_dir)) { - defer ctx.allocator.free(resolved_file_path); - try ctx.respond(.{ - .status = .Forbidden, - .mime = Mime.HTML, - .body = "", - }); - return; - } + // The resolved path should always start like the base directory path, + // otherwise it means that the user is trying to access something forbidden. + if (!std.mem.startsWith(u8, resolved_file_path, resolved_dir)) { + defer ctx.allocator.free(resolved_file_path); + return try ctx.respond(.{ + .status = .Forbidden, + .mime = Mime.HTML, + .body = "", + }); + } - const extension_start = std.mem.lastIndexOfScalar(u8, search_path, '.'); - const mime: Mime = blk: { - if (extension_start) |start| { - if (search_path.len - start == 0) break :blk Mime.BIN; - break :blk Mime.from_extension(search_path[start + 1 ..]); - } else { - break :blk Mime.BIN; - } - }; + const extension_start = std.mem.lastIndexOfScalar(u8, search_path, '.'); + const mime: Mime = blk: { + if (extension_start) |start| { + if (search_path.len - start == 0) break :blk Mime.BIN; + break :blk Mime.from_extension(search_path[start + 1 ..]); + } else { + break :blk Mime.BIN; + } + }; - const provision = try ctx.allocator.create(FileProvision); - - provision.* = .{ - .mime = mime, - .context = ctx, - .request = ctx.request, - .response = ctx.response, - .fd = Cross.fd.INVALID_FD, - .file_size = 0, - .rd_offset = 0, - .current_length = 0, - .buffer = ctx.provision.buffer, - }; + const provision = try ctx.allocator.create(FileProvision); + + provision.* = .{ + .mime = mime, + .context = ctx, + .request = ctx.request, + .response = ctx.response, + .fd = Cross.fd.INVALID_FD, + .file_size = 0, + .rd_offset = 0, + .current_length = 0, + .buffer = ctx.provision.buffer, + }; - try ctx.runtime.fs.open( - provision, - open_file_task, - resolved_file_path, - ); - } - }; -} + return try ctx.runtime.fs.open( + provision, + open_file_task, + resolved_file_path, + ); + } +}; diff --git a/src/http/router/layer.zig b/src/http/router/layer.zig new file mode 100644 index 0000000..415d19b --- /dev/null +++ b/src/http/router/layer.zig @@ -0,0 +1,18 @@ +const Route = @import("route.zig").Route; +const MiddlewareWithData = @import("middleware.zig").MiddlewareWithData; + +const MiddlewarePair = struct { + pre: MiddlewareWithData, + post: MiddlewareWithData, +}; + +pub const Layer = union(enum) { + /// Route + route: Route, + /// Pre-Route Middleware + pre: MiddlewareWithData, + /// Post-Route Middleware + post: MiddlewareWithData, + /// Pair of Middleware + pair: MiddlewarePair, +}; diff --git a/src/http/router/middleware.zig b/src/http/router/middleware.zig new file mode 100644 index 0000000..6f16d25 --- /dev/null +++ b/src/http/router/middleware.zig @@ -0,0 +1,120 @@ +const std = @import("std"); +const assert = std.debug.assert; + +const Runtime = @import("tardy").Runtime; + +const wrap = @import("tardy").wrap; +const Task = @import("tardy").TaskFn; + +const Pseudoslice = @import("../../core/pseudoslice.zig").Pseudoslice; +const Server = @import("../server.zig").Server; + +const Route = @import("route.zig").Route; +const HandlerWithData = @import("route.zig").HandlerWithData; +const Layer = @import("layer.zig").Layer; +const Context = @import("../context.zig").Context; + +const Stage = enum { pre, post }; + +const PreChain = struct { + chain: []const MiddlewareWithData, + handler: HandlerWithData, +}; + +pub const Next = struct { + const Self = @This(); + stage: Stage, + pre_chain: PreChain, + post_chain: []const MiddlewareWithData, + ctx: *Context, + + fn next_middleware_task(_: *Runtime, _: void, n: *Self) !void { + switch (n.stage) { + .pre => { + assert(n.pre_chain.chain.len > 0); + const next_middleware = n.pre_chain.chain[0]; + n.pre_chain.chain = n.pre_chain.chain[1..]; + try @call(.auto, next_middleware.middleware, .{ n, next_middleware.data }); + }, + .post => { + assert(n.post_chain.len > 0); + const next_middleware = n.post_chain[0]; + n.post_chain = n.post_chain[1..]; + try @call(.auto, next_middleware.middleware, .{ n, next_middleware.data }); + }, + } + } + + pub fn run(self: *Self) !void { + switch (self.stage) { + .pre => { + if (self.pre_chain.chain.len > 0) { + return try self.ctx.runtime.spawn(void, self, next_middleware_task); + } else { + return try @call( + .auto, + self.pre_chain.handler.handler, + .{ self.ctx, self.pre_chain.handler.data }, + ); + } + }, + .post => { + if (self.post_chain.len > 0) { + return try self.ctx.runtime.spawn(void, self, next_middleware_task); + } else { + return try self.ctx.respond_without_middleware(); + } + }, + } + } +}; + +pub const MiddlewareFn = *const fn (*Next, usize) anyerror!void; +pub fn TypedMiddlewareFn(comptime T: type) type { + return *const fn (*Next, T) anyerror!void; +} + +pub const MiddlewareWithData = struct { + middleware: MiddlewareFn, + data: usize, +}; + +pub const Middleware = struct { + const Self = @This(); + + pre: ?MiddlewareWithData = null, + post: ?MiddlewareWithData = null, + + pub fn init() Self { + return .{}; + } + + pub fn before(self: Self, data: anytype, func: TypedMiddlewareFn(@TypeOf(data))) Self { + return .{ + .pre = .{ + .middleware = @ptrCast(func), + .data = wrap(usize, data), + }, + .post = self.post, + }; + } + + pub fn after(self: Self, data: anytype, func: TypedMiddlewareFn(@TypeOf(data))) Self { + return .{ + .pre = self.pre, + .post = .{ + .middleware = @ptrCast(func), + .data = wrap(usize, data), + }, + }; + } + + pub fn layer(self: Self) Layer { + if (self.pre != null and self.post != null) { + return .{ .pair = .{ .pre = self.pre.?, .post = self.post.? } }; + } + if (self.pre) |p| return .{ .pre = p }; + if (self.post) |p| return .{ .post = p }; + @panic("Cannot create a layer from an empty Middleware"); + } +}; diff --git a/src/http/router/route.zig b/src/http/router/route.zig index 6cb966c..c91e11d 100644 --- a/src/http/router/route.zig +++ b/src/http/router/route.zig @@ -3,218 +3,214 @@ const builtin = @import("builtin"); const log = std.log.scoped(.@"zzz/http/route"); const assert = std.debug.assert; +const wrap = @import("tardy").wrap; + const Method = @import("../method.zig").Method; const Request = @import("../request.zig").Request; const Response = @import("../response.zig").Response; const Mime = @import("../mime.zig").Mime; -const _FsDir = @import("fs_dir.zig").FsDir; -const _Context = @import("../context.zig").Context; - -/// Structure of a server route definition. -pub fn Route(comptime Server: type, comptime AppState: type) type { - return struct { - const Context = _Context(Server, AppState); - const FsDir = _FsDir(Server, AppState); - - const Self = @This(); - pub const HandlerFn = *const fn (context: *Context) anyerror!void; - - /// Defined route path. - path: []const u8, - - /// Route handlers. - handlers: [9]?HandlerFn = [_]?HandlerFn{null} ** 9, - - fn method_to_index(method: Method) u32 { - return switch (method) { - .GET => 0, - .HEAD => 1, - .POST => 2, - .PUT => 3, - .DELETE => 4, - .CONNECT => 5, - .OPTIONS => 6, - .TRACE => 7, - .PATCH => 8, - }; - } - - /// Initialize a route for the given path. - pub fn init(path: []const u8) Self { - return Self{ .path = path }; - } +const FsDir = @import("fs_dir.zig").FsDir; +const Context = @import("../context.zig").Context; +const Layer = @import("layer.zig").Layer; - /// Returns a comma delinated list of allowed Methods for this route. This - /// is meant to be used as the value for the 'Allow' header in the Response. - pub fn get_allowed(self: Self, allocator: std.mem.Allocator) ![]const u8 { - // This gets allocated within the context of the connection's arena. - const allowed_size = comptime blk: { - var size = 0; - for (std.meta.tags(Method)) |method| { - size += @tagName(method).len + 1; - } - break :blk size; - }; +pub const HandlerFn = *const fn (*Context, usize) anyerror!void; +pub fn TypedHandlerFn(comptime T: type) type { + return *const fn (*Context, T) anyerror!void; +} - const buffer = try allocator.alloc(u8, allowed_size); +pub const HandlerWithData = struct { + handler: HandlerFn, + data: usize, +}; - var current: []u8 = ""; - inline for (std.meta.tags(Method)) |method| { - if (self.handlers[@intFromEnum(method)] != null) { - current = std.fmt.bufPrint(buffer, "{s},{s}", .{ @tagName(method), current }) catch unreachable; - } +/// Structure of a server route definition. +pub const Route = struct { + const Self = @This(); + + /// Defined route path. + path: []const u8, + + /// Route handlers. + handlers: [9]?HandlerWithData = [_]?HandlerWithData{null} ** 9, + + fn method_to_index(method: Method) u32 { + return switch (method) { + .GET => 0, + .HEAD => 1, + .POST => 2, + .PUT => 3, + .DELETE => 4, + .CONNECT => 5, + .OPTIONS => 6, + .TRACE => 7, + .PATCH => 8, + }; + } + + /// Initialize a route for the given path. + pub fn init(path: []const u8) Self { + return Self{ .path = path }; + } + + /// Returns a comma delinated list of allowed Methods for this route. This + /// is meant to be used as the value for the 'Allow' header in the Response. + pub fn get_allowed(self: Self, allocator: std.mem.Allocator) ![]const u8 { + // This gets allocated within the context of the connection's arena. + const allowed_size = comptime blk: { + var size = 0; + for (std.meta.tags(Method)) |method| { + size += @tagName(method).len + 1; } - - if (current.len == 0) { - return current; - } else { - return current[0 .. current.len - 1]; + break :blk size; + }; + + const buffer = try allocator.alloc(u8, allowed_size); + + var current: []u8 = ""; + inline for (std.meta.tags(Method)) |method| { + if (self.handlers[@intFromEnum(method)] != null) { + current = std.fmt.bufPrint( + buffer, + "{s},{s}", + .{ @tagName(method), current }, + ) catch unreachable; } } - /// Get a defined request handler for the provided method. - /// Return NULL if no handler is defined for this method. - pub fn get_handler(self: Self, method: Method) ?HandlerFn { - return self.handlers[method_to_index(method)]; - } - - /// Set a new route path. - pub fn set_path(self: Self, path: []const u8) Self { - return Self{ - .path = path, - .handlers = self.handlers, - }; - } - - /// Set a handler function for the provided method. - fn inner_route( - comptime method: Method, - self: Self, - handler_fn: HandlerFn, - ) Self { - var new_handlers = self.handlers; - new_handlers[comptime method_to_index(method)] = handler_fn; - return Self{ - .path = self.path, - .handlers = new_handlers, - }; - } - - /// Set a handler function for all methods. - pub fn all(self: Self, handler_fn: HandlerFn) Self { - var new_handlers = self.handlers; - - for (&new_handlers) |*new_handler| { - new_handler.* = handler_fn; - } - - return Self{ - .path = self.path, - .handlers = new_handlers, + if (current.len == 0) { + return current; + } else { + return current[0 .. current.len - 1]; + } + } + + /// Get a defined request handler for the provided method. + /// Return NULL if no handler is defined for this method. + pub fn get_handler(self: Self, method: Method) ?HandlerWithData { + return self.handlers[method_to_index(method)]; + } + + pub fn layer(self: Self) Layer { + return .{ .route = self }; + } + + /// Set a handler function for the provided method. + inline fn inner_route( + comptime method: Method, + self: Self, + data: anytype, + handler_fn: TypedHandlerFn(@TypeOf(data)), + ) Self { + const wrapped = wrap(usize, data); + var new_handlers = self.handlers; + new_handlers[comptime method_to_index(method)] = .{ + .handler = @ptrCast(handler_fn), + .data = wrapped, + }; + + return Self{ .path = self.path, .handlers = new_handlers }; + } + + /// Set a handler function for all methods. + pub fn all(self: Self, data: anytype, handler_fn: TypedHandlerFn(@TypeOf(data))) Self { + const wrapped = wrap(usize, data); + var new_handlers = self.handlers; + + for (&new_handlers) |*new_handler| { + new_handler.* = .{ + .handler = @ptrCast(handler_fn), + .data = wrapped, }; } - pub fn get(self: Self, handler_fn: HandlerFn) Self { - return inner_route(.GET, self, handler_fn); - } - - pub fn head(self: Self, handler_fn: HandlerFn) Self { - return inner_route(.HEAD, self, handler_fn); - } - - pub fn post(self: Self, handler_fn: HandlerFn) Self { - return inner_route(.POST, self, handler_fn); - } - - pub fn put(self: Self, handler_fn: HandlerFn) Self { - return inner_route(.PUT, self, handler_fn); - } - - pub fn delete(self: Self, handler_fn: HandlerFn) Self { - return inner_route(.DELETE, self, handler_fn); - } - - pub fn connect(self: Self, handler_fn: HandlerFn) Self { - return inner_route(.CONNECT, self, handler_fn); - } - - pub fn options(self: Self, handler_fn: HandlerFn) Self { - return inner_route(.OPTIONS, self, handler_fn); - } - - pub fn trace(self: Self, handler_fn: HandlerFn) Self { - return inner_route(.TRACE, self, handler_fn); - } - - pub fn patch(self: Self, handler_fn: HandlerFn) Self { - return inner_route(.PATCH, self, handler_fn); - } - - /// Define a GET handler to serve an embedded file. - pub fn serve_embedded_file( - self: *const Self, - comptime mime: ?Mime, - comptime bytes: []const u8, - ) Self { - return self.get(struct { - fn handler_fn(ctx: *Context) !void { - const cache_control: []const u8 = if (comptime builtin.mode == .Debug) - "no-cache" - else - comptime std.fmt.comptimePrint( - "max-age={d}", - .{std.time.s_per_day * 30}, - ); - - ctx.response.headers.put_assume_capacity("Cache-Control", cache_control); - - // If our static item is greater than 1KB, - // it might be more beneficial to using caching. - if (comptime bytes.len > 1024) { - @setEvalBranchQuota(1_000_000); - const etag = comptime std.fmt.comptimePrint( - "\"{d}\"", - .{std.hash.Wyhash.hash(0, bytes)}, - ); - ctx.response.headers.put_assume_capacity("ETag", etag[0..]); - - if (ctx.request.headers.get("If-None-Match")) |match| { - if (std.mem.eql(u8, etag, match)) { - try ctx.respond(.{ - .status = .@"Not Modified", - .mime = Mime.HTML, - .body = "", - }); - - return; - } + return Self{ + .path = self.path, + .handlers = new_handlers, + }; + } + + pub fn get(self: Self, data: anytype, handler_fn: TypedHandlerFn(@TypeOf(data))) Self { + return inner_route(.GET, self, data, handler_fn); + } + + pub fn head(self: Self, data: anytype, handler_fn: TypedHandlerFn(@TypeOf(data))) Self { + return inner_route(.HEAD, self, handler_fn); + } + + pub fn post(self: Self, data: anytype, handler_fn: TypedHandlerFn(@TypeOf(data))) Self { + return inner_route(.POST, self, data, handler_fn); + } + + pub fn put(self: Self, data: anytype, handler_fn: TypedHandlerFn(@TypeOf(data))) Self { + return inner_route(.PUT, self, data, handler_fn); + } + + pub fn delete(self: Self, data: anytype, handler_fn: TypedHandlerFn(@TypeOf(data))) Self { + return inner_route(.DELETE, self, data, handler_fn); + } + + pub fn connect(self: Self, data: anytype, handler_fn: TypedHandlerFn(@TypeOf(data))) Self { + return inner_route(.CONNECT, self, data, handler_fn); + } + + pub fn options(self: Self, data: anytype, handler_fn: TypedHandlerFn(@TypeOf(data))) Self { + return inner_route(.OPTIONS, self, data, handler_fn); + } + + pub fn trace(self: Self, data: anytype, handler_fn: TypedHandlerFn(@TypeOf(data))) Self { + return inner_route(.TRACE, self, data, handler_fn); + } + + pub fn patch(self: Self, data: anytype, handler_fn: TypedHandlerFn(@TypeOf(data))) Self { + return inner_route(.PATCH, self, data, handler_fn); + } + + /// Define a GET handler to serve an embedded file. + pub fn serve_embedded_file( + self: *const Self, + comptime mime: ?Mime, + comptime bytes: []const u8, + ) Self { + return self.get({}, struct { + fn handler_fn(ctx: *Context, _: void) !void { + const cache_control: []const u8 = if (comptime builtin.mode == .Debug) + "no-cache" + else + comptime std.fmt.comptimePrint( + "max-age={d}", + .{std.time.s_per_day * 30}, + ); + + ctx.response.headers.put_assume_capacity("Cache-Control", cache_control); + + // If our static item is greater than 1KB, + // it might be more beneficial to using caching. + if (comptime bytes.len > 1024) { + @setEvalBranchQuota(1_000_000); + const etag = comptime std.fmt.comptimePrint( + "\"{d}\"", + .{std.hash.Wyhash.hash(0, bytes)}, + ); + ctx.response.headers.put_assume_capacity("ETag", etag[0..]); + + if (ctx.request.headers.get("If-None-Match")) |match| { + if (std.mem.eql(u8, etag, match)) { + return try ctx.respond(.{ + .status = .@"Not Modified", + .mime = Mime.HTML, + .body = "", + }); } } - - try ctx.respond(.{ - .status = .OK, - .mime = mime, - .body = bytes, - }); } - }.handler_fn); - } - /// Define a GET handler to serve an entire directory. - pub fn serve_fs_dir(comptime self: *const Self, comptime dir_path: []const u8) Self { - const url_with_match_all = comptime std.fmt.comptimePrint( - "{s}/%r", - .{std.mem.trimRight(u8, self.path, "/")}, - ); - - return self - .set_path(url_with_match_all) - .get(struct { - fn handler_fn(ctx: *Context) !void { - try FsDir.handler_fn(ctx, dir_path); - } - }.handler_fn); - } - }; -} + return try ctx.respond(.{ + .status = .OK, + .mime = mime, + .body = bytes, + }); + } + }.handler_fn); + } +}; diff --git a/src/http/router/routing_trie.zig b/src/http/router/routing_trie.zig index 5d30a6a..c36a986 100644 --- a/src/http/router/routing_trie.zig +++ b/src/http/router/routing_trie.zig @@ -2,13 +2,52 @@ const std = @import("std"); const assert = std.debug.assert; const log = std.log.scoped(.@"zzz/http/routing_trie"); +const Layer = @import("layer.zig").Layer; +const Route = @import("route.zig").Route; +const Bundle = @import("bundle.zig").Bundle; + +const MiddlewareWithData = @import("middleware.zig").MiddlewareWithData; const CaseStringMap = @import("../../core/case_string_map.zig").CaseStringMap; -const _Route = @import("route.zig").Route; -const TokenHashMap = @import("token_hash_map.zig").TokenHashMap; -// These tokens are for the Routes when assembling the -// Routing Trie. This allows for every sub-path to be -// parsed into a token and assembled later. +fn TokenHashMap(comptime V: type) type { + return std.HashMap(Token, V, struct { + pub fn hash(self: @This(), input: Token) u64 { + _ = self; + + const bytes = blk: { + switch (input) { + .fragment => |inner| break :blk inner, + .match => |inner| break :blk @tagName(inner), + } + }; + + return std.hash.Wyhash.hash(0, bytes); + } + + pub fn eql(self: @This(), first: Token, second: Token) bool { + _ = self; + + const result = blk: { + switch (first) { + .fragment => |f_inner| { + switch (second) { + .fragment => |s_inner| break :blk std.mem.eql(u8, f_inner, s_inner), + else => break :blk false, + } + }, + .match => |f_inner| { + switch (second) { + .match => |s_inner| break :blk f_inner == s_inner, + else => break :blk false, + } + }, + } + }; + + return result; + } + }, 80); +} const TokenEnum = enum(u8) { fragment = 0, @@ -71,220 +110,222 @@ pub const Capture = union(TokenMatch) { remaining: TokenMatch.remaining.as_type(), }; +/// Structure of a matched route. +pub const FoundBundle = struct { + bundle: Bundle, + captures: []Capture, + queries: *QueryMap, +}; + // This RoutingTrie is deleteless. It only can create new routes or update existing ones. -pub fn RoutingTrie(comptime Server: type, comptime AppState: type) type { - return struct { - const Self = @This(); - const Route = _Route(Server, AppState); - - /// Structure of a matched route. - pub const FoundRoute = struct { - route: Route, - captures: []Capture, - queries: *QueryMap, - }; +pub const RoutingTrie = struct { + const Self = @This(); + + /// Structure of a node of the trie. + pub const Node = struct { + pub const ChildrenMap = TokenHashMap(Node); + + token: Token, + bundle: ?Bundle = null, + children: ChildrenMap, + + /// Initialize a new empty node. + pub fn init(allocator: std.mem.Allocator, token: Token, bundle: ?Bundle) Node { + return .{ + .token = token, + .bundle = bundle, + .children = ChildrenMap.init(allocator), + }; + } - /// Structure of a node of the trie. - pub const Node = struct { - pub const ChildrenMap = TokenHashMap(*const Node); - - token: Token, - route: ?Route = null, - children: ChildrenMap, - - /// Initialize a new empty node. - pub fn init(token: Token, route: ?Route) Node { - return Node{ - .token = token, - .route = route, - .children = ChildrenMap.init_comptime(&[0]ChildrenMap.KV{}), - }; - } + pub fn deinit(self: *Node) void { + var iter = self.children.valueIterator(); - /// Initialize a cloned node with a new child for the provided token. - pub fn with_child(self: *const Node, token: Token, node: *const Node) Node { - return Node{ - .token = self.token, - .route = self.route, - .children = self.children.with_kvs(&[_]ChildrenMap.KV{.{ token, node }}), - }; + while (iter.next()) |node| { + node.deinit(); } - }; - root: Node = Node.init(.{ .fragment = "" }, null), - - /// Initialize the routing tree with the given routes. - pub fn init(comptime routes: []const Route) Self { - return (Self{}).with_routes(routes); + self.children.deinit(); } + }; - fn print_node(root: *const Node, depth: usize) void { - for (root.children.values) |node| { - var i: usize = 0; - while (i < depth) : (i += 1) { - std.debug.print(" │ ", .{}); - } + root: Node, + pre_mw: std.ArrayListUnmanaged(MiddlewareWithData), + post_mw: std.ArrayListUnmanaged(MiddlewareWithData), - std.debug.print(" ├ ", .{}); + /// Initialize the routing tree with the given routes. + pub fn init(allocator: std.mem.Allocator, layers: []const Layer) !Self { + var self: Self = .{ + .root = Node.init(allocator, .{ .fragment = "" }, null), + .pre_mw = try std.ArrayListUnmanaged(MiddlewareWithData).initCapacity(allocator, 0), + .post_mw = try std.ArrayListUnmanaged(MiddlewareWithData).initCapacity(allocator, 0), + }; - switch (node.token) { - .fragment => |inner| std.debug.print("Token: {s}", .{inner}), - .match => |match| std.debug.print("Token: Match {s}", .{@tagName(match)}), - } - if (node.route != null) { - std.debug.print(" ⃝", .{}); - } - std.debug.print("\n", .{}); + for (layers) |layer| { + switch (layer) { + .route => |route| { + var current = &self.root; + var iter = std.mem.tokenizeScalar(u8, route.path, '/'); + + while (iter.next()) |chunk| { + const token: Token = Token.parse_chunk(chunk); + if (current.children.getPtr(token)) |child| { + current = child; + } else { + try current.children.put(token, Node.init(allocator, token, null)); + current = current.children.getPtr(token).?; + } + } - print_node(node, depth + 1); + current.bundle = .{ + .pre = self.pre_mw.items, + .route = route, + .post = self.post_mw.items, + }; + }, + .pre => |func| try self.pre_mw.append(allocator, func), + .post => |func| try self.post_mw.append(allocator, func), + .pair => |inner| { + try self.pre_mw.append(allocator, inner.pre); + try self.post_mw.append(allocator, inner.post); + }, } } - pub fn print(self: *const Self) void { - std.debug.print("Root: \n", .{}); - print_node(&(self.root), 0); - } + return self; + } - /// Initialize new trie node for the next token. - fn with_route_helper( - comptime node: *const Node, - comptime iterator: *std.mem.TokenIterator(u8, .scalar), - comptime route: Route, - ) Node { - if (iterator.next()) |chunk| { - // Parse the current chunk. - const token: Token = Token.parse_chunk(chunk); - // Alter the child of the current node. - return node.with_child(token, &(with_route_helper( - node.children.get_optional(token) orelse &(Node.init(token, null)), - iterator, - route, - ))); - } else { - // We reached the last node, returning it with the provided route. - return Node{ - .token = node.token, - .route = route, - .children = node.children, - }; - } - } + pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { + self.root.deinit(); + self.pre_mw.deinit(allocator); + self.post_mw.deinit(allocator); + } - /// Copy the current routing trie to add the provided route. - pub fn with_route(comptime self: *const Self, comptime route: Route) Self { - @setEvalBranchQuota(1_000_000); + fn print_node(root: *const Node, depth: usize) void { + var i: usize = 0; + while (i < depth) : (i += 1) { + std.debug.print(" │ ", .{}); + } - // This is where we will parse out the path. - comptime var iterator = std.mem.tokenizeScalar(u8, route.path, '/'); + std.debug.print(" ├ ", .{}); - return Self{ - .root = with_route_helper(&(self.root), &iterator, route), - }; + switch (root.token) { + .fragment => |inner| std.debug.print("Token: \"{s}\"", .{inner}), + .match => |match| std.debug.print("Token: match {s}", .{@tagName(match)}), } - /// Copy the current routing trie to add all the provided routes. - pub fn with_routes(comptime self: *const Self, comptime routes: []const Route) Self { - comptime var current = self.*; - inline for (routes) |route| { - current = current.with_route(route); - } - return current; + if (root.bundle) |bundle| { + std.debug.print(" [x] ({d} | {d})", .{ bundle.pre.len, bundle.post.len }); + } else { + std.debug.print(" [ ]", .{}); } + std.debug.print("\n", .{}); - pub fn get_route( - self: Self, - path: []const u8, - captures: []Capture, - queries: *QueryMap, - ) !?FoundRoute { - var capture_idx: usize = 0; + var iter = root.children.valueIterator(); - queries.clear(); - const query_pos = std.mem.indexOfScalar(u8, path, '?'); - var iter = std.mem.tokenizeScalar(u8, path[0..(query_pos orelse path.len)], '/'); + while (iter.next()) |node| { + print_node(node, depth + 1); + } + } - var current = self.root; + pub fn print(self: *const Self) void { + std.debug.print("Root: \n", .{}); + print_node(&self.root, 0); + } - slash_loop: while (iter.next()) |chunk| { - const fragment = Token{ .fragment = chunk }; + pub fn get_route( + self: Self, + path: []const u8, + captures: []Capture, + queries: *QueryMap, + ) !?FoundBundle { + var capture_idx: usize = 0; - // If it is the fragment, match it here. - if (current.children.get_optional(fragment)) |child| { - current = child.*; - continue; - } + queries.clear(); + const query_pos = std.mem.indexOfScalar(u8, path, '?'); + var iter = std.mem.tokenizeScalar(u8, path[0..(query_pos orelse path.len)], '/'); - var matched = false; - for (std.meta.tags(TokenMatch)) |token_type| { - const token = Token{ .match = token_type }; - if (current.children.get_optional(token)) |child| { - matched = true; - switch (token_type) { - .signed => if (std.fmt.parseInt(i64, chunk, 10)) |value| { - captures[capture_idx] = Capture{ .signed = value }; - } else |_| continue, - .unsigned => if (std.fmt.parseInt(u64, chunk, 10)) |value| { - captures[capture_idx] = Capture{ .unsigned = value }; - } else |_| continue, - .float => if (std.fmt.parseFloat(f64, chunk)) |value| { - captures[capture_idx] = Capture{ .float = value }; - } else |_| continue, - .string => captures[capture_idx] = Capture{ .string = chunk }, - // This ends the matching sequence and claims everything. - // Does not claim the query values. - .remaining => { - const rest = iter.buffer[(iter.index - chunk.len)..]; - captures[capture_idx] = Capture{ .remaining = rest }; - - current = child.*; - capture_idx += 1; - - break :slash_loop; - }, - } + var current = self.root; + + slash_loop: while (iter.next()) |chunk| { + const fragment = Token{ .fragment = chunk }; - current = child.*; - capture_idx += 1; + // If it is the fragment, match it here. + if (current.children.get(fragment)) |child| { + current = child; + continue; + } - if (capture_idx > captures.len) return error.TooManyCaptures; - break; + var matched = false; + for (std.meta.tags(TokenMatch)) |token_type| { + const token = Token{ .match = token_type }; + if (current.children.get(token)) |child| { + matched = true; + switch (token_type) { + .signed => if (std.fmt.parseInt(i64, chunk, 10)) |value| { + captures[capture_idx] = Capture{ .signed = value }; + } else |_| continue, + .unsigned => if (std.fmt.parseInt(u64, chunk, 10)) |value| { + captures[capture_idx] = Capture{ .unsigned = value }; + } else |_| continue, + .float => if (std.fmt.parseFloat(f64, chunk)) |value| { + captures[capture_idx] = Capture{ .float = value }; + } else |_| continue, + .string => captures[capture_idx] = Capture{ .string = chunk }, + // This ends the matching sequence and claims everything. + // Does not claim the query values. + .remaining => { + const rest = iter.buffer[(iter.index - chunk.len)..]; + captures[capture_idx] = Capture{ .remaining = rest }; + + current = child; + capture_idx += 1; + + break :slash_loop; + }, } - } - // If we failed to match, this is an invalid route. - if (!matched) { - return null; + current = child; + capture_idx += 1; + + if (capture_idx > captures.len) return error.TooManyCaptures; + break; } } - if (query_pos) |pos| { - if (path.len > pos + 1) { - var query_iter = std.mem.tokenizeScalar(u8, path[pos + 1 ..], '&'); + // If we failed to match, this is an invalid route. + if (!matched) { + return null; + } + } - while (query_iter.next()) |chunk| { - if (queries.pool.clean() == 0) return null; + if (query_pos) |pos| { + if (path.len > pos + 1) { + var query_iter = std.mem.tokenizeScalar(u8, path[pos + 1 ..], '&'); - const field_idx = std.mem.indexOfScalar(u8, chunk, '=') orelse break; - if (chunk.len < field_idx + 1) break; + while (query_iter.next()) |chunk| { + if (queries.pool.clean() == 0) return null; - const key = chunk[0..field_idx]; - const value = chunk[(field_idx + 1)..]; + const field_idx = std.mem.indexOfScalar(u8, chunk, '=') orelse break; + if (chunk.len < field_idx + 1) break; - assert(std.mem.indexOfScalar(u8, key, '=') == null); - assert(std.mem.indexOfScalar(u8, value, '=') == null); - queries.put_assume_capacity(key, value); - } + const key = chunk[0..field_idx]; + const value = chunk[(field_idx + 1)..]; + + assert(std.mem.indexOfScalar(u8, key, '=') == null); + assert(std.mem.indexOfScalar(u8, value, '=') == null); + queries.put_assume_capacity(key, value); } } - - return FoundRoute{ - .route = current.route orelse return null, - .captures = captures[0..capture_idx], - .queries = queries, - }; } - }; -} + + return .{ + .bundle = current.bundle orelse return null, + .captures = captures[0..capture_idx], + .queries = queries, + }; + } +}; const testing = std.testing; @@ -346,31 +387,29 @@ test "Path Parsing (Mixed)" { } test "Constructing Routing from Path" { - const Route = _Route(void, void); - - const s = comptime RoutingTrie(void, void).init(&[_]Route{ - Route.init("/item"), - Route.init("/item/%i/description"), - Route.init("/item/%i/hello"), - Route.init("/item/%f/price_float"), - Route.init("/item/name/%s"), - Route.init("/item/list"), + var s = try RoutingTrie.init(testing.allocator, &.{ + Route.init("/item").layer(), + Route.init("/item/%i/description").layer(), + Route.init("/item/%i/hello").layer(), + Route.init("/item/%f/price_float").layer(), + Route.init("/item/name/%s").layer(), + Route.init("/item/list").layer(), }); + defer s.deinit(testing.allocator); - try testing.expectEqual(1, s.root.children.keys.len); + try testing.expectEqual(1, s.root.children.count()); } test "Routing with Paths" { - const Route = _Route(void, void); - - const s = comptime RoutingTrie(void, void).init(&[_]Route{ - Route.init("/item"), - Route.init("/item/%i/description"), - Route.init("/item/%i/hello"), - Route.init("/item/%f/price_float"), - Route.init("/item/name/%s"), - Route.init("/item/list"), + var s = try RoutingTrie.init(testing.allocator, &.{ + Route.init("/item").layer(), + Route.init("/item/%i/description").layer(), + Route.init("/item/%i/hello").layer(), + Route.init("/item/%f/price_float").layer(), + Route.init("/item/name/%s").layer(), + Route.init("/item/list").layer(), }); + defer s.deinit(testing.allocator); var q = try QueryMap.init(testing.allocator, 8); defer q.deinit(); @@ -382,27 +421,26 @@ test "Routing with Paths" { { const captured = (try s.get_route("/item/name/HELLO", captures[0..], &q)).?; - try testing.expectEqual(Route.init("/item/name/%s"), captured.route); + try testing.expectEqual(Route.init("/item/name/%s"), captured.bundle.route); try testing.expectEqualStrings("HELLO", captured.captures[0].string); } { const captured = (try s.get_route("/item/2112.22121/price_float", captures[0..], &q)).?; - try testing.expectEqual(Route.init("/item/%f/price_float"), captured.route); + try testing.expectEqual(Route.init("/item/%f/price_float"), captured.bundle.route); try testing.expectEqual(2112.22121, captured.captures[0].float); } } test "Routing with Remaining" { - const Route = _Route(void, void); - - const s = comptime RoutingTrie(void, void).init(&[_]Route{ - Route.init("/item"), - Route.init("/item/%f/price_float"), - Route.init("/item/name/%r"), - Route.init("/item/%i/price/%f"), + var s = try RoutingTrie.init(testing.allocator, &.{ + Route.init("/item").layer(), + Route.init("/item/%f/price_float").layer(), + Route.init("/item/name/%r").layer(), + Route.init("/item/%i/price/%f").layer(), }); + defer s.deinit(testing.allocator); var q = try QueryMap.init(testing.allocator, 8); defer q.deinit(); @@ -413,38 +451,37 @@ test "Routing with Remaining" { { const captured = (try s.get_route("/item/name/HELLO", captures[0..], &q)).?; - try testing.expectEqual(Route.init("/item/name/%r"), captured.route); + try testing.expectEqual(Route.init("/item/name/%r"), captured.bundle.route); try testing.expectEqualStrings("HELLO", captured.captures[0].remaining); } { const captured = (try s.get_route("/item/name/THIS/IS/A/FILE/SYSTEM/PATH.html", captures[0..], &q)).?; - try testing.expectEqual(Route.init("/item/name/%r"), captured.route); + try testing.expectEqual(Route.init("/item/name/%r"), captured.bundle.route); try testing.expectEqualStrings("THIS/IS/A/FILE/SYSTEM/PATH.html", captured.captures[0].remaining); } { const captured = (try s.get_route("/item/2112.22121/price_float", captures[0..], &q)).?; - try testing.expectEqual(Route.init("/item/%f/price_float"), captured.route); + try testing.expectEqual(Route.init("/item/%f/price_float"), captured.bundle.route); try testing.expectEqual(2112.22121, captured.captures[0].float); } { const captured = (try s.get_route("/item/100/price/283.21", captures[0..], &q)).?; - try testing.expectEqual(Route.init("/item/%i/price/%f"), captured.route); + try testing.expectEqual(Route.init("/item/%i/price/%f"), captured.bundle.route); try testing.expectEqual(100, captured.captures[0].signed); try testing.expectEqual(283.21, captured.captures[1].float); } } test "Routing with Queries" { - const Route = _Route(void, void); - - const s = comptime RoutingTrie(void, void).init(&[_]Route{ - Route.init("/item"), - Route.init("/item/%f/price_float"), - Route.init("/item/name/%r"), - Route.init("/item/%i/price/%f"), + var s = try RoutingTrie.init(testing.allocator, &.{ + Route.init("/item").layer(), + Route.init("/item/%f/price_float").layer(), + Route.init("/item/name/%r").layer(), + Route.init("/item/%i/price/%f").layer(), }); + defer s.deinit(testing.allocator); var q = try QueryMap.init(testing.allocator, 8); defer q.deinit(); @@ -455,7 +492,7 @@ test "Routing with Queries" { { const captured = (try s.get_route("/item/name/HELLO?name=muki&food=waffle", captures[0..], &q)).?; - try testing.expectEqual(Route.init("/item/name/%r"), captured.route); + try testing.expectEqual(Route.init("/item/name/%r"), captured.bundle.route); try testing.expectEqualStrings("HELLO", captured.captures[0].remaining); try testing.expectEqual(2, q.dirty()); try testing.expectEqualStrings("muki", q.get("name").?); @@ -465,7 +502,7 @@ test "Routing with Queries" { { // Purposefully bad format with no keys or values. const captured = (try s.get_route("/item/2112.22121/price_float?", captures[0..], &q)).?; - try testing.expectEqual(Route.init("/item/%f/price_float"), captured.route); + try testing.expectEqual(Route.init("/item/%f/price_float"), captured.bundle.route); try testing.expectEqual(2112.22121, captured.captures[0].float); try testing.expectEqual(0, q.dirty()); } @@ -473,7 +510,7 @@ test "Routing with Queries" { { // Purposefully bad format with incomplete key/value pair. const captured = (try s.get_route("/item/100/price/283.21?help", captures[0..], &q)).?; - try testing.expectEqual(Route.init("/item/%i/price/%f"), captured.route); + try testing.expectEqual(Route.init("/item/%i/price/%f"), captured.bundle.route); try testing.expectEqual(100, captured.captures[0].signed); try testing.expectEqual(283.21, captured.captures[1].float); try testing.expectEqual(0, q.dirty()); diff --git a/src/http/router/token_hash_map.zig b/src/http/router/token_hash_map.zig deleted file mode 100644 index 2a89c29..0000000 --- a/src/http/router/token_hash_map.zig +++ /dev/null @@ -1,203 +0,0 @@ -const std = @import("std"); -const Token = @import("routing_trie.zig").Token; - -/// Errors of get function. -pub const MapGetErrors = error{ - NotFound, -}; - -/// Type of a token hash. -pub const Hash = u64; - -/// Type of a hash entry in hashed array. -pub const HashEntry = struct { Hash, usize }; - -/// In-place sort of the given array at compile time. -/// https://github.com/Koura/algorithms/blob/b1dd07147a34554543994b2c033fae64a2202933/sorting/quicksort.zig -fn sort(A: []HashEntry, lo: usize, hi: usize) void { - if (lo < hi) { - const p = partition(A, lo, hi); - sort(A, lo, @min(p, p -% 1)); - sort(A, p + 1, hi); - } -} - -fn partition(A: []HashEntry, lo: usize, hi: usize) usize { - // Pivot can be chosen otherwise, for example try picking the first or random - // and check in which way that affects the performance of the sorting. - const pivot = A[hi][0]; - var i = lo; - var j = lo; - while (j < hi) : (j += 1) { - if (A[j][0] <= pivot) { - std.mem.swap(HashEntry, &A[i], &A[j]); - i += 1; - } - } - std.mem.swap(HashEntry, &A[i], &A[hi]); - return i; -} - -/// Compile-time token hash map. -pub fn TokenHashMap(V: type) type { - return struct { - const Self = @This(); - - /// Type of a key-value tuple. - pub const KV = struct { - Token, - V, - }; - - /// Sorted array of tokens keys hashes. - /// Associate a key hash to an index in keys / values array. - hashes: []const HashEntry, - - /// Keys of the map. - keys: []const Token, - - /// Values of the map. - values: []const V, - - /// Hash the given token key. - fn hash_key(input: Token) Hash { - const bytes = blk: { - break :blk switch (input) { - .fragment => |inner| inner, - .match => |inner| @tagName(inner), - }; - }; - - return std.hash.Wyhash.hash(0, bytes); - } - - /// Initialize a token hash map with the given key-value tuples. - pub fn init_comptime(comptime kvs: []const KV) Self { - const arrays = comptime kvs: { - // Initialize arrays. - var result = struct { - hashes: [kvs.len]HashEntry = undefined, - keys: [kvs.len]Token = undefined, - values: [kvs.len]V = undefined, - }{}; - - // Add each key-value tuple to the internal map arrays. - var index = 0; - for (kvs) |kv| { - // Get the current key hash. - const hash = hash_key(kv[0]); - // Fill keys / values internal arrays. - result.hashes[index] = .{ hash, index }; - result.keys[index] = kv[0]; - result.values[index] = kv[1]; - index += 1; - } - - // Sort the hashes, if there is something to sort. - if (kvs.len > 0) sort(&result.hashes, 0, kvs.len - 1); - - break :kvs result; - }; - - // Make an HashMap object from initialized arrays. - return .{ - .hashes = &arrays.hashes, - .keys = &arrays.keys, - .values = &arrays.values, - }; - } - - /// Get raw key-value tuples of the current map. - pub fn get_kvs(self: *const Self) []const KV { - var kvs: [self.keys.len]KV = undefined; - for (&kvs, self.keys, self.values) |*kv, key, value| { - kv.* = .{ key, value }; - } - return &kvs; - } - - /// Initialize a cloned token hash map with the provided new key-value tuples. - pub fn with_kvs(self: *const Self, comptime new_kvs: []const KV) Self { - // Get key-value tuples to remove from clones: they are overridden by new key-value tuples. - const kvs_to_remove = comptime kvs_to_remove: { - var kvs: []const usize = &[0]usize{}; - - for (new_kvs) |kv| { - const index: ?usize = self.get_index(kv[0]) catch null; - if (index) |idx| { - kvs = kvs ++ .{idx}; - } - } - - break :kvs_to_remove kvs; - }; - - // Get key-value tuples to clone: all key-value tuples of the current map, without the overridden ones. - const kvs_to_clone = if (kvs_to_remove.len > 0) comptime kvs: { - var kvs: []const KV = &[0]KV{}; - - var kvs_to_remove_index = 0; - var i = 0; - for (self.get_kvs()) |kv| { - if (kvs_to_remove_index < kvs_to_remove.len and i != kvs_to_remove[kvs_to_remove_index]) { - kvs = kvs ++ .{kv}; - } else { - kvs_to_remove_index += 1; - } - i += 1; - } - - break :kvs kvs; - } else self.get_kvs(); - - return Self.init_comptime(kvs_to_clone ++ new_kvs); - } - - /// Get the index in keys / values array from the provided token key. - pub fn get_index(self: *const Self, key: Token) MapGetErrors!usize { - // Get the current key hash. - const hash = hash_key(key); - - // Search in the sorted hashes array. - const hash_index = std.sort.binarySearch(HashEntry, hash, self.hashes, {}, struct { - fn f(_: void, searched_key: Hash, mid_item: HashEntry) std.math.Order { - if (searched_key < mid_item[0]) return std.math.Order.lt; - if (searched_key > mid_item[0]) return std.math.Order.gt; - if (searched_key == mid_item[0]) return std.math.Order.eq; - - unreachable; - } - }.f); - - // No hash index has been found, return not found. - if (hash_index == null) return MapGetErrors.NotFound; - - // Get the index in keys / values in hashes. - return self.hashes[hash_index.?][1]; - } - - /// Get the value of a given token key. - pub fn get(self: *const Self, key: Token) MapGetErrors!V { - return self.values[try self.get_index(key)]; - } - - /// Try to get the value of a given token key, return NULL if it doesn't exists. - pub fn get_optional(self: *const Self, key: Token) ?V { - return self.get(key) catch null; - } - }; -} - -test TokenHashMap { - const map = comptime TokenHashMap([]const u8).init_comptime(&[_]TokenHashMap([]const u8).KV{ - .{ Token{ .fragment = "route-fragment" }, "route" }, - .{ Token{ .match = .unsigned }, "id" }, - .{ Token{ .match = .remaining }, "remaining" }, - }); - - try std.testing.expectEqualStrings("route", try map.get(Token{ .fragment = "route-fragment" })); - try std.testing.expectEqualStrings("id", try map.get(Token{ .match = .unsigned })); - try std.testing.expectEqualStrings("remaining", try map.get(Token{ .match = .remaining })); - try std.testing.expectError(MapGetErrors.NotFound, map.get(Token{ .fragment = "not_found" })); - try std.testing.expectEqual(null, map.get_optional(Token{ .fragment = "not_found" })); -} diff --git a/src/http/server.zig b/src/http/server.zig index 4f026c2..748dc21 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -10,20 +10,24 @@ const TLSFileOptions = @import("../tls/lib.zig").TLSFileOptions; const TLSContext = @import("../tls/lib.zig").TLSContext; const TLS = @import("../tls/lib.zig").TLS; -const _Context = @import("context.zig").Context; +const Context = @import("context.zig").Context; const Request = @import("request.zig").Request; const Response = @import("response.zig").Response; const Capture = @import("router/routing_trie.zig").Capture; const QueryMap = @import("router/routing_trie.zig").QueryMap; const ResponseSetOptions = Response.ResponseSetOptions; -const _SSE = @import("sse.zig").SSE; +const SSE = @import("sse.zig").SSE; const Provision = @import("provision.zig").Provision; const Mime = @import("mime.zig").Mime; -const _Router = @import("router.zig").Router; -const _Route = @import("router/route.zig").Route; +const Router = @import("router.zig").Router; +const Route = @import("router/route.zig").Route; +const Layer = @import("router/layer.zig").Layer; +const Middleware = @import("router/middleware.zig").Middleware; const HTTPError = @import("lib.zig").HTTPError; +const Next = @import("router/middleware.zig").Next; + const AfterType = @import("../core/job.zig").AfterType; const Pool = @import("tardy").Pool; @@ -82,6 +86,7 @@ pub fn raw_respond(p: *Provision) !RecvStatus { /// This includes various different options and limits /// for interacting with the underlying network. pub const ServerConfig = struct { + security: Security = .plain, /// Kernel Backlog Value. backlog_count: u31 = 512, /// Number of Maximum Concurrent Connections. @@ -149,234 +154,226 @@ pub const ServerConfig = struct { request_uri_bytes_max: u32 = 1024 * 2, }; -pub fn Server(comptime security: Security, comptime AppState: type) type { - const TLSContextType = comptime if (security == .tls) TLSContext else void; - const TLSType = comptime if (security == .tls) ?TLS else void; - - return struct { - const Self = @This(); - pub const Context = _Context(Self, AppState); - pub const Router = _Router(Self, AppState); - pub const Route = _Route(Self, AppState); - pub const SSE = _SSE(Self, AppState); - allocator: std.mem.Allocator, - config: ServerConfig, - addr: ?std.net.Address, - tls_ctx: TLSContextType, - router: *const Router, - - pub fn init(allocator: std.mem.Allocator, config: ServerConfig) Self { - const tls_ctx = switch (comptime security) { - .tls => |inner| TLSContext.init(allocator, .{ - .cert = inner.cert, - .cert_name = inner.cert_name, - .key = inner.key, - .key_name = inner.key_name, - .size_tls_buffer_max = config.socket_buffer_bytes * 2, - }) catch unreachable, - .plain => void{}, - }; +pub const Server = struct { + const Self = @This(); + allocator: std.mem.Allocator, + config: ServerConfig, + addr: ?std.net.Address, + tls_ctx: ?TLSContext, + router: *const Router, + + pub fn init(allocator: std.mem.Allocator, config: ServerConfig) Self { + const tls_ctx = switch (config.security) { + .tls => |inner| TLSContext.init(allocator, .{ + .cert = inner.cert, + .cert_name = inner.cert_name, + .key = inner.key, + .key_name = inner.key_name, + .size_tls_buffer_max = config.socket_buffer_bytes * 2, + }) catch unreachable, + .plain => null, + }; - return Self{ - .allocator = allocator, - .config = config, - .addr = null, - .tls_ctx = tls_ctx, - .router = undefined, - }; - } + return Self{ + .allocator = allocator, + .config = config, + .addr = null, + .tls_ctx = tls_ctx, + .router = undefined, + }; + } - pub fn deinit(self: *const Self) void { - if (comptime security == .tls) { - self.tls_ctx.deinit(); - } + pub fn deinit(self: *const Self) void { + if (self.tls_ctx) |tls| { + tls.deinit(); } + } - const BindOptions = switch (builtin.os.tag) { - // Currently, don't support unix sockets - // on Windows. - .windows => union(enum) { - ip: struct { host: []const u8, port: u16 }, - }, - else => union(enum) { - ip: struct { host: []const u8, port: u16 }, - unix: []const u8, - }, - }; + const BindOptions = switch (builtin.os.tag) { + // Currently, don't support unix sockets + // on Windows. + .windows => union(enum) { + ip: struct { host: []const u8, port: u16 }, + }, + else => union(enum) { + ip: struct { host: []const u8, port: u16 }, + unix: []const u8, + }, + }; - pub fn bind(self: *Self, options: BindOptions) !void { - self.addr = blk: { - if (options == .ip) { - const inner = options.ip; - assert(inner.host.len > 0); - assert(inner.port > 0); + pub fn bind(self: *Self, options: BindOptions) !void { + self.addr = blk: { + if (options == .ip) { + const inner = options.ip; + assert(inner.host.len > 0); + assert(inner.port > 0); - if (comptime builtin.os.tag == .linux) { - break :blk try std.net.Address.resolveIp(inner.host, inner.port); - } else { - break :blk try std.net.Address.parseIp(inner.host, inner.port); - } + if (comptime builtin.os.tag == .linux) { + break :blk try std.net.Address.resolveIp(inner.host, inner.port); + } else { + break :blk try std.net.Address.parseIp(inner.host, inner.port); } + } - if (comptime @hasField(BindOptions, "unix")) { - if (options == .unix) { - const path = options.unix; - assert(path.len > 0); + if (comptime @hasField(BindOptions, "unix")) { + if (options == .unix) { + const path = options.unix; + assert(path.len > 0); - // Unlink the existing file if it exists. - _ = std.posix.unlink(path) catch |e| switch (e) { - error.FileNotFound => {}, - else => return e, - }; + // Unlink the existing file if it exists. + _ = std.posix.unlink(path) catch |e| switch (e) { + error.FileNotFound => {}, + else => return e, + }; - break :blk try std.net.Address.initUnix(path); - } + break :blk try std.net.Address.initUnix(path); } + } - unreachable; - }; - } - - pub fn close_task(rt: *Runtime, _: void, provision: *Provision) !void { - assert(provision.job == .close); - const server_socket = rt.storage.get("__zzz_server_socket", std.posix.socket_t); - const pool = rt.storage.get_ptr("__zzz_provision_pool", Pool(Provision)); - const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); - - log.info("{d} - closing connection", .{provision.index}); + unreachable; + }; + } - if (comptime security == .tls) { - const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType); - const tls_ptr: *TLSType = &tls_slice[provision.index]; - assert(tls_ptr.* != null); - tls_ptr.*.?.deinit(); - tls_ptr.* = null; - } + pub fn close_task(rt: *Runtime, _: void, provision: *Provision) !void { + assert(provision.job == .close); + const server_socket = rt.storage.get("__zzz_server_socket", std.posix.socket_t); + const pool = rt.storage.get_ptr("__zzz_provision_pool", Pool(Provision)); + const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); - provision.socket = Cross.socket.INVALID_SOCKET; - provision.job = .empty; - _ = provision.arena.reset(.{ .retain_with_limit = config.connection_arena_bytes_retain }); + log.info("{d} - closing connection", .{provision.index}); - provision.request.clear(); - provision.response.clear(); + if (config.security == .tls) { + const tls_slice = rt.storage.get("__zzz_tls_slice", []?TLS); + const tls_ptr: *?TLS = &tls_slice[provision.index]; + assert(tls_ptr.* != null); + tls_ptr.*.?.deinit(); + tls_ptr.* = null; + } - if (provision.recv_buffer.len > config.list_recv_bytes_retain) { - try provision.recv_buffer.shrink_clear_and_free(config.list_recv_bytes_retain); - } else { - provision.recv_buffer.clear_retaining_capacity(); - } + provision.socket = Cross.socket.INVALID_SOCKET; + provision.job = .empty; + _ = provision.arena.reset(.{ .retain_with_limit = config.connection_arena_bytes_retain }); - pool.release(provision.index); + provision.request.clear(); + provision.response.clear(); - const accept_queued = rt.storage.get_ptr("__zzz_accept_queued", bool); - if (!accept_queued.*) { - accept_queued.* = true; - try rt.net.accept( - server_socket, - accept_task, - server_socket, - ); - } + if (provision.recv_buffer.len > config.list_recv_bytes_retain) { + try provision.recv_buffer.shrink_clear_and_free(config.list_recv_bytes_retain); + } else { + provision.recv_buffer.clear_retaining_capacity(); } - fn accept_task(rt: *Runtime, result: AcceptResult, socket: std.posix.socket_t) !void { - const accept_queued = rt.storage.get_ptr("__zzz_accept_queued", bool); + pool.release(provision.index); - const child_socket = result.unwrap() catch |e| { - log.err("socket accept failed | {}", .{e}); - accept_queued.* = true; - try rt.net.accept(socket, accept_task, socket); - return; - }; + const accept_queued = rt.storage.get_ptr("__zzz_accept_queued", bool); + if (!accept_queued.*) { + accept_queued.* = true; + try rt.net.accept( + server_socket, + accept_task, + server_socket, + ); + } + } - const pool = rt.storage.get_ptr("__zzz_provision_pool", Pool(Provision)); - accept_queued.* = false; + fn accept_task(rt: *Runtime, result: AcceptResult, socket: std.posix.socket_t) !void { + const accept_queued = rt.storage.get_ptr("__zzz_accept_queued", bool); - if (rt.scheduler.tasks.clean() >= 2) { - accept_queued.* = true; - try rt.net.accept(socket, accept_task, socket); - } + const child_socket = result.unwrap() catch |e| { + log.err("socket accept failed | {}", .{e}); + accept_queued.* = true; + try rt.net.accept(socket, accept_task, socket); + return; + }; - // This should never fail. It means that we have a dangling item. - assert(pool.clean() > 0); - const borrowed = pool.borrow() catch unreachable; + const pool = rt.storage.get_ptr("__zzz_provision_pool", Pool(Provision)); + accept_queued.* = false; - log.info("{d} - accepting connection", .{borrowed.index}); - log.debug( - "empty provision slots: {d}", - .{pool.items.len - pool.dirty.count()}, - ); - assert(borrowed.item.job == .empty); + if (rt.scheduler.tasks.clean() >= 2) { + accept_queued.* = true; + try rt.net.accept(socket, accept_task, socket); + } - if (!rt.storage.get("__zzz_is_unix", bool)) - try Cross.socket.disable_nagle(child_socket); + // This should never fail. It means that we have a dangling item. + assert(pool.clean() > 0); + const borrowed = pool.borrow() catch unreachable; - try Cross.socket.to_nonblock(child_socket); + log.info("{d} - accepting connection", .{borrowed.index}); + log.debug( + "empty provision slots: {d}", + .{pool.items.len - pool.dirty.count()}, + ); + assert(borrowed.item.job == .empty); - const provision = borrowed.item; + if (!rt.storage.get("__zzz_is_unix", bool)) + try Cross.socket.disable_nagle(child_socket); - // Store the index of this item. - provision.index = @intCast(borrowed.index); - provision.socket = child_socket; - log.debug("provision buffer size: {d}", .{provision.buffer.len}); + try Cross.socket.to_nonblock(child_socket); - switch (comptime security) { - .tls => |_| { - const tls_ctx = rt.storage.get_const_ptr("__zzz_tls_ctx", TLSContextType); - const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType); + const provision = borrowed.item; - const tls_ptr: *TLSType = &tls_slice[provision.index]; - assert(tls_ptr.* == null); + // Store the index of this item. + provision.index = @intCast(borrowed.index); + provision.socket = child_socket; + log.debug("provision buffer size: {d}", .{provision.buffer.len}); - tls_ptr.* = tls_ctx.create(child_socket) catch |e| { - log.err("{d} - tls creation failed={any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(provision, close_task, provision.socket); - return error.TLSCreationFailed; - }; + const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); - const recv_buf = tls_ptr.*.?.start_handshake() catch |e| { - log.err("{d} - tls start handshake failed={any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(provision, close_task, provision.socket); - return error.TLSStartHandshakeFailed; - }; + if (config.security == .tls) { + const tls_ctx = rt.storage.get_const_ptr("__zzz_tls_ctx", ?TLSContext); + const tls_slice = rt.storage.get("__zzz_tls_slice", []?TLS); - provision.job = .{ .handshake = .{ .state = .recv, .count = 0 } }; - try rt.net.recv(borrowed.item, handshake_recv_task, child_socket, recv_buf); - }, - .plain => { - provision.job = .{ .recv = .{ .count = 0 } }; - try rt.net.recv(provision, recv_task, child_socket, provision.buffer); - }, - } - } + const tls_ptr: *?TLS = &tls_slice[provision.index]; + assert(tls_ptr.* == null); - fn recv_task(rt: *Runtime, result: RecvResult, provision: *Provision) !void { - assert(provision.job == .recv); + tls_ptr.* = tls_ctx.*.?.create(child_socket) catch |e| { + log.err("{d} - tls creation failed={any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(provision, close_task, provision.socket); + return error.TLSCreationFailed; + }; - // recv_count is how many bytes we have read off the socket - const recv_count = result.unwrap() catch |e| { - if (e != error.Closed) { - log.warn("socket recv failed | {}", .{e}); - } + const recv_buf = tls_ptr.*.?.start_handshake() catch |e| { + log.err("{d} - tls start handshake failed={any}", .{ provision.index, e }); provision.job = .close; try rt.net.close(provision, close_task, provision.socket); - return; + return error.TLSStartHandshakeFailed; }; - const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); - const router = rt.storage.get_const_ptr("__zzz_router", Router); + provision.job = .{ .handshake = .{ .state = .recv, .count = 0 } }; + try rt.net.recv(borrowed.item, handshake_recv_task, child_socket, recv_buf); + } else { + provision.job = .{ .recv = .{ .count = 0 } }; + try rt.net.recv(provision, recv_task, child_socket, provision.buffer); + } + } - const recv_job = &provision.job.recv; + fn recv_task(rt: *Runtime, result: RecvResult, provision: *Provision) !void { + assert(provision.job == .recv); - log.debug("{d} - recv triggered", .{provision.index}); + // recv_count is how many bytes we have read off the socket + const recv_count = result.unwrap() catch |e| { + if (e != error.Closed) { + log.warn("socket recv failed | {}", .{e}); + } + provision.job = .close; + try rt.net.close(provision, close_task, provision.socket); + return; + }; + + const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); + const router = rt.storage.get_const_ptr("__zzz_router", Router); - // this is how many http bytes we have received - const http_bytes_count: usize = blk: { - if (comptime security == .tls) { - const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType); - const tls_ptr: *TLSType = &tls_slice[provision.index]; + const recv_job = &provision.job.recv; + + log.debug("{d} - recv triggered", .{provision.index}); + + // this is how many http bytes we have received + const http_bytes_count: usize = blk: { + switch (config.security) { + .tls => { + const tls_slice = rt.storage.get("__zzz_tls_slice", []?TLS); + const tls_ptr: *?TLS = &tls_slice[provision.index]; assert(tls_ptr.* != null); const decrypted = tls_ptr.*.?.decrypt(provision.buffer[0..recv_count]) catch |e| { @@ -391,463 +388,491 @@ pub fn Server(comptime security: Security, comptime AppState: type) type { const area = try provision.recv_buffer.get_write_area(decrypted.len); std.mem.copyForwards(u8, area, decrypted); break :blk decrypted.len; - } else { - break :blk recv_count; - } - }; + }, + .plain => break :blk recv_count, + } + }; - provision.recv_buffer.mark_written(http_bytes_count); - provision.buffer = try provision.recv_buffer.get_write_area(config.socket_buffer_bytes); - recv_job.count += http_bytes_count; + provision.recv_buffer.mark_written(http_bytes_count); + provision.buffer = try provision.recv_buffer.get_write_area(config.socket_buffer_bytes); + recv_job.count += http_bytes_count; - const status = try on_recv(http_bytes_count, rt, provision, router, config); - assert(provision.buffer.len == config.socket_buffer_bytes); + const status = try on_recv(http_bytes_count, rt, provision, router, config); + assert(provision.buffer.len == config.socket_buffer_bytes); - switch (status) { - .spawned => return, - .kill => { - rt.stop(); - return error.Killed; - }, - .recv => { - try rt.net.recv( - provision, - recv_task, - provision.socket, - provision.buffer, - ); - }, - .send => |pslice| { - const first_buffer = try prepare_send(rt, provision, .recv, pslice); - try rt.net.send( - provision, - send_then_recv_task, - provision.socket, - first_buffer, - ); - }, - } + switch (status) { + .spawned => return, + .kill => { + rt.stop(); + return error.Killed; + }, + .recv => { + try rt.net.recv( + provision, + recv_task, + provision.socket, + provision.buffer, + ); + }, + .send => |pslice| { + const first_buffer = try prepare_send(rt, provision, .recv, pslice); + try rt.net.send( + provision, + send_then_recv_task, + provision.socket, + first_buffer, + ); + }, } + } - fn handshake_recv_task(rt: *Runtime, result: RecvResult, provision: *Provision) !void { - assert(security == .tls); + fn handshake_recv_task(rt: *Runtime, result: RecvResult, provision: *Provision) !void { + const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); + assert(config.security == .tls); + assert(provision.job == .handshake); + assert(provision.job.handshake.state == .recv); - const length = result.unwrap() catch |e| { - if (e != error.Closed) { - log.warn("socket recv failed | {}", .{e}); - } - provision.job = .close; - try rt.net.close(provision, close_task, provision.socket); - return error.TLSHandshakeClosed; - }; + const length = result.unwrap() catch |e| { + if (e != error.Closed) { + log.warn("socket recv failed | {}", .{e}); + } + provision.job = .close; + try rt.net.close(provision, close_task, provision.socket); + return error.TLSHandshakeClosed; + }; - try handshake_inner_task(rt, length, provision); - } + log.debug("handshake recv length: {d}", .{length}); - fn handshake_send_task(rt: *Runtime, result: SendResult, provision: *Provision) !void { - assert(security == .tls); + try handshake_inner_task(rt, length, provision); + } - const length = result.unwrap() catch |e| { - if (e != error.ConnectionReset) { - log.warn("socket send failed | {}", .{e}); - } - provision.job = .close; - try rt.net.close(provision, close_task, provision.socket); - return error.TLSHandshakeClosed; - }; + fn handshake_send_task(rt: *Runtime, result: SendResult, provision: *Provision) !void { + const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); + assert(config.security == .tls); + assert(provision.job == .handshake); + assert(provision.job.handshake.state == .send); - try handshake_inner_task(rt, length, provision); + const length = result.unwrap() catch |e| { + if (e != error.ConnectionReset) { + log.warn("socket send failed | {}", .{e}); + } + provision.job = .close; + try rt.net.close(provision, close_task, provision.socket); + return error.TLSHandshakeClosed; + }; + + log.debug("handshake send length: {d}", .{length}); + + try handshake_inner_task(rt, length, provision); + } + + fn handshake_inner_task(rt: *Runtime, length: usize, provision: *Provision) !void { + const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); + assert(config.security == .tls); + + const tls_slice = rt.storage.get("__zzz_tls_slice", []?TLS); + + assert(provision.job == .handshake); + const handshake_job = &provision.job.handshake; + + const tls_ptr: *?TLS = &tls_slice[provision.index]; + assert(tls_ptr.* != null); + log.debug("processing handshake", .{}); + handshake_job.count += 1; + + if (handshake_job.count >= 50) { + log.debug("handshake taken too many cycles", .{}); + provision.job = .close; + try rt.net.close(provision, close_task, provision.socket); + return error.TLSHandshakeTooManyCycles; } - fn handshake_inner_task(rt: *Runtime, length: usize, provision: *Provision) !void { - assert(security == .tls); - if (comptime security == .tls) { - const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType); + const hstate = switch (handshake_job.state) { + .recv => tls_ptr.*.?.continue_handshake(.{ .recv = length }), + .send => tls_ptr.*.?.continue_handshake(.{ .send = length }), + } catch |e| { + log.err("{d} - tls handshake failed={any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(provision, close_task, provision.socket); + return error.TLSHandshakeRecvFailed; + }; + + switch (hstate) { + .recv => |buf| { + log.debug("queueing recv in handshake", .{}); + handshake_job.state = .recv; + try rt.net.recv(provision, handshake_recv_task, provision.socket, buf); + }, + .send => |buf| { + log.debug("queueing send in handshake", .{}); + handshake_job.state = .send; + try rt.net.send(provision, handshake_send_task, provision.socket, buf); + }, + .complete => { + log.debug("handshake complete", .{}); + provision.job = .{ .recv = .{ .count = 0 } }; + try rt.net.recv(provision, recv_task, provision.socket, provision.buffer); + }, + } + } - assert(provision.job == .handshake); - const handshake_job = &provision.job.handshake; + /// Prepares the provision send_job and returns the first send chunk + pub fn prepare_send(rt: *Runtime, provision: *Provision, after: AfterType, pslice: Pseudoslice) ![]const u8 { + const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); + const plain_buffer = pslice.get(0, config.socket_buffer_bytes); - const tls_ptr: *TLSType = &tls_slice[provision.index]; + switch (config.security) { + .tls => { + const tls_slice = rt.storage.get("__zzz_tls_slice", []?TLS); + const tls_ptr: *?TLS = &tls_slice[provision.index]; assert(tls_ptr.* != null); - log.debug("processing handshake", .{}); - handshake_job.count += 1; - if (handshake_job.count >= 50) { - log.debug("handshake taken too many cycles", .{}); + const encrypted_buffer = tls_ptr.*.?.encrypt(plain_buffer) catch |e| { + log.err("{d} - encrypt failed: {any}", .{ provision.index, e }); provision.job = .close; try rt.net.close(provision, close_task, provision.socket); - return error.TLSHandshakeTooManyCycles; - } - - const hstate = switch (handshake_job.state) { - .recv => tls_ptr.*.?.continue_handshake(.{ .recv = length }), - .send => tls_ptr.*.?.continue_handshake(.{ .send = length }), - } catch |e| { - log.err("{d} - tls handshake failed={any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(provision, close_task, provision.socket); - return error.TLSHandshakeRecvFailed; + return error.TLSEncryptFailed; }; - switch (hstate) { - .recv => |buf| { - log.debug("queueing recv in handshake", .{}); - handshake_job.state = .recv; - try rt.net.recv(provision, handshake_recv_task, provision.socket, buf); - }, - .send => |buf| { - log.debug("queueing send in handshake", .{}); - handshake_job.state = .send; - try rt.net.send(provision, handshake_send_task, provision.socket, buf); + provision.job = .{ + .send = .{ + .after = after, + .slice = pslice, + .count = @intCast(plain_buffer.len), + .security = .{ + .tls = .{ + .encrypted = encrypted_buffer, + .encrypted_count = 0, + }, + }, }, - .complete => { - log.debug("handshake complete", .{}); - provision.job = .{ .recv = .{ .count = 0 } }; - try rt.net.recv(provision, recv_task, provision.socket, provision.buffer); + }; + + return encrypted_buffer; + }, + .plain => { + provision.job = .{ + .send = .{ + .after = after, + .slice = pslice, + .count = 0, + .security = .plain, }, - } - } + }; + + return plain_buffer; + }, } + } - /// Prepares the provision send_job and returns the first send chunk - pub fn prepare_send(rt: *Runtime, provision: *Provision, after: AfterType, pslice: Pseudoslice) ![]const u8 { - const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); - const plain_buffer = pslice.get(0, config.socket_buffer_bytes); + pub const send_then_other_task = send_then(struct { + fn inner(rt: *Runtime, success: bool, provision: *Provision) !void { + const send_job = provision.job.send; + assert(send_job.after == .other); + const func: TaskFn(bool, *anyopaque) = @ptrCast(@alignCast(send_job.after.other.func)); + const ctx: *anyopaque = @ptrCast(@alignCast(send_job.after.other.ctx)); + try @call(.auto, func, .{ rt, success, ctx }); - switch (comptime security) { - .tls => { - const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType); - const tls_ptr: *TLSType = &tls_slice[provision.index]; - assert(tls_ptr.* != null); + if (!success) { + provision.job = .close; + try rt.net.close(provision, close_task, provision.socket); + } + } + }.inner); - const encrypted_buffer = tls_ptr.*.?.encrypt(plain_buffer) catch |e| { - log.err("{d} - encrypt failed: {any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(provision, close_task, provision.socket); - return error.TLSEncryptFailed; - }; + pub const send_then_recv_task = send_then(struct { + fn inner(rt: *Runtime, success: bool, provision: *Provision) !void { + if (!success) { + provision.job = .close; + try rt.net.close(provision, close_task, provision.socket); + return; + } - provision.job = .{ - .send = .{ - .after = after, - .slice = pslice, - .count = @intCast(plain_buffer.len), - .security = .{ - .tls = .{ - .encrypted = encrypted_buffer, - .encrypted_count = 0, - }, - }, - }, - }; + const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); - return encrypted_buffer; - }, - .plain => { - provision.job = .{ - .send = .{ - .after = after, - .slice = pslice, - .count = 0, - .security = .plain, - }, - }; + log.debug("{d} - queueing a new recv", .{provision.index}); + _ = provision.arena.reset(.{ + .retain_with_limit = config.connection_arena_bytes_retain, + }); - return plain_buffer; - }, - } + provision.response.clear(); + provision.recv_buffer.clear_retaining_capacity(); + provision.job = .{ .recv = .{ .count = 0 } }; + provision.buffer = try provision.recv_buffer.get_write_area(config.socket_buffer_bytes); + + try rt.net.recv( + provision, + recv_task, + provision.socket, + provision.buffer, + ); } + }.inner); - pub const send_then_other_task = send_then(struct { - fn inner(rt: *Runtime, success: bool, provision: *Provision) !void { - const send_job = provision.job.send; - assert(send_job.after == .other); - const func: TaskFn(bool, *anyopaque) = @ptrCast(@alignCast(send_job.after.other.func)); - const ctx: *anyopaque = @ptrCast(@alignCast(send_job.after.other.ctx)); - try @call(.auto, func, .{ rt, success, ctx }); + pub fn send_then(comptime func: TaskFn(bool, *Provision)) TaskFn(SendResult, *Provision) { + return struct { + fn send_then_inner(rt: *Runtime, result: SendResult, provision: *Provision) !void { + assert(provision.job == .send); + const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); - if (!success) { - provision.job = .close; - try rt.net.close(provision, close_task, provision.socket); - } - } - }.inner); + const send_count = result.unwrap() catch |e| { + // If the socket is closed. + if (e != error.ConnectionReset) { + log.warn("socket send failed: {}", .{e}); + } - pub const send_then_recv_task = send_then(struct { - fn inner(rt: *Runtime, success: bool, provision: *Provision) !void { - if (!success) { - provision.job = .close; - try rt.net.close(provision, close_task, provision.socket); + try @call(.auto, func, .{ rt, false, provision }); return; - } + }; - const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); + const send_job = &provision.job.send; - log.debug("{d} - queueing a new recv", .{provision.index}); - _ = provision.arena.reset(.{ - .retain_with_limit = config.connection_arena_bytes_retain, - }); + log.debug("{d} - send triggered", .{provision.index}); + log.debug("{d} - sent length: {d}", .{ provision.index, send_count }); - provision.recv_buffer.clear_retaining_capacity(); - provision.job = .{ .recv = .{ .count = 0 } }; - provision.buffer = try provision.recv_buffer.get_write_area(config.socket_buffer_bytes); + switch (config.security) { + .tls => { + assert(send_job.security == .tls); - try rt.net.recv( - provision, - recv_task, - provision.socket, - provision.buffer, - ); - } - }.inner); - - pub fn send_then(comptime func: TaskFn(bool, *Provision)) TaskFn(SendResult, *Provision) { - return struct { - fn send_then_inner(rt: *Runtime, result: SendResult, provision: *Provision) !void { - assert(provision.job == .send); - const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); - - const send_count = result.unwrap() catch |e| { - // If the socket is closed. - if (e != error.ConnectionReset) { - log.warn("socket send failed: {}", .{e}); - } + const tls_slice = rt.storage.get("__zzz_tls_slice", []?TLS); - try @call(.auto, func, .{ rt, false, provision }); - return; - }; - - const send_job = &provision.job.send; - - log.debug("{d} - send triggered", .{provision.index}); - log.debug("{d} - sent length: {d}", .{ provision.index, send_count }); - - switch (comptime security) { - .tls => { - assert(send_job.security == .tls); - - const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType); - - const job_tls = &send_job.security.tls; - job_tls.encrypted_count += send_count; - - if (job_tls.encrypted_count >= job_tls.encrypted.len) { - if (send_job.count >= send_job.slice.len) { - try @call(.auto, func, .{ rt, true, provision }); - } else { - // Queue a new chunk up for sending. - log.debug( - "{d} - sending next chunk starting at index {d}", - .{ provision.index, send_job.count }, - ); - - const inner_slice = send_job.slice.get( - send_job.count, - send_job.count + config.socket_buffer_bytes, - ); - - send_job.count += @intCast(inner_slice.len); - - const tls_ptr: *TLSType = &tls_slice[provision.index]; - assert(tls_ptr.* != null); - - const encrypted = tls_ptr.*.?.encrypt(inner_slice) catch |e| { - log.err("{d} - encrypt failed: {any}", .{ provision.index, e }); - provision.job = .close; - try rt.net.close(provision, close_task, provision.socket); - return error.TLSEncryptFailed; - }; - - job_tls.encrypted = encrypted; - job_tls.encrypted_count = 0; - - try rt.net.send( - provision, - send_then_inner, - provision.socket, - job_tls.encrypted, - ); - } - } else { - log.debug( - "{d} - sending next encrypted chunk starting at index {d}", - .{ provision.index, job_tls.encrypted_count }, - ); - - const remainder = job_tls.encrypted[job_tls.encrypted_count..]; - try rt.net.send( - provision, - send_then_inner, - provision.socket, - remainder, - ); - } - }, - .plain => { - assert(send_job.security == .plain); - send_job.count += send_count; + const job_tls = &send_job.security.tls; + job_tls.encrypted_count += send_count; + if (job_tls.encrypted_count >= job_tls.encrypted.len) { if (send_job.count >= send_job.slice.len) { try @call(.auto, func, .{ rt, true, provision }); } else { + // Queue a new chunk up for sending. log.debug( "{d} - sending next chunk starting at index {d}", .{ provision.index, send_job.count }, ); - const plain_buffer = send_job.slice.get( + const inner_slice = send_job.slice.get( send_job.count, send_job.count + config.socket_buffer_bytes, ); - log.debug("socket buffer size: {d}", .{config.socket_buffer_bytes}); + send_job.count += @intCast(inner_slice.len); - log.debug("{d} - chunk ends at: {d}", .{ - provision.index, - plain_buffer.len + send_job.count, - }); + const tls_ptr: *?TLS = &tls_slice[provision.index]; + assert(tls_ptr.* != null); + + const encrypted = tls_ptr.*.?.encrypt(inner_slice) catch |e| { + log.err("{d} - encrypt failed: {any}", .{ provision.index, e }); + provision.job = .close; + try rt.net.close(provision, close_task, provision.socket); + return error.TLSEncryptFailed; + }; + + job_tls.encrypted = encrypted; + job_tls.encrypted_count = 0; - // this is the problem. - // we are doing send then recv which is wrong!! - // - // we should be calling ourselves... try rt.net.send( provision, send_then_inner, provision.socket, - plain_buffer, + job_tls.encrypted, ); } - }, - } - } - }.send_then_inner; - } - - pub fn serve(self: *Self, router: *const Router, rt: *Runtime) !void { - if (self.addr == null) return error.ServerNotBinded; - const addr = self.addr.?; - try rt.storage.store_alloc("__zzz_is_unix", addr.any.family == std.posix.AF.UNIX); - - self.router = router; + } else { + log.debug( + "{d} - sending next encrypted chunk starting at index {d}", + .{ provision.index, job_tls.encrypted_count }, + ); + + const remainder = job_tls.encrypted[job_tls.encrypted_count..]; + try rt.net.send( + provision, + send_then_inner, + provision.socket, + remainder, + ); + } + }, + .plain => { + assert(send_job.security == .plain); + send_job.count += send_count; - log.info("server listening...", .{}); - log.info("security mode: {s}", .{@tagName(security)}); + if (send_job.count >= send_job.slice.len) { + try @call(.auto, func, .{ rt, true, provision }); + } else { + log.debug( + "{d} - sending next chunk starting at index {d}", + .{ provision.index, send_job.count }, + ); - const socket = try create_socket(addr); - try std.posix.bind(socket, &addr.any, addr.getOsSockLen()); - try std.posix.listen(socket, self.config.backlog_count); + const plain_buffer = send_job.slice.get( + send_job.count, + send_job.count + config.socket_buffer_bytes, + ); - const provision_pool = try rt.allocator.create(Pool(Provision)); - provision_pool.* = try Pool(Provision).init( - rt.allocator, - self.config.connection_count_max, - Provision.InitContext{ .allocator = self.allocator, .config = self.config }, - Provision.init_hook, - ); + log.debug("socket buffer size: {d}", .{config.socket_buffer_bytes}); - try rt.storage.store_ptr("__zzz_router", @constCast(router)); - try rt.storage.store_ptr("__zzz_provision_pool", provision_pool); - try rt.storage.store_alloc("__zzz_config", self.config); + log.debug("{d} - chunk ends at: {d}", .{ + provision.index, + plain_buffer.len + send_job.count, + }); - if (comptime security == .tls) { - const tls_slice = try rt.allocator.alloc( - TLSType, - self.config.connection_count_max, - ); - for (tls_slice) |*tls| { - tls.* = null; + // this is the problem. + // we are doing send then recv which is wrong!! + // + // we should be calling ourselves... + try rt.net.send( + provision, + send_then_inner, + provision.socket, + plain_buffer, + ); + } + }, } - - // since slices are fat pointers... - try rt.storage.store_alloc("__zzz_tls_slice", tls_slice); - try rt.storage.store_alloc("__zzz_tls_ctx", self.tls_ctx); } + }.send_then_inner; + } - try rt.storage.store_alloc("__zzz_server_socket", socket); - try rt.storage.store_alloc("__zzz_accept_queued", true); + pub fn serve(self: *Self, router: *const Router, rt: *Runtime) !void { + if (self.addr == null) return error.ServerNotBinded; + const addr = self.addr.?; + try rt.storage.store_alloc("__zzz_is_unix", addr.any.family == std.posix.AF.UNIX); - try rt.net.accept(socket, accept_task, socket); - } + self.router = router; - pub fn clean(rt: *Runtime) !void { - // clean up socket. - const server_socket = rt.storage.get("__zzz_server_socket", std.posix.socket_t); - std.posix.close(server_socket); + log.info("server listening...", .{}); + log.info("security mode: {s}", .{@tagName(self.config.security)}); - // clean up provision pool. - const provision_pool = rt.storage.get_ptr("__zzz_provision_pool", Pool(Provision)); - provision_pool.deinit(rt.allocator, Provision.deinit_hook); - rt.allocator.destroy(provision_pool); + const socket = try create_socket(addr); + try std.posix.bind(socket, &addr.any, addr.getOsSockLen()); + try std.posix.listen(socket, self.config.backlog_count); - // clean up TLS. - if (comptime security == .tls) { - const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType); - rt.allocator.free(tls_slice); - } - } + const provision_pool = try rt.allocator.create(Pool(Provision)); + provision_pool.* = try Pool(Provision).init( + rt.allocator, + self.config.connection_count_max, + Provision.InitContext{ .allocator = self.allocator, .config = self.config }, + Provision.init_hook, + ); - fn create_socket(addr: std.net.Address) !std.posix.socket_t { - const protocol: u32 = if (addr.any.family == std.posix.AF.UNIX) - 0 - else - std.posix.IPPROTO.TCP; + try rt.storage.store_ptr("__zzz_router", @constCast(router)); + try rt.storage.store_ptr("__zzz_provision_pool", provision_pool); + try rt.storage.store_alloc("__zzz_config", self.config); - const socket = try std.posix.socket( - addr.any.family, - std.posix.SOCK.STREAM | std.posix.SOCK.CLOEXEC | std.posix.SOCK.NONBLOCK, - protocol, + if (self.config.security == .tls) { + const tls_slice = try rt.allocator.alloc( + ?TLS, + self.config.connection_count_max, ); + for (tls_slice) |*tls| { + tls.* = null; + } - log.debug("socket | t: {s} v: {any}", .{ @typeName(std.posix.socket_t), socket }); + // since slices are fat pointers... + try rt.storage.store_alloc("__zzz_tls_slice", tls_slice); + try rt.storage.store_alloc("__zzz_tls_ctx", self.tls_ctx); + } - if (@hasDecl(std.posix.SO, "REUSEPORT_LB")) { - try std.posix.setsockopt( - socket, - std.posix.SOL.SOCKET, - std.posix.SO.REUSEPORT_LB, - &std.mem.toBytes(@as(c_int, 1)), - ); - } else if (@hasDecl(std.posix.SO, "REUSEPORT")) { - try std.posix.setsockopt( - socket, - std.posix.SOL.SOCKET, - std.posix.SO.REUSEPORT, - &std.mem.toBytes(@as(c_int, 1)), - ); - } else { - try std.posix.setsockopt( - socket, - std.posix.SOL.SOCKET, - std.posix.SO.REUSEADDR, - &std.mem.toBytes(@as(c_int, 1)), - ); - } + try rt.storage.store_alloc("__zzz_server_socket", socket); + try rt.storage.store_alloc("__zzz_accept_queued", true); + + try rt.net.accept(socket, accept_task, socket); + } - return socket; + pub fn clean(rt: *Runtime) !void { + // clean up socket. + const server_socket = rt.storage.get("__zzz_server_socket", std.posix.socket_t); + std.posix.close(server_socket); + + // clean up provision pool. + const provision_pool = rt.storage.get_ptr("__zzz_provision_pool", Pool(Provision)); + provision_pool.deinit(rt.allocator, Provision.deinit_hook); + rt.allocator.destroy(provision_pool); + + // clean up TLS. + const config = rt.storage.get_const_ptr("__zzz_config", ServerConfig); + if (config.security == .tls) { + const tls_slice = rt.storage.get("__zzz_tls_slice", []?TLS); + rt.allocator.free(tls_slice); } + } - fn route_and_respond(runtime: *Runtime, p: *Provision, router: *const Router) !RecvStatus { - route: { - const found = try router.get_route_from_host(p.request.uri.?, p.captures, &p.queries); - const optional_handler = found.route.get_handler(p.request.method.?); - - if (optional_handler) |handler| { - const context: *Context = try p.arena.allocator().create(Context); - context.* = .{ - .allocator = p.arena.allocator(), - .runtime = runtime, - .state = router.state, - .route = &found.route, - .request = &p.request, - .response = &p.response, - .captures = found.captures, - .queries = found.queries, - .provision = p, - }; + fn create_socket(addr: std.net.Address) !std.posix.socket_t { + const protocol: u32 = if (addr.any.family == std.posix.AF.UNIX) + 0 + else + std.posix.IPPROTO.TCP; + + const socket = try std.posix.socket( + addr.any.family, + std.posix.SOCK.STREAM | std.posix.SOCK.CLOEXEC | std.posix.SOCK.NONBLOCK, + protocol, + ); + + log.debug("socket | t: {s} v: {any}", .{ @typeName(std.posix.socket_t), socket }); + + if (@hasDecl(std.posix.SO, "REUSEPORT_LB")) { + try std.posix.setsockopt( + socket, + std.posix.SOL.SOCKET, + std.posix.SO.REUSEPORT_LB, + &std.mem.toBytes(@as(c_int, 1)), + ); + } else if (@hasDecl(std.posix.SO, "REUSEPORT")) { + try std.posix.setsockopt( + socket, + std.posix.SOL.SOCKET, + std.posix.SO.REUSEPORT, + &std.mem.toBytes(@as(c_int, 1)), + ); + } else { + try std.posix.setsockopt( + socket, + std.posix.SOL.SOCKET, + std.posix.SO.REUSEADDR, + &std.mem.toBytes(@as(c_int, 1)), + ); + } + + return socket; + } - @call(.auto, handler, .{ + fn route_and_respond(runtime: *Runtime, p: *Provision, router: *const Router) !RecvStatus { + route: { + const found = try router.get_bundle_from_host(p.request.uri.?, p.captures, &p.queries); + const optional_handler = found.bundle.route.get_handler(p.request.method.?); + + if (optional_handler) |h_with_data| { + const next: *Next = try p.arena.allocator().create(Next); + + const context: *Context = try p.arena.allocator().create(Context); + context.* = .{ + .allocator = p.arena.allocator(), + .runtime = runtime, + .request = &p.request, + .response = &p.response, + .captures = found.captures, + .queries = found.queries, + .provision = p, + .next = next, + }; + + next.* = .{ + .ctx = context, + .stage = .pre, + .pre_chain = .{ + .chain = found.bundle.pre, + .handler = h_with_data, + }, + .post_chain = found.bundle.post, + }; + + if (found.bundle.pre.len > 0) { + try next.run(); + return .spawned; + } else { + @call(.auto, h_with_data.handler, .{ context, + h_with_data.data, }) catch |e| { log.err("\"{s}\" handler failed with error: {}", .{ p.request.uri.?, e }); p.response.set(.{ @@ -860,251 +885,251 @@ pub fn Server(comptime security: Security, comptime AppState: type) type { }; return .spawned; - } else { - // If we match the route but not the method. + } + } else { + // If we match the route but not the method. + p.response.set(.{ + .status = .@"Method Not Allowed", + .mime = Mime.HTML, + .body = "405 Method Not Allowed", + }); + + // We also need to add to Allow header. + // This uses the connection's arena to allocate 64 bytes. + const allowed = found.bundle.route.get_allowed(p.arena.allocator()) catch { p.response.set(.{ - .status = .@"Method Not Allowed", + .status = .@"Internal Server Error", .mime = Mime.HTML, - .body = "405 Method Not Allowed", + .body = "", }); - // We also need to add to Allow header. - // This uses the connection's arena to allocate 64 bytes. - const allowed = found.route.get_allowed(p.arena.allocator()) catch { - p.response.set(.{ - .status = .@"Internal Server Error", - .mime = Mime.HTML, - .body = "", - }); - - break :route; - }; - - p.response.headers.put_assume_capacity("Allow", allowed); break :route; - } - } + }; - if (p.response.status == .Kill) { - return .kill; + p.response.headers.put_assume_capacity("Allow", allowed); + break :route; } - - return try raw_respond(p); } - fn on_recv( - // How much we just received - recv_count: usize, - rt: *Runtime, - provision: *Provision, - router: *const Router, - config: *const ServerConfig, - ) !RecvStatus { - var stage = provision.stage; - const job = provision.job.recv; - - if (job.count >= config.request_bytes_max) { - provision.response.set(.{ - .status = .@"Content Too Large", - .mime = Mime.HTML, - .body = "Request was too large", - }); - - return try raw_respond(provision); - } - - switch (stage) { - .header => { - // this should never underflow if things are working correctly. - const starting_length = provision.recv_buffer.len - recv_count; - const start = starting_length -| 4; - - const header_ends = std.mem.lastIndexOf( - u8, - provision.recv_buffer.subslice(.{ .start = start }), - "\r\n\r\n", - ); - - // Basically, this means we haven't finished processing the header. - if (header_ends == null) { - log.debug("{d} - header doesn't end in this chunk, continue", .{provision.index}); - return .recv; - } + if (p.response.status == .Kill) { + return .kill; + } - log.debug("{d} - parsing header", .{provision.index}); - // We add start to account for the fact that we are searching - // starting at the index of start. - // The +4 is to account for the slice we match. - const header_end: usize = header_ends.? + start + 4; - provision.request.parse_headers( - provision.recv_buffer.subslice(.{ .end = header_end }), - .{ - .request_bytes_max = config.request_bytes_max, - .request_uri_bytes_max = config.request_uri_bytes_max, - }, - ) catch |e| { - switch (e) { - HTTPError.ContentTooLarge => { - provision.response.set(.{ - .status = .@"Content Too Large", - .mime = Mime.HTML, - .body = "Request was too large", - }); - }, - HTTPError.TooManyHeaders => { - provision.response.set(.{ - .status = .@"Request Header Fields Too Large", - .mime = Mime.HTML, - .body = "Too Many Headers", - }); - }, - HTTPError.MalformedRequest => { - provision.response.set(.{ - .status = .@"Bad Request", - .mime = Mime.HTML, - .body = "Malformed Request", - }); - }, - HTTPError.URITooLong => { - provision.response.set(.{ - .status = .@"URI Too Long", - .mime = Mime.HTML, - .body = "URI Too Long", - }); - }, - HTTPError.InvalidMethod => { - provision.response.set(.{ - .status = .@"Not Implemented", - .mime = Mime.HTML, - .body = "Not Implemented", - }); - }, - HTTPError.HTTPVersionNotSupported => { - provision.response.set(.{ - .status = .@"HTTP Version Not Supported", - .mime = Mime.HTML, - .body = "HTTP Version Not Supported", - }); - }, - } + return try raw_respond(p); + } - return raw_respond(provision) catch unreachable; - }; + fn on_recv( + // How much we just received + recv_count: usize, + rt: *Runtime, + provision: *Provision, + router: *const Router, + config: *const ServerConfig, + ) !RecvStatus { + var stage = provision.stage; + const job = provision.job.recv; + + if (job.count >= config.request_bytes_max) { + provision.response.set(.{ + .status = .@"Content Too Large", + .mime = Mime.HTML, + .body = "Request was too large", + }); + + return try raw_respond(provision); + } - // Logging information about Request. - log.info("{d} - \"{s} {s}\" {s}", .{ - provision.index, - @tagName(provision.request.method.?), - provision.request.uri.?, - provision.request.headers.get("User-Agent") orelse "N/A", - }); + switch (stage) { + .header => { + // this should never underflow if things are working correctly. + const starting_length = provision.recv_buffer.len - recv_count; + const start = starting_length -| 4; - // HTTP/1.1 REQUIRES a Host header to be present. - const is_http_1_1 = provision.request.version == .@"HTTP/1.1"; - const is_host_present = provision.request.headers.get("Host") != null; - if (is_http_1_1 and !is_host_present) { - provision.response.set(.{ - .status = .@"Bad Request", - .mime = Mime.HTML, - .body = "Missing \"Host\" Header", - }); + const header_ends = std.mem.lastIndexOf( + u8, + provision.recv_buffer.subslice(.{ .start = start }), + "\r\n\r\n", + ); - return try raw_respond(provision); - } + // Basically, this means we haven't finished processing the header. + if (header_ends == null) { + log.debug("{d} - header doesn't end in this chunk, continue", .{provision.index}); + return .recv; + } - if (!provision.request.expect_body()) { - return try route_and_respond(rt, provision, router); + log.debug("{d} - parsing header", .{provision.index}); + // We add start to account for the fact that we are searching + // starting at the index of start. + // The +4 is to account for the slice we match. + const header_end: usize = header_ends.? + start + 4; + provision.request.parse_headers( + provision.recv_buffer.subslice(.{ .end = header_end }), + .{ + .request_bytes_max = config.request_bytes_max, + .request_uri_bytes_max = config.request_uri_bytes_max, + }, + ) catch |e| { + switch (e) { + HTTPError.ContentTooLarge => { + provision.response.set(.{ + .status = .@"Content Too Large", + .mime = Mime.HTML, + .body = "Request was too large", + }); + }, + HTTPError.TooManyHeaders => { + provision.response.set(.{ + .status = .@"Request Header Fields Too Large", + .mime = Mime.HTML, + .body = "Too Many Headers", + }); + }, + HTTPError.MalformedRequest => { + provision.response.set(.{ + .status = .@"Bad Request", + .mime = Mime.HTML, + .body = "Malformed Request", + }); + }, + HTTPError.URITooLong => { + provision.response.set(.{ + .status = .@"URI Too Long", + .mime = Mime.HTML, + .body = "URI Too Long", + }); + }, + HTTPError.InvalidMethod => { + provision.response.set(.{ + .status = .@"Not Implemented", + .mime = Mime.HTML, + .body = "Not Implemented", + }); + }, + HTTPError.HTTPVersionNotSupported => { + provision.response.set(.{ + .status = .@"HTTP Version Not Supported", + .mime = Mime.HTML, + .body = "HTTP Version Not Supported", + }); + }, } - // Everything after here is a Request that is expecting a body. - const content_length = blk: { - const length_string = provision.request.headers.get("Content-Length") orelse { - break :blk 0; - }; - - break :blk try std.fmt.parseInt(u32, length_string, 10); - }; + return raw_respond(provision) catch unreachable; + }; - if (header_end < provision.recv_buffer.len) { - const difference = provision.recv_buffer.len - header_end; - if (difference == content_length) { - // Whole Body - log.debug("{d} - got whole body with header", .{provision.index}); - const body_end = header_end + difference; - provision.request.set(.{ - .body = provision.recv_buffer.subslice(.{ - .start = header_end, - .end = body_end, - }), - }); - return try route_and_respond(rt, provision, router); - } else { - // Partial Body - log.debug("{d} - got partial body with header", .{provision.index}); - stage = .{ .body = header_end }; - return .recv; - } - } else if (header_end == provision.recv_buffer.len) { - // Body of length 0 probably or only got header. - if (content_length == 0) { - log.debug("{d} - got body of length 0", .{provision.index}); - // Body of Length 0. - provision.request.set(.{ .body = "" }); - return try route_and_respond(rt, provision, router); - } else { - // Got only header. - log.debug("{d} - got all header aka no body", .{provision.index}); - stage = .{ .body = header_end }; - return .recv; - } - } else unreachable; - }, + // Logging information about Request. + log.info("{d} - \"{s} {s}\" {s}", .{ + provision.index, + @tagName(provision.request.method.?), + provision.request.uri.?, + provision.request.headers.get("User-Agent") orelse "N/A", + }); - .body => |header_end| { - // We should ONLY be here if we expect there to be a body. - assert(provision.request.expect_body()); - log.debug("{d} - body matching", .{provision.index}); + // HTTP/1.1 REQUIRES a Host header to be present. + const is_http_1_1 = provision.request.version == .@"HTTP/1.1"; + const is_host_present = provision.request.headers.get("Host") != null; + if (is_http_1_1 and !is_host_present) { + provision.response.set(.{ + .status = .@"Bad Request", + .mime = Mime.HTML, + .body = "Missing \"Host\" Header", + }); - const content_length = blk: { - const length_string = provision.request.headers.get("Content-Length") orelse { - provision.response.set(.{ - .status = .@"Length Required", - .mime = Mime.HTML, - .body = "", - }); + return try raw_respond(provision); + } - return try raw_respond(provision); - }; + if (!provision.request.expect_body()) { + return try route_and_respond(rt, provision, router); + } - break :blk try std.fmt.parseInt(u32, length_string, 10); + // Everything after here is a Request that is expecting a body. + const content_length = blk: { + const length_string = provision.request.headers.get("Content-Length") orelse { + break :blk 0; }; - // We factor in the length of the headers. - const request_length = header_end + content_length; - - // If this body will be too long, abort early. - if (request_length > config.request_bytes_max) { - provision.response.set(.{ - .status = .@"Content Too Large", - .mime = Mime.HTML, - .body = "", - }); - return try raw_respond(provision); - } + break :blk try std.fmt.parseInt(u32, length_string, 10); + }; - if (job.count >= request_length) { + if (header_end < provision.recv_buffer.len) { + const difference = provision.recv_buffer.len - header_end; + if (difference == content_length) { + // Whole Body + log.debug("{d} - got whole body with header", .{provision.index}); + const body_end = header_end + difference; provision.request.set(.{ .body = provision.recv_buffer.subslice(.{ .start = header_end, - .end = request_length, + .end = body_end, }), }); return try route_and_respond(rt, provision, router); } else { + // Partial Body + log.debug("{d} - got partial body with header", .{provision.index}); + stage = .{ .body = header_end }; return .recv; } - }, - } + } else if (header_end == provision.recv_buffer.len) { + // Body of length 0 probably or only got header. + if (content_length == 0) { + log.debug("{d} - got body of length 0", .{provision.index}); + // Body of Length 0. + provision.request.set(.{ .body = "" }); + return try route_and_respond(rt, provision, router); + } else { + // Got only header. + log.debug("{d} - got all header aka no body", .{provision.index}); + stage = .{ .body = header_end }; + return .recv; + } + } else unreachable; + }, + + .body => |header_end| { + // We should ONLY be here if we expect there to be a body. + assert(provision.request.expect_body()); + log.debug("{d} - body matching", .{provision.index}); + + const content_length = blk: { + const length_string = provision.request.headers.get("Content-Length") orelse { + provision.response.set(.{ + .status = .@"Length Required", + .mime = Mime.HTML, + .body = "", + }); + + return try raw_respond(provision); + }; + + break :blk try std.fmt.parseInt(u32, length_string, 10); + }; + + // We factor in the length of the headers. + const request_length = header_end + content_length; + + // If this body will be too long, abort early. + if (request_length > config.request_bytes_max) { + provision.response.set(.{ + .status = .@"Content Too Large", + .mime = Mime.HTML, + .body = "", + }); + return try raw_respond(provision); + } + + if (job.count >= request_length) { + provision.request.set(.{ + .body = provision.recv_buffer.subslice(.{ + .start = header_end, + .end = request_length, + }), + }); + return try route_and_respond(rt, provision, router); + } else { + return .recv; + } + }, } - }; -} + } +}; diff --git a/src/http/sse.zig b/src/http/sse.zig index 58e4195..c15cf37 100644 --- a/src/http/sse.zig +++ b/src/http/sse.zig @@ -3,7 +3,7 @@ const std = @import("std"); const Pseudoslice = @import("../core/pseudoslice.zig").Pseudoslice; const Provision = @import("provision.zig").Provision; -const _Context = @import("context.zig").Context; +const Context = @import("context.zig").Context; const TaskFn = @import("tardy").TaskFn; const Runtime = @import("tardy").Runtime; @@ -15,47 +15,44 @@ const SSEMessage = struct { retry: ?u64 = null, }; -pub fn SSE(comptime Server: type, comptime AppState: type) type { - const Context = _Context(Server, AppState); - return struct { - const Self = @This(); - context: *Context, - allocator: std.mem.Allocator, - runtime: *Runtime, - - pub fn send( - self: *Self, - options: SSEMessage, - then_context: anytype, - then: TaskFn(bool, @TypeOf(then_context)), - ) !void { - var index: usize = 0; - const buffer = self.context.provision.buffer; - - if (options.id) |id| { - const buf = try std.fmt.bufPrint(buffer[index..], "id: {s}\n", .{id}); - index += buf.len; - } - - if (options.event) |event| { - const buf = try std.fmt.bufPrint(buffer[index..], "event: {s}\n", .{event}); - index += buf.len; - } - - if (options.data) |data| { - const buf = try std.fmt.bufPrint(buffer[index..], "data: {s}\n", .{data}); - index += buf.len; - } - - if (options.retry) |retry| { - const buf = try std.fmt.bufPrint(buffer[index..], "retry: {d}\n", .{retry}); - index += buf.len; - } - - buffer[index] = '\n'; - index += 1; - - try self.context.send_then(buffer[0..index], then_context, then); +pub const SSE = struct { + const Self = @This(); + context: *Context, + allocator: std.mem.Allocator, + runtime: *Runtime, + + pub fn send( + self: *Self, + options: SSEMessage, + then_context: anytype, + then: TaskFn(bool, @TypeOf(then_context)), + ) !void { + var index: usize = 0; + const buffer = self.context.provision.buffer; + + if (options.id) |id| { + const buf = try std.fmt.bufPrint(buffer[index..], "id: {s}\n", .{id}); + index += buf.len; } - }; -} + + if (options.event) |event| { + const buf = try std.fmt.bufPrint(buffer[index..], "event: {s}\n", .{event}); + index += buf.len; + } + + if (options.data) |data| { + const buf = try std.fmt.bufPrint(buffer[index..], "data: {s}\n", .{data}); + index += buf.len; + } + + if (options.retry) |retry| { + const buf = try std.fmt.bufPrint(buffer[index..], "retry: {d}\n", .{retry}); + index += buf.len; + } + + buffer[index] = '\n'; + index += 1; + + try self.context.send_then(buffer[0..index], then_context, then); + } +}; diff --git a/src/tls/bear.zig b/src/tls/bear.zig index 6912534..c450de0 100644 --- a/src/tls/bear.zig +++ b/src/tls/bear.zig @@ -616,14 +616,14 @@ pub const TLS = struct { if ((after_state & bearssl.BR_SSL_SENDAPP) != 0) break :blk .complete; if ((after_state & bearssl.BR_SSL_SENDREC) != 0) { - var length: usize = 0; + var length: usize = undefined; const buf = bearssl.br_ssl_engine_sendrec_buf(engine, &length); log.debug("send rec buffer: address={*}, length={d}", .{ buf, length }); break :blk .{ .send = buf[0..length] }; } if ((after_state & bearssl.BR_SSL_RECVREC) != 0) { - var length: usize = 0; + var length: usize = undefined; const buf = bearssl.br_ssl_engine_recvrec_buf(engine, &length); log.debug("recv rec buffer: address={*}, length={d}", .{ buf, length }); break :blk .{ .recv = buf[0..length] }; diff --git a/src/unit_test.zig b/src/unit_test.zig index 44b1c06..96fdadb 100644 --- a/src/unit_test.zig +++ b/src/unit_test.zig @@ -27,7 +27,6 @@ test "zzz unit tests" { testing.refAllDecls(@import("./http/router.zig")); testing.refAllDecls(@import("./http/router/route.zig")); testing.refAllDecls(@import("./http/router/routing_trie.zig")); - testing.refAllDecls(@import("./http/router/token_hash_map.zig")); // TLS testing.refAllDecls(@import("./tls/bear.zig"));