Skip to content

Commit

Permalink
feat: cache resolved DNS queries
Browse files Browse the repository at this point in the history
  • Loading branch information
mookums committed Nov 26, 2024
1 parent d2ee562 commit fbae73b
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 61 deletions.
2 changes: 0 additions & 2 deletions build.zig.zon
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
.minimum_zig_version = "0.13.0",
.dependencies = .{
.tardy = .{
//.url = "git+https://github.com/mookums/tardy?ref=main#feae50e9bf60ac13f1d5d14a7d3346fcfe442fa8",
//.hash = "122093168263d66adc14bbee5de6aa0d4a2600e7299cad2b66175feeb6ce8aaef173",
.path = "../tardy",
},
.bearssl = .{
Expand Down
11 changes: 6 additions & 5 deletions examples/proxy/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ fn fetch_task(_: *Runtime, response: ?*const http.Response, ctx: *Context) !void

pub fn main() !void {
const host: []const u8 = "0.0.0.0";
const port: u16 = 9862;
const port: u16 = 9863;
//const proxy_path = "http://httpforever.com";
const proxy_path = "http://http.badssl.com";
//const proxy_path = "http://http.badssl.com";
const proxy_path = "http://localhost:9862";

var gpa = std.heap.GeneralPurposeAllocator(.{}){};
const allocator = gpa.allocator();
Expand All @@ -50,7 +51,7 @@ pub fn main() !void {
// will spawn our runtimes.
var t = try Tardy.init(.{
.allocator = allocator,
.threading = .single,
.threading = .auto,
});
defer t.deinit();

Expand Down Expand Up @@ -82,8 +83,8 @@ pub fn main() !void {
&router,
struct {
fn entry(rt: *Runtime, r: *const Router) !void {
var server = Server.init(.{ .allocator = rt.allocator });
try server.bind(host, port);
var server = Server.init(rt.allocator, .{});
try server.bind(.{ .ip = .{ .host = host, .port = port } });
try server.serve(r, rt);

// this kills it after a set delay, allowing the GPA to report any leaks.
Expand Down
169 changes: 118 additions & 51 deletions src/http/client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,15 @@ const Stage = union(enum) {
connect: struct {
ip: []const u8,
port: u16,
address_list: *std.net.AddressList,
address_index: usize,
status: union(enum) {
// If uncached, we must find it.
uncached: struct {
address_list: *std.net.AddressList,
address_index: usize,
},
// If cached, assume the ip/port is correct.
cached,
},
},
send: usize,
recv: union(enum) {
Expand All @@ -121,6 +128,7 @@ const Stage = union(enum) {
const RequestContext = struct {
allocator: std.mem.Allocator,
info: URLInfo,
client: *Client,
request: *Request,
response: *Response,
socket: std.posix.socket_t,
Expand All @@ -144,35 +152,52 @@ fn connect_task(rt: *Runtime, socket: std.posix.socket_t, ctx: *RequestContext)
// Attempts to connect to all of the returned addresses.
if (!Cross.socket.is_valid(socket)) {
std.posix.close(ctx.socket);
stage.address_index += 1;

log.debug(
"address idx: {d} | address list len: {d}",
.{ stage.address_index, stage.address_list.addrs.len },
);
if (stage.address_index >= stage.address_list.addrs.len) {
return;
}

const next = stage.address_list.addrs[stage.address_index];
ctx.socket = try create_socket(next);
ctx.allocator.free(stage.ip);
stage.ip = try get_ip_from_address(ctx.allocator, next);
log.debug("next ip: {s}", .{stage.ip});

try rt.net.connect(
ctx,
connect_task,
ctx.socket,
stage.ip,
stage.port,
);
switch (stage.status) {
.cached => {
// should remove from the cache on failure...
// should we fall back to uncached if we have bad data in the cache?
_ = ctx.client.cache.remove(ctx.info.host);
@call(.auto, ctx.then, .{ rt, null, ctx.then_ctx }) catch unreachable;
return;
},
.uncached => |*inner| {
inner.address_index += 1;

log.debug(
"address idx: {d} | address list len: {d}",
.{ inner.address_index, inner.address_list.addrs.len },
);
if (inner.address_index >= inner.address_list.addrs.len) {
return;
}

return;
const next = inner.address_list.addrs[inner.address_index];
ctx.socket = try create_socket(next);
ctx.allocator.free(stage.ip);
stage.ip = try get_ip_from_address(ctx.allocator, next);
log.debug("next ip: {s}", .{stage.ip});

try rt.net.connect(
ctx,
connect_task,
ctx.socket,
stage.ip,
stage.port,
);
return;
},
}
}

stage.address_list.deinit();
ctx.allocator.free(stage.ip);
switch (stage.status) {
.cached => {},
.uncached => |inner| {
inner.address_list.deinit();
try ctx.client.cache.put(ctx.info.host, try std.net.Address.parseIp(stage.ip, stage.port));
ctx.allocator.free(stage.ip);
},
}

ctx.stage = .{ .send = 0 };
log.debug("sending from {d} to {d}", .{ 0, ctx.buffer.len });
Expand Down Expand Up @@ -377,11 +402,12 @@ fn get_ip_from_address(allocator: std.mem.Allocator, address: std.net.Address) !
const ClientRequest = struct {
allocator: std.mem.Allocator,
url: []const u8,
client: *Client,
runtime: *Runtime,
request: *Request,
response: *Response,

pub fn init(allocator: std.mem.Allocator, runtime: *Runtime, url: []const u8) !ClientRequest {
pub fn init(allocator: std.mem.Allocator, client: *Client, runtime: *Runtime, url: []const u8) !ClientRequest {
const request = try allocator.create(Request);
request.* = try Request.init(allocator, 32);

Expand All @@ -390,6 +416,7 @@ const ClientRequest = struct {

return .{
.allocator = allocator,
.client = client,
.runtime = runtime,
.request = request,
.response = response,
Expand All @@ -413,19 +440,50 @@ const ClientRequest = struct {

const info = try parse_url(self.url);
self.request.path = info.path;
try self.request.headers.add("Host", info.host_with_port);

const list = if (info.host[0] == '[')
// If ipv6, strip the brackets.
try std.net.getAddressList(self.allocator, info.host[1 .. info.host.len - 1], info.port)
else
try std.net.getAddressList(self.allocator, info.host, info.port);

const first: std.net.Address = list.addrs[0];
log.debug("First IP: {}", .{first});

const socket = try create_socket(first);
const ip = try get_ip_from_address(self.allocator, first);
self.request.headers.putAssumeCapacity("Host", info.host_with_port);

var socket: std.posix.fd_t = undefined;

const stage: Stage = blk: {
const cached = self.client.cache.get(info.host);
if (cached) |address| {
log.debug("cache hit on {s}", .{info.host});
const ip = try get_ip_from_address(self.allocator, address);
socket = try create_socket(address);
break :blk Stage{
.connect = .{
.ip = ip,
.port = info.port,
.status = .cached,
},
};
} else {
const list = if (info.host[0] == '[')
// If ipv6, strip the brackets.
try std.net.getAddressList(self.allocator, info.host[1 .. info.host.len - 1], info.port)
else
try std.net.getAddressList(self.allocator, info.host, info.port);

const first: std.net.Address = list.addrs[0];
log.debug("First IP: {}", .{first});

const ip = try get_ip_from_address(self.allocator, first);
socket = try create_socket(first);

break :blk Stage{
.connect = .{
.ip = ip,
.port = info.port,
.status = .{
.uncached = .{
.address_list = list,
.address_index = 0,
},
},
},
};
}
};

const buffer = try self.allocator.alloc(u8, 2048);
const headers = try self.request.headers_into_buffer(buffer, self.request.body.len);
Expand All @@ -435,30 +493,26 @@ const ClientRequest = struct {
context.* = RequestContext{
.allocator = self.allocator,
.info = info,
.client = self.client,
.request = self.request,
.response = self.response,
.pseudo = Pseudoslice.init(headers, self.request.body, buffer),
.stage = .{ .connect = .{
.ip = ip,
.port = info.port,
.address_list = list,
.address_index = 0,
} },
.stage = stage,
.socket = socket,
.buffer = buffer,
.recv_buffer = try std.ArrayListUnmanaged(u8).initCapacity(self.allocator, 0),
.then = @ptrCast(then),
.then_ctx = wrap(usize, then_ctx),
};

log.debug("ip: {s}", .{ip});
log.debug("ip: {s}", .{stage.connect.ip});

// queue connect.
try self.runtime.net.connect(
context,
connect_task,
socket,
ip,
stage.connect.ip,
info.port,
);
}
Expand All @@ -472,16 +526,29 @@ pub const Client = struct {
allocator: std.mem.Allocator,
options: ClientOptions,

// Client needs to have an internal pool of ClientRequests to pass around.

// We should also have a cache of connections, meaning that multiple connections
// to the same host should use the same connection, instead of creating a ton?

// This should match the host of the URL to the resolved IP.
// Client also needs to have a cache of DNS resolved hostnames...
cache: std.StringHashMap(std.net.Address),

pub fn init(allocator: std.mem.Allocator, options: ClientOptions) Client {
return .{ .allocator = allocator, .options = options };
return .{
.allocator = allocator,
.options = options,
.cache = std.StringHashMap(std.net.Address).init(allocator),
};
}

pub fn deinit(self: *Client) void {
_ = self;
}

pub fn get(self: *Client, runtime: *Runtime, url: []const u8) !ClientRequest {
var builder = try ClientRequest.init(self.allocator, runtime, url);
var builder = try ClientRequest.init(self.allocator, self, runtime, url);
builder.request.method = .GET;
return builder;
}
Expand Down
2 changes: 1 addition & 1 deletion src/http/request.zig
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ pub const Request = struct {
index += 24;

// Headers
var iter = self.headers.map.iterator();
var iter = self.headers.iterator();
while (iter.next()) |entry| {
std.mem.copyForwards(u8, buffer[index..], entry.key_ptr.*);
index += entry.key_ptr.len;
Expand Down
4 changes: 2 additions & 2 deletions src/http/response.zig
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub const Response = struct {
const ResponseParseOptions = struct { size_response_max: u32 };

pub fn parse_headers(self: *Response, bytes: []const u8, options: ResponseParseOptions) !void {
self.headers.clear();
self.headers.clearRetainingCapacity();
var total_size: u32 = 0;
var lines = std.mem.tokenizeAny(u8, bytes, "\r\n");

Expand Down Expand Up @@ -69,7 +69,7 @@ pub const Response = struct {
const key = header_iter.next() orelse return error.MalformedResponse;
const value = std.mem.trimLeft(u8, header_iter.rest(), &.{' '});
if (value.len == 0) return error.MalformedResponse;
try self.headers.add(key, value);
self.headers.putAssumeCapacity(key, value);
}
}

Expand Down
1 change: 1 addition & 0 deletions src/http/server.zig
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ pub fn Server(comptime security: Security) type {
// HTTP/1.1 REQUIRES a Host header to be present.
const is_http_1_1 = provision.request.version == .@"HTTP/1.1";
const is_host_present = provision.request.headers.get("Host") != null;

if (is_http_1_1 and !is_host_present) {
provision.response.set(.{
.status = .@"Bad Request",
Expand Down

0 comments on commit fbae73b

Please sign in to comment.