diff --git a/src/core/database/database.zig b/src/core/database/database.zig index fb2389c..b7fda35 100644 --- a/src/core/database/database.zig +++ b/src/core/database/database.zig @@ -32,14 +32,14 @@ pub const MintDatabase = struct { getMintQuoteFn: *const fn (ptr: *anyopaque, gpa: std.mem.Allocator, quote_id: zul.UUID) anyerror!?MintQuote, updateMintQuoteStateFn: *const fn (ptr: *anyopaque, quote_id: zul.UUID, state: nuts.nut04.QuoteState) anyerror!nuts.nut04.QuoteState, getMintQuotesFn: *const fn (ptr: *anyopaque, allocator: std.mem.Allocator) anyerror!std.ArrayList(MintQuote), - getMintQuoteByRequestLookupIdFn: *const fn (ptr: *anyopaque, gpa: std.mem.Allocator, request_lookup_id: zul.UUID) anyerror!?MintQuote, + getMintQuoteByRequestLookupIdFn: *const fn (ptr: *anyopaque, gpa: std.mem.Allocator, request_lookup_id: []const u8) anyerror!?MintQuote, getMintQuoteByRequestFn: *const fn (ptr: *anyopaque, gpa: std.mem.Allocator, request: []const u8) anyerror!?MintQuote, removeMintQuoteStateFn: *const fn (ptr: *anyopaque, quote_id: zul.UUID) anyerror!void, addMeltQuoteFn: *const fn (ptr: *anyopaque, quote: MeltQuote) anyerror!void, getMeltQuoteFn: *const fn (ptr: *anyopaque, gpa: std.mem.Allocator, quote_id: zul.UUID) anyerror!?MeltQuote, updateMeltQuoteStateFn: *const fn (ptr: *anyopaque, quote_id: zul.UUID, state: nuts.nut05.QuoteState) anyerror!nuts.nut05.QuoteState, getMeltQuotesFn: *const fn (ptr: *anyopaque, gpa: std.mem.Allocator) anyerror!std.ArrayList(MeltQuote), - getMeltQuoteByRequestLookupIdFn: *const fn (ptr: *anyopaque, gpa: std.mem.Allocator, request_lookup_id: zul.UUID) anyerror!?MeltQuote, + getMeltQuoteByRequestLookupIdFn: *const fn (ptr: *anyopaque, gpa: std.mem.Allocator, request_lookup_id: []const u8) anyerror!?MeltQuote, getMeltQuoteByRequestFn: *const fn (ptr: *anyopaque, gpa: std.mem.Allocator, request: []const u8) anyerror!?MeltQuote, removeMeltQuoteStateFn: *const fn (ptr: *anyopaque, quote_id: zul.UUID) anyerror!void, addProofsFn: *const fn (ptr: *anyopaque, proofs: []const nuts.Proof) anyerror!void, @@ -129,7 +129,7 @@ pub const MintDatabase = struct { const self: *T = @ptrCast(@alignCast(pointer)); return self.getMintQuotes(gpa); } - pub fn getMintQuoteByRequestLookupId(pointer: *anyopaque, gpa: std.mem.Allocator, request_lookup_id: zul.UUID) anyerror!?MintQuote { + pub fn getMintQuoteByRequestLookupId(pointer: *anyopaque, gpa: std.mem.Allocator, request_lookup_id: []const u8) anyerror!?MintQuote { const self: *T = @ptrCast(@alignCast(pointer)); return self.getMintQuoteByRequestLookupId(gpa, request_lookup_id); } @@ -157,7 +157,7 @@ pub const MintDatabase = struct { const self: *T = @ptrCast(@alignCast(pointer)); return self.getMeltQuotes(gpa); } - pub fn getMeltQuoteByRequestLookupId(pointer: *anyopaque, gpa: std.mem.Allocator, request_lookup_id: zul.UUID) anyerror!?MeltQuote { + pub fn getMeltQuoteByRequestLookupId(pointer: *anyopaque, gpa: std.mem.Allocator, request_lookup_id: []const u8) anyerror!?MeltQuote { const self: *T = @ptrCast(@alignCast(pointer)); return self.getMeltQuoteByRequestLookupId(gpa, request_lookup_id); } @@ -332,7 +332,7 @@ pub const MintDatabase = struct { pub fn getMintQuotes(self: Self, allocator: std.mem.Allocator) anyerror!std.ArrayList(MintQuote) { return self.getMintQuotesFn(self.ptr, allocator); } - pub fn getMintQuoteByRequestLookupId(self: Self, gpa: std.mem.Allocator, request_lookup_id: zul.UUID) anyerror!?MintQuote { + pub fn getMintQuoteByRequestLookupId(self: Self, gpa: std.mem.Allocator, request_lookup_id: []const u8) anyerror!?MintQuote { return self.getMintQuoteByRequestLookupIdFn(self.ptr, gpa, request_lookup_id); } pub fn getMintQuoteByRequest(self: Self, gpa: std.mem.Allocator, request: []const u8) anyerror!?MintQuote { @@ -353,7 +353,7 @@ pub const MintDatabase = struct { pub fn getMeltQuotes(self: Self, gpa: std.mem.Allocator) anyerror!std.ArrayList(MeltQuote) { return self.getMeltQuotesFn(self.ptr, gpa); } - pub fn getMeltQuoteByRequestLookupId(self: Self, gpa: std.mem.Allocator, request_lookup_id: zul.UUID) anyerror!?MeltQuote { + pub fn getMeltQuoteByRequestLookupId(self: Self, gpa: std.mem.Allocator, request_lookup_id: []const u8) anyerror!?MeltQuote { return self.getMeltQuoteByRequestLookupIdFn(self.ptr, gpa, request_lookup_id); } pub fn getMeltQuoteByRequest(self: Self, gpa: std.mem.Allocator, request: []const u8) anyerror!?MeltQuote { diff --git a/src/core/database/mint_memory.zig b/src/core/database/mint_memory.zig index a1db251..e1b48a4 100644 --- a/src/core/database/mint_memory.zig +++ b/src/core/database/mint_memory.zig @@ -285,7 +285,7 @@ pub const MintMemoryDatabase = struct { pub fn getMintQuoteByRequestLookupId( self: *Self, allocator: std.mem.Allocator, - request: zul.UUID, + request: []const u8, ) !?MintQuote { var arena = std.heap.ArenaAllocator.init(allocator); defer arena.deinit(); @@ -294,7 +294,7 @@ pub const MintMemoryDatabase = struct { const quotes = try self.getMintQuotes(arena.allocator()); for (quotes.items) |q| { // if we found, cloning with allocator, so caller responsible on free resources - if (q.request_lookup_id.eql(request)) return try q.clone(allocator); + if (std.mem.eql(u8, q.request_lookup_id, request)) return try q.clone(allocator); } return null; @@ -386,7 +386,7 @@ pub const MintMemoryDatabase = struct { pub fn getMeltQuoteByRequestLookupId( self: *Self, allocator: std.mem.Allocator, - request: zul.UUID, + request: []const u8, ) !?MeltQuote { var arena = std.heap.ArenaAllocator.init(allocator); defer arena.deinit(); @@ -395,7 +395,7 @@ pub const MintMemoryDatabase = struct { const quotes = try self.getMeltQuotes(arena.allocator()); for (quotes.items) |q| { // if we found, cloning with allocator, so caller responsible on free resources - if (std.mem.eql(u8, q.request_lookup_id, &request.bin)) return try q.clone(allocator); + if (std.mem.eql(u8, q.request_lookup_id, request)) return try q.clone(allocator); } return null; diff --git a/src/core/lightning/lightning.zig b/src/core/lightning/lightning.zig index 66aa5e9..1d236f4 100644 --- a/src/core/lightning/lightning.zig +++ b/src/core/lightning/lightning.zig @@ -30,6 +30,8 @@ pub const PayInvoiceResponse = struct { /// Totoal Amount Spent total_spent: Amount, + unit: CurrencyUnit, + pub fn deinit(self: PayInvoiceResponse, allocator: std.mem.Allocator) void { allocator.free(self.payment_hash); @@ -45,6 +47,8 @@ pub const PaymentQuoteResponse = struct { amount: Amount, /// Fee required for melt fee: u64, + /// Status + state: MeltQuoteState, pub fn deinit(self: PaymentQuoteResponse, allocator: std.mem.Allocator) void { allocator.free(self.request_lookup_id); diff --git a/src/core/lightning/mint.zig b/src/core/lightning/mint.zig index 98e2abd..d5578da 100644 --- a/src/core/lightning/mint.zig +++ b/src/core/lightning/mint.zig @@ -3,6 +3,8 @@ const Self = @This(); const std = @import("std"); const core = @import("../lib.zig"); +const ref = @import("../../sync/ref.zig"); +const mpmc = @import("../../sync/mpmc.zig"); const Channel = @import("../../channels/channels.zig").Channel; const Amount = core.amount.Amount; @@ -21,7 +23,7 @@ ptr: *anyopaque, deinitFn: *const fn (ptr: *anyopaque) void, getSettingsFn: *const fn (ptr: *anyopaque) Settings, -waitAnyInvoiceFn: *const fn (ptr: *anyopaque) anyerror!Channel(std.ArrayList(u8)).Rx, +waitAnyInvoiceFn: *const fn (ptr: *anyopaque) ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))), getPaymentQuoteFn: *const fn (ptr: *anyopaque, alloc: std.mem.Allocator, melt_quote_request: MeltQuoteBolt11Request) anyerror!PaymentQuoteResponse, payInvoiceFn: *const fn (ptr: *anyopaque, alloc: std.mem.Allocator, melt_quote: core.mint.MeltQuote, partial_msats: ?Amount, max_fee_msats: ?Amount) anyerror!PayInvoiceResponse, checkInvoiceStatusFn: *const fn (ptr: *anyopaque, request_lookup_id: []const u8) anyerror!MintQuoteState, @@ -34,7 +36,7 @@ pub fn initFrom(comptime T: type, allocator: std.mem.Allocator, value: T) !Self return self.getSettings(); } - pub fn waitAnyInvoice(pointer: *anyopaque) anyerror!Channel(std.ArrayList(u8)).Rx { + pub fn waitAnyInvoice(pointer: *anyopaque) ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))) { const self: *T = @ptrCast(@alignCast(pointer)); return self.waitAnyInvoice(); } @@ -94,7 +96,7 @@ pub fn getSettings(self: Self) Settings { return self.getSettingsFn(self.ptr); } -pub fn waitAnyInvoice(self: Self) !Channel(std.ArrayList(u8)).Rx { +pub fn waitAnyInvoice(self: Self) ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))) { return self.waitAnyInvoiceFn(self.ptr); } diff --git a/src/core/mint/lightning/lib.zig b/src/core/mint/lightning/lib.zig index 72b4513..dbd08f5 100644 --- a/src/core/mint/lightning/lib.zig +++ b/src/core/mint/lightning/lib.zig @@ -1,4 +1,2 @@ pub const lnbits = @import("lnbits.zig"); pub const invoice = @import("invoices/lib.zig"); - -pub const Lightning = @import("lightning.zig"); diff --git a/src/core/mint/lightning/lnbits.zig b/src/core/mint/lightning/lnbits.zig index d5a254e..cd4f0f0 100644 --- a/src/core/mint/lightning/lnbits.zig +++ b/src/core/mint/lightning/lnbits.zig @@ -1,6 +1,22 @@ const std = @import("std"); -const model = @import("../model.zig"); -const Lightning = @import("lightning.zig"); +const core = @import("../../lib.zig"); +const lightning_invoice = @import("../../../lightning_invoices/invoice.zig"); +const httpz = @import("httpz"); +const zul = @import("zul"); +const ref = @import("../../../sync/ref.zig"); +const mpmc = @import("../../../sync/mpmc.zig"); +const http_router = @import("../../../misc/http_router/http_router.zig"); + +const Amount = core.amount.Amount; +const PaymentQuoteResponse = core.lightning.PaymentQuoteResponse; +// const CreateInvoiceResponse = core.lightning.CreateInvoiceResponse; +const MeltQuoteBolt11Request = core.nuts.nut05.MeltQuoteBolt11Request; +const MintMeltSettings = core.lightning.MintMeltSettings; +const MeltQuoteState = core.nuts.nut05.QuoteState; +const MintQuoteState = core.nuts.nut04.QuoteState; +const FeeReserve = core.mint.FeeReserve; +const Channel = @import("../../../channels/channels.zig").Channel; +const MintLightning = core.lightning.MintLightning; pub const HttpError = std.http.Client.RequestError || std.http.Client.Request.FinishError || std.http.Client.Request.WaitError || error{ ReadBodyError, WrongJson }; @@ -10,26 +26,107 @@ pub const LightningError = HttpError || std.Uri.ParseError || std.mem.Allocator. PaymentFailed, }; -pub const Settings = struct { - admin_key: ?[]const u8, - url: ?[]const u8, -}; +pub const LnBits = struct { + const Self = @This(); -pub const LnBitsLightning = struct { client: LNBitsClient, - pub fn init(allocator: std.mem.Allocator, admin_key: []const u8, lnbits_url: []const u8) !@This() { + chan: ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))), // we using signle channel for sending invoices + allocator: std.mem.Allocator, + fee_reserve: FeeReserve, + + webhook_url: ?[]const u8, + + mint_settings: MintMeltSettings = .{}, + melt_settings: MintMeltSettings = .{}, + + pub fn init( + allocator: std.mem.Allocator, + admin_key: []const u8, + invoice_api_key: []const u8, + lnbits_url: []const u8, + mint_settings: MintMeltSettings, + melt_settings: MintMeltSettings, + fee_reserve: FeeReserve, + chan: ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))), + webhook_url: ?[]const u8, + ) !@This() { return .{ - .client = try LNBitsClient.init(allocator, admin_key, lnbits_url), + .allocator = allocator, + .client = try LNBitsClient.init( + allocator, + admin_key, + invoice_api_key, + lnbits_url, + ), + .mint_settings = mint_settings, + .melt_settings = melt_settings, + .fee_reserve = fee_reserve, + .chan = chan, + + .webhook_url = webhook_url, }; } - pub fn deinit(self: *@This()) void { - self.client.deinit(); + pub fn toMintLightning(self: *const Self, gpa: std.mem.Allocator) error{OutOfMemory}!MintLightning { + return MintLightning.initFrom(Self, gpa, self.*); } - pub fn lightning(self: *@This()) Lightning { - return Lightning.init(self); + pub fn getSettings(self: *const Self) core.lightning.Settings { + return .{ + .mpp = false, + .unit = .sat, + .melt_settings = self.melt_settings, + .mint_settings = self.mint_settings, + }; + } + + /// caller responsible to deallocate result + pub fn getPaymentQuote( + self: *const Self, + allocator: std.mem.Allocator, + melt_quote_request: MeltQuoteBolt11Request, + ) !PaymentQuoteResponse { + if (melt_quote_request.unit != .sat) return error.UnsupportedUnit; + + const invoice_amount_msat = melt_quote_request + .request + .amountMilliSatoshis() orelse return error.UnknownInvoiceAmount; + + const amount = try core.lightning.toUnit( + invoice_amount_msat, + .msat, + melt_quote_request.unit, + ); + + const relative_fee_reserve: u64 = + @intFromFloat(self.fee_reserve.percent_fee_reserve * @as(f32, @floatFromInt(amount))); + + const absolute_fee_reserve: u64 = self.fee_reserve.min_fee_reserve; + + const fee = if (relative_fee_reserve > absolute_fee_reserve) + relative_fee_reserve + else + absolute_fee_reserve; + + const req_lookup_id = try allocator.dupe(u8, &melt_quote_request.request.paymentHash().inner); + errdefer allocator.free(req_lookup_id); + + return .{ + .request_lookup_id = req_lookup_id, + .amount = amount, + .fee = fee, + .state = .unpaid, + }; + } + + pub fn deinit(self: *@This()) void { + self.client.deinit(); + self.chan.releaseWithFn((struct { + fn deinit(_self: mpmc.UnboundedChannel(std.ArrayList(u8))) void { + _self.deinit(); + } + }).deinit); } pub fn isInvoicePaid(self: *@This(), allocator: std.mem.Allocator, invoice: []const u8) !bool { @@ -39,34 +136,102 @@ pub const LnBitsLightning = struct { return self.client.isInvoicePaid(allocator, &decoded_invoice.paymentHash()); } - pub fn createInvoice(self: *@This(), allocator: std.mem.Allocator, amount: u64) !model.CreateInvoiceResult { - return try self.client.createInvoice(allocator, .{ - .amount = amount, - .unit = "sat", - .memo = null, - .expiry = 10000, - .webhook = null, + pub fn checkInvoiceStatus( + self: *Self, + _request_lookup_id: []const u8, + ) !MintQuoteState { + return if (try self.client.isInvoicePaid(self.allocator, _request_lookup_id)) .paid else .unpaid; + } + + pub fn createInvoice( + self: *Self, + gpa: std.mem.Allocator, + amount: Amount, + unit: core.nuts.CurrencyUnit, + description: []const u8, + unix_expiry: u64, + ) !core.lightning.CreateInvoiceResponse { + if (unit != .sat) return error.UnsupportedUnit; + + const time_now = std.time.timestamp(); + std.debug.assert(unix_expiry > time_now); + + const amnt = try core.lightning.toUnit(amount, unit, .sat); + + const expiry = unix_expiry - @abs(time_now); + + const create_invoice_response = try self.client.createInvoice(gpa, .{ + .amount = amnt, + .unit = unit.toString(), + .memo = description, + .expiry = expiry, + .webhook = self.webhook_url, .internal = null, + .out = false, }); + errdefer create_invoice_response.deinit(gpa); + defer gpa.free(create_invoice_response.payment_request); + + var request = try lightning_invoice.Bolt11Invoice.fromStr(gpa, create_invoice_response.payment_request); + errdefer request.deinit(); + + const res_expiry = request.expiresAtSecs(); + + return .{ + .request_lookup_id = create_invoice_response.payment_hash, + .request = request, + .expiry = res_expiry, + }; + } + + pub fn payInvoice( + self: *Self, + arena: std.mem.Allocator, + melt_quote: core.mint.MeltQuote, + _: ?Amount, + _: ?Amount, // max_fee_msats + ) !core.lightning.PayInvoiceResponse { + const pay_response = try self.client.payInvoice(arena, melt_quote.request); + + const invoices_info = try self.client.findInvoice(arena, pay_response.payment_hash); + + if (invoices_info.value.len == 0) return error.InvoiceNotFound; + + const invoice_info = invoices_info.value[0]; + + const status: MeltQuoteState = if (invoice_info.pending) .unpaid else .paid; + + const total_spent = @abs(invoice_info.amount + invoice_info.fee); + + return .{ + .payment_hash = pay_response.payment_hash, + .payment_preimage = invoice_info.payment_hash, + .status = status, + .total_spent = total_spent, + .unit = .sat, + }; } - pub fn payInvoice(self: *@This(), allocator: std.mem.Allocator, payment_request: []const u8) !model.PayInvoiceResult { - return try self.client.payInvoice(allocator, payment_request); + // Result is channel with invoices, caller must free result + pub fn waitAnyInvoice( + self: *Self, + ) ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))) { + return self.chan.retain(); } }; pub const LNBitsClient = struct { admin_key: []const u8, - lnbits_url: std.Uri, + invoice_api_key: []const u8, + lnbits_url: []const u8, client: std.http.Client, pub fn init( allocator: std.mem.Allocator, admin_key: []const u8, + invoice_api_key: []const u8, lnbits_url: []const u8, ) !LNBitsClient { - const url = try std.Uri.parse(lnbits_url); - var client = std.http.Client{ .allocator = allocator, }; @@ -74,7 +239,8 @@ pub const LNBitsClient = struct { return .{ .admin_key = admin_key, - .lnbits_url = url, + .lnbits_url = lnbits_url, + .invoice_api_key = invoice_api_key, .client = client, }; } @@ -89,15 +255,13 @@ pub const LNBitsClient = struct { allocator: std.mem.Allocator, endpoint: []const u8, ) LightningError![]const u8 { - var buf: [100]u8 = undefined; - var b: []u8 = buf[0..]; - - const uri = self.lnbits_url.resolve_inplace(endpoint, &b) catch return std.Uri.ParseError.UnexpectedCharacter; + const uri = try std.fmt.allocPrint(allocator, "{s}/{s}", .{ self.lnbits_url, endpoint }); + defer allocator.free(uri); const header_buf = try allocator.alloc(u8, 1024 * 1024 * 4); defer allocator.free(header_buf); - var req = try self.client.open(.GET, uri, .{ + var req = try self.client.open(.GET, try std.Uri.parse(uri), .{ .server_header_buffer = header_buf, .extra_headers = &.{ .{ @@ -132,15 +296,15 @@ pub const LNBitsClient = struct { endpoint: []const u8, req_body: []const u8, ) LightningError![]const u8 { - var buf: [100]u8 = undefined; - var b: []u8 = buf[0..]; - - const uri = self.lnbits_url.resolve_inplace(endpoint, &b) catch return std.Uri.ParseError.UnexpectedCharacter; + const uri = try std.fmt.allocPrint(allocator, "{s}/{s}", .{ self.lnbits_url, endpoint }); + defer allocator.free(uri); const header_buf = try allocator.alloc(u8, 1024 * 1024 * 4); defer allocator.free(header_buf); - var req = try self.client.open(.POST, uri, .{ + std.log.debug("uri: {s}", .{uri}); + + var req = try self.client.open(.POST, try std.Uri.parse(uri), .{ .server_header_buffer = header_buf, .extra_headers = &.{ .{ @@ -180,37 +344,31 @@ pub const LNBitsClient = struct { /// createInvoice - creating invoice /// note: after success call u need to call deinit on result using alloactor that u pass as argument to this func. - pub fn createInvoice(self: *@This(), allocator: std.mem.Allocator, params: model.CreateInvoiceParams) !model.CreateInvoiceResult { - const req_body = try std.json.stringifyAlloc(allocator, ¶ms, .{}); + pub fn createInvoice(self: *@This(), allocator: std.mem.Allocator, params: CreateInvoiceRequest) !CreateInvoiceResponse { + const req_body = try std.json.stringifyAlloc(allocator, ¶ms, .{ + .emit_null_optional_fields = false, + }); + + std.log.debug("request {s}", .{req_body}); const res = try self.post(allocator, "api/v1/payments", req_body); + std.log.debug("create invoice, response : {s}", .{res}); + const parsed = std.json.parseFromSlice(std.json.Value, allocator, res, .{ .allocate = .alloc_always }) catch return error.WrongJson; const payment_request = parsed.value.object.get("payment_request") orelse unreachable; const payment_hash = parsed.value.object.get("payment_hash") orelse unreachable; const pr = switch (payment_request) { - .string => |v| val: { - const result = try allocator.alloc(u8, v.len); - @memcpy(result, v); - break :val result; - }, - else => { - unreachable; - }, + .string => |v| try allocator.dupe(u8, v), + else => unreachable, }; errdefer allocator.free(pr); const ph = switch (payment_hash) { - .string => |v| val: { - const result = try allocator.alloc(u8, v.len); - @memcpy(result, v); - break :val result; - }, - else => { - unreachable; - }, + .string => |v| try allocator.dupe(u8, v), + else => unreachable, }; errdefer allocator.free(ph); @@ -222,7 +380,7 @@ pub const LNBitsClient = struct { /// payInvoice - paying invoice /// note: after success call u need to call deinit on result using alloactor that u pass as argument to this func. - pub fn payInvoice(self: *@This(), allocator: std.mem.Allocator, bolt11: []const u8) !model.PayInvoiceResult { + pub fn payInvoice(self: *@This(), allocator: std.mem.Allocator, bolt11: []const u8) !PayInvoiceResponse { const req_body = try std.json.stringifyAlloc(allocator, &.{ .out = true, .bolt11 = bolt11 }, .{}); const res = try self.post(allocator, "api/v1/payments", req_body); @@ -245,13 +403,17 @@ pub const LNBitsClient = struct { return .{ .payment_hash = ph, - .total_fees = 0, + // .total_fees = 0, }; } /// isInvoicePaid - paying invoice /// note: after success call u need to call deinit on result using alloactor that u pass as argument to this func. - pub fn isInvoicePaid(self: *@This(), allocator: std.mem.Allocator, payment_hash: []const u8) !bool { + pub fn isInvoicePaid( + self: *@This(), + allocator: std.mem.Allocator, + payment_hash: []const u8, + ) !bool { const endpoint = try std.fmt.allocPrint( allocator, "api/v1/payments/{s}", @@ -269,10 +431,144 @@ pub const LNBitsClient = struct { else => false, }; } + + /// findInvoice - finding invoice + pub fn findInvoice( + self: *@This(), + allocator: std.mem.Allocator, + checking_id: []const u8, + ) !std.json.Parsed([]const FindInvoiceResponse) { + const endpoint = try std.fmt.allocPrint( + allocator, + "api/v1/payments?checking_id=internal_{s}", + .{checking_id}, + ); + defer allocator.free(endpoint); + + const res = try self.get(allocator, endpoint); + + return std.json.parseFromSlice([]const FindInvoiceResponse, allocator, res, .{ .allocate = .alloc_always }) catch return error.WrongJson; + } + + /// Create invoice webhook + pub fn createInvoiceWebhookRouter( + _: *@This(), + allocator: std.mem.Allocator, + webhook_endpoint: []const u8, + chan: ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))), + ) !http_router.Router { + const state = WebhookState{ + .chan = chan, + .allocator = allocator, + }; + var router = try httpz.Router(WebhookState, httpz.Action(WebhookState)).init(allocator, WebhookState.dispatcher, state); + + router.post(webhook_endpoint, handleInvoice, .{}); + + return try http_router.Router.initFrom(WebhookState, allocator, router); + } +}; + +pub fn handleInvoice( + state: WebhookState, + req: *httpz.Request, + res: *httpz.Response, +) !void { + std.log.debug("incoming webhook, body : {s}", .{req.body().?}); + const webhook_response = if (try req.json(FindInvoiceResponse)) |resp| resp else { + res.status = 422; + return; + }; + + std.log.debug("Received webhook update for: {s}", .{webhook_response.checking_id}); + + var sender = try state.chan.value.sender(); + + try sender.send(std.ArrayList(u8).fromOwnedSlice(state.allocator, try state.allocator.dupe(u8, webhook_response.checking_id))); +} + +/// Webhook state +pub const WebhookState = struct { + /// allocator to allocate webhook messages + allocator: std.mem.Allocator, + /// chan, where we took sender + chan: ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))), + + pub usingnamespace http_router.DefaultDispatcher(@This()); +}; + +/// Create invoice request +pub const CreateInvoiceRequest = struct { + /// Amount (sat) + amount: u64, + /// Unit + unit: []const u8, + /// Memo + memo: ?[]const u8, + /// Expiry is in seconds + expiry: ?u64, + /// Webhook url + webhook: ?[]const u8, + /// Internal payment + internal: ?bool, + /// Incoming or outgoing payment + out: bool, +}; + +/// Pay invoice response +pub const PayInvoiceResponse = struct { + /// Payment hash + payment_hash: []const u8, +}; + +/// Find invoice response +pub const FindInvoiceResponse = struct { + /// status + status: []const u8, + /// Checking id + checking_id: []const u8, + /// Pending (paid) + pending: bool, + /// Amount (sat) + amount: i64, + /// Fee (msat) + fee: i64, + /// Memo + memo: []const u8, + /// Time + time: u64, + /// Bolt11 + bolt11: []const u8, + /// Preimage + preimage: ?[]const u8, + /// Payment hash + payment_hash: []const u8, + /// Expiry + expiry: f64, + /// Extra + extra: std.json.Value, // should be object map + /// Wallet id + wallet_id: []const u8, + /// Webhook url + webhook: ?[]const u8, + /// Webhook status + webhook_status: ?[]const u8, +}; +/// Create invoice response +pub const CreateInvoiceResponse = struct { + /// Payment hash + payment_hash: []const u8, + /// Payment request (bolt11) + payment_request: []const u8, + + pub fn deinit(self: CreateInvoiceResponse, alloc: std.mem.Allocator) void { + alloc.free(self.payment_hash); + alloc.free(self.payment_request); + } }; test "test_decode_invoice" { - var client = try LnBitsLightning.init(std.testing.allocator, "admin_key", "http://localhost:5000"); + var client = try LnBits.init(std.testing.allocator, "admin_key", "http://localhost:5000"); defer client.deinit(); const lightning = client.lightning(); @@ -286,7 +582,7 @@ test "test_decode_invoice" { } test "test_decode_invoice_invalid" { - var client = try LnBitsLightning.init(std.testing.allocator, "admin_key", "http://localhost:5000"); + var client = try LnBits.init(std.testing.allocator, "admin_key", "http://localhost:5000"); defer client.deinit(); const lightning = client.lightning(); diff --git a/src/core/mint/mint.zig b/src/core/mint/mint.zig index 2ebdfcc..e1d864f 100644 --- a/src/core/mint/mint.zig +++ b/src/core/mint/mint.zig @@ -288,7 +288,7 @@ pub const Mint = struct { unit: nuts.CurrencyUnit, amount: core.amount.Amount, expiry: u64, - ln_lookup: zul.UUID, + ln_lookup: []const u8, ) !MintQuote { const nut04 = self.mint_info.nuts.nut04; if (nut04.disabled) return error.MintingDisabled; @@ -579,7 +579,7 @@ pub const Mint = struct { /// Flag mint quote as paid pub fn payMintQuoteForRequestId( self: *Mint, - request_lookup_id: zul.UUID, + request_lookup_id: []const u8, ) !void { const mint_quote = (try self.localstore.value.getMintQuoteByRequestLookupId(self.allocator, request_lookup_id)) orelse return; diff --git a/src/core/mint/types.zig b/src/core/mint/types.zig index 76c40ed..93c438c 100644 --- a/src/core/mint/types.zig +++ b/src/core/mint/types.zig @@ -24,7 +24,7 @@ pub const MintQuote = struct { /// Expiration time of quote expiry: u64, /// Value used by ln backend to look up state of request - request_lookup_id: zul.UUID, + request_lookup_id: []const u8, /// formatting mint quote pub fn format( @@ -45,7 +45,7 @@ pub const MintQuote = struct { unit: CurrencyUnit, amount: amount_lib.Amount, expiry: u64, - request_lookup_id: zul.UUID, + request_lookup_id: []const u8, ) !MintQuote { const id = zul.UUID.v4(); @@ -65,6 +65,8 @@ pub const MintQuote = struct { pub fn deinit(self: *const MintQuote, allocator: std.mem.Allocator) void { allocator.free(self.request); + allocator.free(self.mint_url); + allocator.free(self.request_lookup_id); } pub fn clone(self: *const MintQuote, allocator: std.mem.Allocator) !MintQuote { @@ -73,11 +75,14 @@ pub const MintQuote = struct { const mint_url = try allocator.dupe(u8, self.mint_url); errdefer allocator.free(mint_url); + const request_lookup_id = try allocator.dupe(u8, self.request_lookup_id); + errdefer allocator.free(request_lookup_id); var cloned = self.*; cloned.request = request; cloned.mint_url = mint_url; + cloned.request_lookup_id = request_lookup_id; return cloned; } diff --git a/src/fake_wallet/fake_wallet.zig b/src/fake_wallet/fake_wallet.zig index 82c9d26..f787746 100644 --- a/src/fake_wallet/fake_wallet.zig +++ b/src/fake_wallet/fake_wallet.zig @@ -7,6 +7,8 @@ const lightning_invoice = @import("../lightning_invoices/invoice.zig"); const helper = @import("../helper/helper.zig"); const zul = @import("zul"); const secp256k1 = @import("bitcoin-primitives").secp256k1; +const ref = @import("../sync/ref.zig"); +const mpmc = @import("../sync/mpmc.zig"); const Amount = core.amount.Amount; const PaymentQuoteResponse = core.lightning.PaymentQuoteResponse; @@ -23,12 +25,19 @@ const MintLightning = core.lightning.MintLightning; // TODO: wait any invoices, here we need create a new listener, that will receive // message like pub sub channel -fn sendLabelFn(label: std.ArrayList(u8), ch: Channel(std.ArrayList(u8)).Tx, duration: u64) void { +fn sendLabelFn(label: std.ArrayList(u8), ch: ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))), duration: u64) void { errdefer label.deinit(); + defer ch.releaseWithFn((struct { + fn deinit(_self: mpmc.UnboundedChannel(std.ArrayList(u8))) void { + _self.deinit(); + } + }).deinit); std.time.sleep(duration * @as(u64, 1e9)); - ch.send(label) catch |err| { + var sender = ch.value.sender() catch return std.log.err("channel closed, cannot take sender", .{}); + + sender.send(label) catch |err| { std.log.err("send label {s}, failed: {any}", .{ label.items, err }); return; }; @@ -41,7 +50,7 @@ pub const FakeWallet = struct { const Self = @This(); fee_reserve: core.mint.FeeReserve = .{}, - chan: *Channel(std.ArrayList(u8)) = undefined, // we using signle channel for sending invoices + chan: ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))), // we using signle channel for sending invoices mint_settings: MintMeltSettings = .{}, melt_settings: MintMeltSettings = .{}, @@ -56,8 +65,12 @@ pub const FakeWallet = struct { mint_settings: MintMeltSettings, melt_settings: MintMeltSettings, ) !FakeWallet { - const ch = try Channel(std.ArrayList(u8)).init(allocator, 10); - errdefer ch.deinit(); + const ch = try ref.arc(allocator, mpmc.UnboundedChannel(std.ArrayList(u8)).init(allocator)); + errdefer ch.releaseWithFn((struct { + fn deinit(_self: mpmc.UnboundedChannel(std.ArrayList(u8))) void { + _self.deinit(); + } + }).deinit); return .{ .chan = ch, @@ -74,7 +87,11 @@ pub const FakeWallet = struct { } pub fn deinit(self: *FakeWallet) void { - self.chan.deinit(); + self.chan.releaseWithFn((struct { + fn deinit(_self: mpmc.UnboundedChannel(std.ArrayList(u8))) void { + _self.deinit(); + } + }).deinit); self.thread_pool.deinit(self.allocator); } @@ -89,9 +106,9 @@ pub const FakeWallet = struct { // Result is channel with invoices, caller must free result pub fn waitAnyInvoice( - self: *const Self, - ) !Channel(std.ArrayList(u8)).Rx { - return self.chan.getRx(); + self: *Self, + ) ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))) { + return self.chan.retain(); } /// caller responsible to deallocate result @@ -127,6 +144,7 @@ pub const FakeWallet = struct { .request_lookup_id = req_lookup_id, .amount = amount, .fee = fee, + .state = .unpaid, }; } @@ -144,6 +162,7 @@ pub const FakeWallet = struct { _ = _max_fee_msats; // autofix return .{ + .unit = .msat, .payment_preimage = &.{}, .payment_hash = &.{}, // empty slice - safe to free .status = .paid, @@ -162,7 +181,7 @@ pub const FakeWallet = struct { /// creating invoice - caller own response and responsible to free pub fn createInvoice( - self: *const Self, + self: *Self, gpa: std.mem.Allocator, amount: Amount, unit: core.nuts.CurrencyUnit, @@ -221,7 +240,7 @@ pub const FakeWallet = struct { const label_clone = try self.allocator.dupe(u8, label); errdefer self.allocator.free(label_clone); - try self.thread_pool.spawn(.{ std.ArrayList(u8).fromOwnedSlice(self.allocator, label_clone), self.chan.getTx(), duration }); + try self.thread_pool.spawn(.{ std.ArrayList(u8).fromOwnedSlice(self.allocator, label_clone), self.chan.retain(), duration }); } const expiry = signed_invoice.expiresAtSecs(); diff --git a/src/lib.zig b/src/lib.zig index e20b04b..3460077 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -1,6 +1,7 @@ const std = @import("std"); pub const core = @import("core/lib.zig"); pub const lightning_invoices = @import("lightning_invoices/invoice.zig"); +pub const http_router = @import("misc/http_router/http_router.zig"); test { std.testing.log_level = .warn; diff --git a/src/mint.zig b/src/mint.zig index 2c3682a..c745bee 100644 --- a/src/mint.zig +++ b/src/mint.zig @@ -9,6 +9,9 @@ const builtin = @import("builtin"); const config = @import("mintd/config.zig"); const clap = @import("clap"); const zul = @import("zul"); +const ref = @import("sync/ref.zig"); +const mpmc = @import("sync/mpmc.zig"); +const http_router = @import("misc/http_router/http_router.zig"); const MintLightning = core.lightning.MintLightning; const MintState = @import("router/router.zig").MintState; @@ -22,16 +25,14 @@ const ContactInfo = core.nuts.ContactInfo; const MintVersion = core.nuts.MintVersion; const MintInfo = core.nuts.MintInfo; const Channel = @import("channels/channels.zig").Channel; +const Lnbits = @import("core/mint/lightning/lnbits.zig").LnBits; const default_quote_ttl_secs: u64 = 1800; /// Update mint quote when called for a paid invoice fn handlePaidInvoice(mint: *Mint, request_lookup_id: []const u8) !void { - const request_lookup = try zul.UUID.parse(request_lookup_id); - - std.log.debug("Invoice with lookup id paid: {s}", .{request_lookup}); - - try mint.payMintQuoteForRequestId(request_lookup); + std.log.debug("Invoice with lookup id paid: {s}", .{request_lookup_id}); + try mint.payMintQuoteForRequestId(request_lookup_id); } pub fn main() !void { @@ -135,6 +136,31 @@ pub fn main() !void { .percent_fee_reserve = relative_ln_fee, }; + const mint_url = parsed_settings.value.info.url; + const listen_addr = parsed_settings.value.info.listen_host; + const listen_port = parsed_settings.value.info.listen_port; + + var global_handler = http_router.GlobalRouter{ + .router = std.ArrayList(http_router.Router).init(gpa.allocator()), + }; + defer global_handler.router.deinit(); + + var srv = try httpz.Server(*http_router.GlobalRouter).init(gpa.allocator(), .{ + .address = listen_addr, + .port = listen_port, + }, &global_handler); + defer srv.deinit(); + + const cors_middleware = try srv.middleware(httpz.middleware.Cors, .{ + .origin = "*", + .headers = "*", + }); + + const s_router = srv.router(.{ + .middlewares = &.{cors_middleware}, + }); + _ = s_router; // autofix + const input_fee_ppk = parsed_settings.value.info.input_fee_ppk orelse 0; var supported_units = std.AutoHashMap(core.nuts.CurrencyUnit, std.meta.Tuple(&.{ u64, u8 })).init(gpa.allocator()); @@ -150,9 +176,70 @@ pub fn main() !void { ln_backends.deinit(); } + var arena = std.heap.ArenaAllocator.init(gpa.allocator()); + defer arena.deinit(); + // TODO set ln router // additional routers for httpz server switch (parsed_settings.value.ln.ln_backend) { + .lnbits => { + const lnbits_settings = parsed_settings.value.lnbits orelse unreachable; + const admin_api_key = lnbits_settings.admin_api_key; + const invoice_api_key = lnbits_settings.invoice_api_key; + + const webhook_endpoint = "/webhook/lnbits/sat/invoice"; + var webhook_url = zul.StringBuilder.init(gpa.allocator()); + + try webhook_url + .write(mint_url); + try webhook_url + .write(webhook_endpoint); + + var chan = val: { + var ch = mpmc.UnboundedChannel(std.ArrayList(u8)).init(gpa.allocator()); + errdefer ch.deinit(); + + break :val try ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8))).init(gpa.allocator(), ch); + }; + errdefer chan.releaseWithFn((struct { + fn deinit(self: mpmc.UnboundedChannel(std.ArrayList(u8))) void { + self.deinit(); + } + }).deinit); + + // TODO add fee reserve, webhooh, receiver + var lnbits = try Lnbits.init( + gpa.allocator(), + admin_api_key, + invoice_api_key, + lnbits_settings.lnbits_api, + .{}, + .{}, + fee_reserve, + chan.retain(), + webhook_url.string(), + ); + errdefer lnbits.deinit(); + + const ln_mint = try lnbits.toMintLightning(gpa.allocator()); + errdefer ln_mint.deinit(); + + const unit = core.nuts.CurrencyUnit.sat; + + const ln_key = LnKey.init(unit, .bolt11); + + try ln_backends.put(ln_key, ln_mint); + + try supported_units.put(unit, .{ input_fee_ppk, 64 }); + + const webhook_router = try lnbits.client.createInvoiceWebhookRouter( + arena.allocator(), + webhook_endpoint, + chan.retain(), + ); + + try global_handler.router.append(webhook_router); + }, .fake_wallet => { const units = (parsed_settings.value.fake_wallet orelse config.FakeWallet{}).supported_units; @@ -170,10 +257,10 @@ pub fn main() !void { try supported_units.put(unit, .{ input_fee_ppk, 64 }); } }, - else => { - // not implemented backends - unreachable; - }, + // else => { + // // not implemented backends + // unreachable; + // }, } var nuts = core.nuts.Nuts{}; @@ -256,26 +343,22 @@ pub fn main() !void { // } // TODO - const mint_url = parsed_settings.value.info.url; - const listen_addr = parsed_settings.value.info.listen_host; - const listen_port = parsed_settings.value.info.listen_port; const quote_ttl = parsed_settings.value .info .seconds_quote_is_valid_for orelse default_quote_ttl_secs; // start serevr - var srv = try router.createMintServer(gpa.allocator(), mint_url, &mint, ln_backends, quote_ttl, .{ - .port = listen_port, - .address = listen_addr, - }, &.{ - .{ - httpz.middleware.Cors, .{ - .origin = "*", - .headers = "*", - }, - }, - }); - defer srv.deinit(); + const mint_router = try router.createMintServer( + arena.allocator(), + mint_url, + &mint, + ln_backends, + quote_ttl, + ); + + // adding routers + try global_handler.router.append(mint_router); + // TODO add router for backend with webhooks // add lnn router here to server try handleInterrupt(&srv); @@ -288,13 +371,25 @@ pub fn main() !void { errdefer for (threads.items) |t| t.detach(); const thread_fn = (struct { - fn handleLnInvoice(m: *Mint, wait_ch: Channel(std.ArrayList(u8)).Rx) void { - while (true) { - var request_lookup_id = wait_ch.recv(); - defer request_lookup_id.deinit(); + fn handleLnInvoice(m: *Mint, wait_ch: ref.Arc(mpmc.UnboundedChannel(std.ArrayList(u8)))) void { + defer wait_ch.releaseWithFn((struct { + fn deinit(_self: mpmc.UnboundedChannel(std.ArrayList(u8))) void { + _self.deinit(); + } + }).deinit); + + var receiver = wait_ch.value.receiver() catch return; - handlePaidInvoice(m, request_lookup_id.items) catch |err| { - std.log.warn("handle paid invoice error, lookup_id {s}, err={s}", .{ request_lookup_id.items, @errorName(err) }); + while (true) { + var request_lookup_id = receiver.recv() orelse unreachable; + defer request_lookup_id.releaseWithFn((struct { + fn deinit(_self: std.ArrayList(u8)) void { + _self.deinit(); + } + }).deinit); + + handlePaidInvoice(m, request_lookup_id.value.items) catch |err| { + std.log.warn("handle paid invoice error, lookup_id {s}, err={s}", .{ request_lookup_id.value.items, @errorName(err) }); continue; }; } @@ -304,7 +399,7 @@ pub fn main() !void { var it = ln_backends.iterator(); while (it.next()) |ln_entry| { threads.appendAssumeCapacity(try std.Thread.spawn(.{}, thread_fn, .{ - &mint, try ln_entry.value_ptr.waitAnyInvoice(), + &mint, ln_entry.value_ptr.waitAnyInvoice(), })); } break :v threads; @@ -320,9 +415,9 @@ pub fn main() !void { std.log.info("Stopped server", .{}); } -pub fn handleInterrupt(srv: *httpz.Server(MintState)) !void { +pub fn handleInterrupt(srv: *httpz.Server(*http_router.GlobalRouter)) !void { const signal = struct { - var _srv: *httpz.Server(MintState) = undefined; + var _srv: *httpz.Server(*http_router.GlobalRouter) = undefined; fn handler(sig: c_int) callconv(.C) void { std.debug.assert(sig == std.posix.SIG.INT); diff --git a/src/mintd/config.example.toml b/src/mintd/config.example.toml index e68a07f..023eeed 100644 --- a/src/mintd/config.example.toml +++ b/src/mintd/config.example.toml @@ -1,4 +1,3 @@ - [info] url = "https://mint.thesimplekid.dev/" listen_host = "127.0.0.1" @@ -10,12 +9,12 @@ mnemonic = "" [mint_info] # name = "cdk-mintd mutiney net mint" -# Hex publey of mint +# Hex pubkey of mint # pubkey = "" # description = "These are not real sats for testing only" # description_long = "A longer mint for testing" # motd = "Hello world" -# mint_icon_url = "https://this-is-a-mint-icon-url.com/icon.png" +# icon_url = "https://this-is-a-mint-icon-url.com/icon.png" # contact_email = "hello@cashu.me" # Nostr pubkey of mint (Hex) # contact_nostr_public_key = "" @@ -26,14 +25,23 @@ mnemonic = "" # engine = "sqlite" [ln] -# Required ln backend `cln`, `strike`, `fakewallet` -ln_backend = "cln" - -# [cln] -# Required if using cln backend path to rpc -# cln_path = "" - -# [strike] -# api_key="" -# Optional default sats -# supported_units=[""] +# Required ln backend `cln`, `lnd`, `strike`, `fakewallet`, 'lnbits', 'phoenixd' +ln_backend = "lnbits" +# For 'phoenixd' backend, also specify fee_percent (% fee of the ln payment that mint will put in the melt quote) and reserve_fee_min (absolute amount-higher of fee_percent or reserve_fee_min is the fee reserve). +# fee_percent=0.04 +# reserve_fee_min=4 + + +# [lnbits] +# admin_api_key = "" +# invoice_api_key = "" +# lnbits_api = "" + +# [phoenixd] +# api_password = "" +# api_url = "" + +# [lnd] +# address = "" +# macaroon_file = "" +# cert_file = "" diff --git a/src/mintd/config.zig b/src/mintd/config.zig index eb89ab4..7b1bd6f 100644 --- a/src/mintd/config.zig +++ b/src/mintd/config.zig @@ -11,8 +11,7 @@ pub const Settings = struct { info: Info, mint_info: MintInfo, ln: Ln, - cln: ?Cln, - strike: ?Strike, + lnbits: ?Lnbits, fake_wallet: ?FakeWallet, database: Database, @@ -66,15 +65,12 @@ pub const MintInfo = struct { pub const LnBackend = enum { // default - cln, - strike, fake_wallet, - // Greenlight, - // Ldk, + lnbits, }; pub const Ln = struct { - ln_backend: LnBackend = .cln, + ln_backend: LnBackend = .lnbits, invoice_description: ?[]const u8, fee_percent: f32, reserve_fee_min: Amount, @@ -89,6 +85,12 @@ pub const Cln = struct { rpc_path: []const u8, }; +pub const Lnbits = struct { + admin_api_key: []const u8, + invoice_api_key: []const u8, + lnbits_api: []const u8, +}; + pub const FakeWallet = struct { supported_units: []const CurrencyUnit = &.{.sat}, }; diff --git a/src/misc/http_router/http_router.zig b/src/misc/http_router/http_router.zig new file mode 100644 index 0000000..596e3ef --- /dev/null +++ b/src/misc/http_router/http_router.zig @@ -0,0 +1,139 @@ +const httpz = @import("httpz"); +const std = @import("std"); + +pub const Router = struct { + const Self = @This(); + + ptr: *anyopaque, + allocator: std.mem.Allocator, + + handleFn: *const fn (ptr: *anyopaque, req: *httpz.Request, res: *httpz.Response) anyerror!?void, + deinitFn: *const fn (ptr: *anyopaque, allocator: std.mem.Allocator) void, + + pub fn initFrom(comptime T: type, _allocator: std.mem.Allocator, router: httpz.Router(T, httpz.Action(T))) !Self { + // implement gen structure + const gen = struct { + pub fn handle(pointer: *anyopaque, req: *httpz.Request, res: *httpz.Response) anyerror!?void { + const self: *httpz.Router(T, httpz.Action(T)) = @ptrCast(@alignCast(pointer)); + + const act = self.route(req.method, req.url.raw, req.params) orelse return null; + + return try act.action(self.handler, req, res); + } + + pub fn deinit(pointer: *anyopaque, allocator: std.mem.Allocator) void { + const self: *httpz.Router(T, httpz.Action(T)) = @ptrCast(@alignCast(pointer)); + + if (std.meta.hasFn(T, "deinit")) { + self.deinit(); + } + + allocator.destroy(self); + } + }; + + const ptr: *httpz.Router(T, httpz.Action(T)) align(1) = try _allocator.create(httpz.Router(T, httpz.Action(T))); + ptr.* = router; + + return .{ + .ptr = ptr, + .allocator = _allocator, + // .size = @sizeOf(T), + // .align_of = @alignOf(T), + .handleFn = gen.handle, + .deinitFn = gen.deinit, + }; + } + + /// free resources of database + pub fn deinit(self: Self) void { + self.deinitFn(self.ptr, self.allocator); + // clearing pointer + } + + pub fn handle(self: Self, req: *httpz.Request, res: *httpz.Response) anyerror!?void { + return self.handleFn(self.ptr, req, res); + } +}; + +pub const GlobalRouter = struct { + router: std.ArrayList(Router), + + pub fn notFound(self: *GlobalRouter, req: *httpz.Request, res: *httpz.Response) !void { + std.log.debug("trying to found {} {s} ", .{ + req.method, + req.url.path, + }); + + for (self.router.items) |*r| { + if (try r.handle(req, res)) |_| break; + } else { + return error.NotFound; + } + } +}; + +pub fn DefaultDispatcher(comptime T: type) type { + return struct { + pub fn dispatcher(h: T, action: httpz.Action(T), req: *httpz.Request, res: *httpz.Response) !void { + _ = h; // autofix + _ = action; // autofix + _ = req; // autofix + _ = res; // autofix + } + }; +} + +test "ttt" { + const SomeHandler = struct { + s: []const u8, + + pub usingnamespace DefaultDispatcher(@This()); + + pub fn testik(self: @This(), req: *httpz.Request, res: *httpz.Response) !void { + _ = req; // autofix + res.body = self.s; + } + }; + + const sh = SomeHandler{ + .s = "some_custom_response", + }; + + const SomeHandlerRouter = httpz.Router(SomeHandler, httpz.Action(SomeHandler)); + + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + defer arena.deinit(); + + var some_router = try SomeHandlerRouter.init(arena.allocator(), SomeHandler.dispatcher, sh); + some_router.get("/test14", SomeHandler.testik, .{}); + + const ht = @import("httpz").testing; + var router = GlobalRouter{ + .router = std.ArrayList(Router).init(std.testing.allocator), + }; + defer router.router.deinit(); + + const c_handler = try Router.initFrom(SomeHandler, std.testing.allocator, some_router); + defer c_handler.deinit(); + + try router.router.append(c_handler); + + var srv = try httpz.Server(*GlobalRouter).init(std.testing.allocator, .{}, &router); + + defer srv.deinit(); + + // httpz.Router(*Router, httpz.Action(*Router)).init(std.testing.allocator, , ) + + var _router = srv.router(.{}); + _router.get("/test1234", GlobalRouter.notFound, .{}); + + var web_test = ht.init(.{}); + defer web_test.deinit(); + + web_test.url("/test14"); + + try router.notFound(web_test.req, web_test.res); + + try web_test.expectBody(sh.s); +} diff --git a/src/router/router.zig b/src/router/router.zig index e0fccf5..91bef62 100644 --- a/src/router/router.zig +++ b/src/router/router.zig @@ -5,6 +5,7 @@ const std = @import("std"); const router_handlers = @import("router_handlers.zig"); const zul = @import("zul"); const fake_wallet = @import("../fake_wallet/fake_wallet.zig"); +const http_router = @import("../misc/http_router/http_router.zig"); const MintLightning = core.lightning.MintLightning; const Mint = core.mint.Mint; @@ -21,9 +22,7 @@ pub fn createMintServer( mint: *Mint, ln: LnBackendsMap, quote_ttl: u64, - server_options: httpz.Config, - middlewares: anytype, -) !httpz.Server(MintState) { +) !http_router.Router { // TODO do we need copy const state = MintState{ .mint = mint, @@ -32,19 +31,7 @@ pub fn createMintServer( .ln = ln, }; - var srv = try httpz.Server(MintState).init(allocator, server_options, state); - errdefer srv.deinit(); - - var _middlewares = try srv.arena.alloc(httpz.Middleware(MintState), middlewares.len); - - inline for (middlewares, 0..) |m, i| { - _middlewares[i] = try srv.middleware(m[0], m[1]); - } - - // apply routes - var router = srv.router(.{ - .middlewares = _middlewares, - }); + var router = try httpz.Router(MintState, httpz.Action(MintState)).init(allocator, MintState.dispatcher, state); router.get("/v1/keys", router_handlers.getKeys, .{}); router.get("/v1/keysets", router_handlers.getKeysets, .{}); @@ -59,7 +46,10 @@ pub fn createMintServer( router.post("/v1/mint/bolt11", router_handlers.postMintBolt11, .{}); router.post("/v1/melt/bolt11", router_handlers.postMeltBolt11, .{}); router.post("/v1/swap", router_handlers.postSwap, .{}); - return srv; + + const _router = try http_router.Router.initFrom(MintState, allocator, router); + + return _router; } pub const LnKeyContext = struct { @@ -87,6 +77,8 @@ pub const MintState = struct { mint_url: []const u8, quote_ttl: u64, + pub usingnamespace http_router.DefaultDispatcher(@This()); + pub fn uncaughtError(self: *const MintState, req: *httpz.Request, res: *httpz.Response, err: anyerror) void { _ = self; // autofix std.log.info("500 {} {s} {}", .{ req.method, req.url.path, err }); diff --git a/src/router/router_handlers.zig b/src/router/router_handlers.zig index 4c930c5..fdfae7e 100644 --- a/src/router/router_handlers.zig +++ b/src/router/router_handlers.zig @@ -114,6 +114,8 @@ pub fn getMintBolt11Quote( return error.InvalidPaymentRequest; }; + std.log.debug("{s} quote_id", .{create_invoice_response.request_lookup_id}); + const quote = state.mint.newMintQuote( res.arena, state.mint_url, @@ -121,7 +123,7 @@ pub fn getMintBolt11Quote( payload.unit, payload.amount, create_invoice_response.expiry orelse 0, - try zul.UUID.parse(create_invoice_response.request_lookup_id), + create_invoice_response.request_lookup_id, ) catch |err| { std.log.err("could not create new mint quote: {any}", .{err}); return error.InternalError; @@ -215,6 +217,8 @@ pub fn postMeltBolt11( req: *httpz.Request, res: *httpz.Response, ) !void { + errdefer std.log.debug("{any}", .{@errorReturnTrace()}); + const payload = try req.json(core.nuts.nut05.MeltBolt11Request) orelse return error.WrongRequest; const quote = state.mint.verifyMeltRequest(res.arena, payload) catch |err| { diff --git a/src/sync/mpmc.zig b/src/sync/mpmc.zig new file mode 100644 index 0000000..4fc3fd3 --- /dev/null +++ b/src/sync/mpmc.zig @@ -0,0 +1,297 @@ +// NOTE: WIP +const std = @import("std"); +const assert = std.debug.assert; +const Atomic = std.atomic.Value; +const ref = @import("ref.zig"); + +/// An UnboundedChannel with the following characteristics: +/// - non-blocking sender +/// - blocking receiver +/// - no mutex/rwlock +/// - data ordering is not guaranteed +/// - atomically reference counted data +/// - thread safe +pub fn UnboundedChannel(comptime T: type) type { + return struct { + arena: std.heap.ArenaAllocator, + stack: Atomic(usize), + cache: ?*Node, + closed: Atomic(bool), + + const Self = @This(); + + pub const Error = error{ + Closed, + }; + + /// Node which holds data. Ideally, these should be recycled. + pub const Node = struct { + next: ?*Node, + data: ref.Arc(T), + + pub fn init(allocator: std.mem.Allocator, data: T) !*Node { + const node = try allocator.create(Node); + node.* = .{ + .next = null, + .data = try ref.arc(allocator, data), + }; + return node; + } + + pub fn deinit(self: *Node, allocator: std.mem.Allocator) void { + self.data.release(); + allocator.destroy(self); + } + }; + + /// A batch of linked list nodes + pub const Batch = struct { + first: *Node, + last: *Node, + + pub fn init(first: *Node, last: *Node) Batch { + return .{ + .first = first, + .last = last, + }; + } + }; + + pub const Sender = struct { + allocator: std.mem.Allocator, + stack: *Atomic(usize), + closed: *Atomic(bool), + + fn init(allocator: std.mem.Allocator, stack: *Atomic(usize), closed: *Atomic(bool)) Sender { + return Sender{ + .allocator = allocator, + .stack = stack, + .closed = closed, + }; + } + + pub fn deinit(_: *Sender) void { + return; + } + + fn flush(self: *Sender, list: Batch) void { + var stack = self.stack.load(.seq_cst); + while (true) { + // Attach the list to the stack (pt. 1) + list.last.next = @as(?*Node, @ptrFromInt(stack & PTR_MASK)); + + // Update the stack with the list (pt. 2). + // Don't change the HAS_CACHE and IS_CONSUMING bits of the consumer. + var new_stack = @intFromPtr(list.first); + assert(new_stack & ~PTR_MASK == 0); + new_stack |= (stack & ~PTR_MASK); + + // Push to the stack with a release barrier for the consumer to see the proper list links. + stack = self.stack.cmpxchgStrong( + stack, + new_stack, + .release, + .monotonic, + ) orelse break; + } + } + + pub fn send(self: *Sender, data: T) error{Closed}!void { + if (self.closed.load(.monotonic)) { + return error.Closed; + } + const node = Node.init(self.allocator, data) catch unreachable; + self.flush(Batch.init(node, node)); + } + }; + + pub const Receiver = struct { + allocator: std.mem.Allocator, + stack: *Atomic(usize), + ref: ?*Node, + cache: *?*Node, + acquired: bool, + closed: *Atomic(bool), + + fn init(allocator: std.mem.Allocator, stack: *Atomic(usize), closed: *Atomic(bool), cache: *?*Node) Receiver { + return Receiver{ .allocator = allocator, .stack = stack, .ref = null, .cache = cache, .acquired = false, .closed = closed }; + } + + pub fn deinit(noalias self: *Receiver) void { + // Stop consuming and remove the HAS_CACHE bit as well if the consumer's cache is empty. + // When HAS_CACHE bit is zeroed, the next consumer will acquire the pushed stack nodes. + var remove = IS_CONSUMING; + if (self.ref == null) + remove |= HAS_CACHE; + + // Release the consumer with a release barrier to ensure cache/node accesses + // happen before the consumer was released and before the next consumer starts using the cache. + self.cache.* = self.ref; + const stack = self.stack.fetchSub(remove, .seq_cst); + assert(stack & remove != 0); + } + + fn tryAcquireConsumer(noalias self: *Receiver) error{ Empty, Contended }!?*Node { + var stack = self.stack.load(.monotonic); + while (true) { + if (stack & IS_CONSUMING != 0) + return error.Contended; // The queue already has a consumer. + if (stack & (HAS_CACHE | PTR_MASK) == 0) + return error.Empty; // The queue is empty when there's nothing cached and nothing in the stack. + + // When we acquire the consumer, also consume the pushed stack if the cache is empty. + var new_stack = stack | HAS_CACHE | IS_CONSUMING; + if (stack & HAS_CACHE == 0) { + assert(stack & PTR_MASK != 0); + new_stack &= ~PTR_MASK; + } + + // Acquire barrier on getting the consumer to see cache/Node updates done by previous consumers + // and to ensure our cache/Node updates in pop() happen after that of previous consumers. + stack = self.stack.cmpxchgStrong( + stack, + new_stack, + .acquire, + .monotonic, + ) orelse return self.cache.* orelse @as(*Node, @ptrFromInt(stack & PTR_MASK)); + } + } + + pub fn recv(noalias self: *Receiver) ?ref.Arc(T) { + while (!self.closed.load(.monotonic)) { + while (!self.acquired) { + if (self.closed.load(.monotonic)) { + return null; + } + self.ref = self.tryAcquireConsumer() catch continue; + self.acquired = true; + } + + // Check the consumer cache (fast path) + if (self.ref) |node| { + self.ref = node.next; + // grab the data and retain arc + const data = node.data.retain(); + // deinit the node + node.deinit(self.allocator); + // return the data + return data; + } + + // Load the stack to see if there was anything pushed that we could grab. + var stack = self.stack.load(.monotonic); + assert(stack & IS_CONSUMING != 0); + if (stack & PTR_MASK == 0) { + continue; + } + + // Nodes have been pushed to the stack, grab then with an Acquire barrier to see the Node links. + stack = self.stack.swap(HAS_CACHE | IS_CONSUMING, .acquire); + assert(stack & IS_CONSUMING != 0); + assert(stack & PTR_MASK != 0); + + const node = @as(*Node, @ptrFromInt(stack & PTR_MASK)); + self.ref = node.next; + const data = node.data.retain(); + node.deinit(self.allocator); + return data; + } + return null; + } + }; + + const HAS_CACHE: usize = 0b01; + const IS_CONSUMING: usize = 0b10; + const PTR_MASK: usize = ~(HAS_CACHE | IS_CONSUMING); + + pub fn init(allocator: std.mem.Allocator) Self { + return .{ + .stack = Atomic(usize).init(0), + .cache = null, + .arena = std.heap.ArenaAllocator.init(allocator), + .closed = Atomic(bool).init(false), + }; + } + + pub fn deinit(self: *const Self) void { + self.arena.deinit(); + } + + comptime { + assert(@alignOf(Node) >= ((IS_CONSUMING | HAS_CACHE) + 1)); + } + + pub fn close(self: *Self) void { + self.closed.store(true, .monotonic); + } + + pub fn sender(self: *Self) error{Closed}!Sender { + if (self.closed.load(.seq_cst)) { + return error.Closed; + } + return Sender.init(self.arena.allocator(), &self.stack, &self.closed); + } + + pub fn receiver(self: *Self) error{Closed}!Receiver { + if (self.closed.load(.seq_cst)) { + return error.Closed; + } + return Receiver.init(self.arena.allocator(), &self.stack, &self.closed, &self.cache); + } + }; +} + +fn startSender(chan: *UnboundedChannel(u32), count: usize) !void { + var sender = try chan.sender(); + defer sender.deinit(); + + for (0..count) |i| { + try sender.send(@intCast(i)); + if ((i % 4) == 0) { + std.time.sleep(2); + } + } + std.time.sleep(std.time.ns_per_s * 1); + chan.close(); +} + +fn startReceiver(chan: *UnboundedChannel(u32), observed: *usize, id: u8) !void { + var receiver = chan.receiver() catch return; + defer receiver.deinit(); + + var i: usize = 0; + while (receiver.recv()) |data| { + i += 1; + defer data.release(); + std.log.info("[{any}] got data: {any}", .{ id, data.value.* }); + observed.* += 1; + if ((i % 10) == 0) { + std.time.sleep(3); + } + } +} + +const testing = std.testing; + +test "sync.mpmc: UnboundedChannel works" { + return error.SkipZigTest; + + // var chan = UnboundedChannel(u32).init(testing.allocator); + // defer chan.deinit(); + + // var items_to_produce: usize = 1000; + // var observed: usize = 0; + + // var send_handle = try std.Thread.spawn(.{}, startSender, .{ &chan, items_to_produce }); + // var recv_handle_1 = try std.Thread.spawn(.{}, startReceiver, .{ &chan, &observed, 1 }); + // // std.time.sleep(1); + // var recv_handle_2 = try std.Thread.spawn(.{}, startReceiver, .{ &chan, &observed, 2 }); + // // std.time.sleep(1); + // var recv_handle_3 = try std.Thread.spawn(.{}, startReceiver, .{ &chan, &observed, 3 }); + + // send_handle.join(); + // recv_handle_1.join(); + // recv_handle_2.join(); + // recv_handle_3.join(); + // try testing.expectEqual(items_to_produce, observed); +} diff --git a/src/sync/ref.zig b/src/sync/ref.zig new file mode 100644 index 0000000..831cf25 --- /dev/null +++ b/src/sync/ref.zig @@ -0,0 +1,517 @@ +// Fork of: https://github.com/Aandreba/zigrc with minor changes +// +// MIT License +// +// Copyright (c) 2023 Alex Andreba +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +const std = @import("std"); +const builtin = @import("builtin"); + +/// This variable is `true` if an atomic reference-counter is used for `Arc`, `false` otherwise. +/// +/// If the target is single-threaded, `Arc` is optimized to a regular `Rc`. +pub const atomic_arc = !builtin.single_threaded or (builtin.target.isWasm() and std.Target.wasm.featureSetHas(builtin.cpu.features, .atomics)); + +/// A single threaded, strong reference to a reference-counted value. +pub fn Rc(comptime T: type) type { + return struct { + value: *T, + alloc: std.mem.Allocator, + + const Self = @This(); + const Inner = struct { + strong: usize, + weak: usize, + value: T, + + fn innerSize() comptime_int { + return @sizeOf(@This()); + } + + fn innerAlign() comptime_int { + return @alignOf(@This()); + } + }; + + /// Creates a new reference-counted value. + pub fn init(alloc: std.mem.Allocator, t: T) error{OutOfMemory}!Self { + const inner = try alloc.create(Inner); + inner.* = Inner{ .strong = 1, .weak = 1, .value = t }; + return Self{ .value = &inner.value, .alloc = alloc }; + } + + /// Constructs a new `Rc` while giving you a `Weak` to the allocation, + /// to allow you to construct a `T` which holds a weak pointer to itself. + pub fn initCyclic(alloc: std.mem.Allocator, comptime data_fn: fn (*Weak) T) error{OutOfMemory}!Self { + const inner = try alloc.create(Inner); + inner.* = Inner{ .strong = 0, .weak = 1, .value = undefined }; + + // Strong references should collectively own a shared weak reference, + // so don't run the destructor for our old weak reference. + var weak = Weak{ .inner = inner, .alloc = alloc }; + + // It's important we don't give up ownership of the weak pointer, or + // else the memory might be freed by the time `data_fn` returns. If + // we really wanted to pass ownership, we could create an additional + // weak pointer for ourselves, but this would result in additional + // updates to the weak reference count which might not be necessary + // otherwise. + inner.value = data_fn(&weak); + + std.debug.assert(inner.strong == 0); + inner.strong = 1; + + return Self{ .value = &inner.value, .alloc = alloc }; + } + + /// Gets the number of strong references to this value. + pub fn strongCount(self: *const Self) usize { + return self.innerPtr().strong; + } + + /// Gets the number of weak references to this value. + pub fn weakCount(self: *const Self) usize { + return self.innerPtr().weak - 1; + } + + /// Increments the strong count. + pub fn retain(self: *Self) Self { + self.innerPtr().strong += 1; + return self.*; + } + + /// Creates a new weak reference to the pointed value + pub fn downgrade(self: *Self) Weak { + return Weak.init(self); + } + + /// Decrements the reference count, deallocating if the weak count reaches zero. + /// The continued use of the pointer after calling `release` is undefined behaviour. + pub fn release(self: Self) void { + const ptr = self.innerPtr(); + + ptr.strong -= 1; + if (ptr.strong == 0) { + ptr.weak -= 1; + if (ptr.weak == 0) { + self.alloc.destroy(ptr); + } + } + } + + /// Decrements the reference count, deallocating the weak count reaches zero, + /// and executing `f` if the strong count reaches zero. + /// The continued use of the pointer after calling `release` is undefined behaviour. + pub fn releaseWithFn(self: Self, comptime f: fn (T) void) void { + const ptr = self.innerPtr(); + + ptr.strong -= 1; + if (ptr.strong == 0) { + f(self.value.*); + + ptr.weak -= 1; + if (ptr.weak == 0) { + self.alloc.destroy(ptr); + } + } + } + + /// Returns the inner value, if the `Rc` has exactly one strong reference. + /// Otherwise, `null` is returned. + /// This will succeed even if there are outstanding weak references. + /// The continued use of the pointer if the method successfully returns `T` is undefined behaviour. + pub fn tryUnwrap(self: Self) ?T { + const ptr = self.innerPtr(); + + if (ptr.strong == 1) { + ptr.strong = 0; + const tmp = self.value.*; + + ptr.weak -= 1; + if (ptr.weak == 0) { + self.alloc.destroy(ptr); + } + + return tmp; + } + + return null; + } + + /// Total size (in bytes) of the reference counted value on the heap. + /// This value accounts for the extra memory required to count the references. + pub fn innerSize() comptime_int { + return Inner.innerSize(); + } + + /// Alignment (in bytes) of the reference counted value on the heap. + /// This value accounts for the extra memory required to count the references. + pub fn innerAlign() comptime_int { + return Inner.innerAlign(); + } + + inline fn innerPtr(self: *const Self) *Inner { + return @fieldParentPtr("value", self.value); + } + + /// A single threaded, weak reference to a reference-counted value. + pub const Weak = struct { + inner: ?*Inner = null, + alloc: std.mem.Allocator, + + /// Creates a new weak reference. + pub fn init(parent: *Rc(T)) Weak { + const ptr = parent.innerPtr(); + ptr.weak += 1; + return Weak{ .inner = ptr, .alloc = parent.alloc }; + } + + /// Creates a new weak reference object from a pointer to it's underlying value, + /// without increasing the weak count. + pub fn fromValuePtr(value: *T, alloc: std.mem.Allocator) Weak { + return .{ .inner = @fieldParentPtr("value", value), .alloc = alloc }; + } + + /// Gets the number of strong references to this value. + pub fn strongCount(self: *const Weak) usize { + return (self.innerPtr() orelse return 0).strong; + } + + /// Gets the number of weak references to this value. + pub fn weakCount(self: *const Weak) usize { + const ptr = self.innerPtr() orelse return 1; + if (ptr.strong == 0) { + return ptr.weak; + } else { + return ptr.weak - 1; + } + } + + /// Increments the weak count. + pub fn retain(self: *Weak) Weak { + if (self.innerPtr()) |ptr| { + ptr.weak += 1; + } + return self.*; + } + + /// Attempts to upgrade the weak pointer to an `Rc`, delaying dropping of the inner value if successful. + /// + /// Returns `null` if the inner value has since been dropped. + pub fn upgrade(self: *Weak) ?Rc(T) { + const ptr = self.innerPtr() orelse return null; + + if (ptr.strong == 0) { + ptr.weak -= 1; + if (ptr.weak == 0) { + self.alloc.destroy(ptr); + self.inner = null; + } + return null; + } + + ptr.strong += 1; + return Rc(T){ + .value = &ptr.value, + .alloc = self.alloc, + }; + } + + /// Decrements the weak reference count, deallocating if it reaches zero. + /// The continued use of the pointer after calling `release` is undefined behaviour. + pub fn release(self: Weak) void { + if (self.innerPtr()) |ptr| { + ptr.weak -= 1; + if (ptr.weak == 0) { + self.alloc.destroy(ptr); + } + } + } + + /// Total size (in bytes) of the reference counted value on the heap. + /// This value accounts for the extra memory required to count the references, + /// and is valid for single and multi-threaded refrence counters. + pub fn innerSize() comptime_int { + return Inner.innerSize(); + } + + /// Alignment (in bytes) of the reference counted value on the heap. + /// This value accounts for the extra memory required to count the references, + /// and is valid for single and multi-threaded refrence counters. + pub fn innerAlign() comptime_int { + return Inner.innerAlign(); + } + + inline fn innerPtr(self: *const Weak) ?*Inner { + return @as(?*Inner, @ptrCast(self.inner)); + } + }; + }; +} + +/// A multi-threaded, strong reference to a reference-counted value. +pub fn Arc(comptime T: type) type { + if (!atomic_arc) { + return Rc(T); + } + + return struct { + value: *T, + alloc: std.mem.Allocator, + + const Self = @This(); + const Inner = struct { + strong: usize align(std.atomic.cache_line), + weak: usize align(std.atomic.cache_line), + value: T, + + fn innerSize() comptime_int { + return @sizeOf(@This()); + } + + fn innerAlign() comptime_int { + return @alignOf(@This()); + } + }; + + /// Creates a new reference-counted value. + pub fn init(alloc: std.mem.Allocator, t: T) error{OutOfMemory}!Self { + const inner = try alloc.create(Inner); + inner.* = Inner{ .strong = 1, .weak = 1, .value = t }; + return Self{ .value = &inner.value, .alloc = alloc }; + } + + /// Constructs a new `Arc` while giving you a `weak` to the allocation, + /// to allow you to construct a `T` which holds a weak pointer to itself. + pub fn initCyclic(alloc: std.mem.Allocator, comptime data_fn: fn (*Weak) T) error{OutOfMemory}!Self { + const inner = try alloc.create(Inner); + inner.* = Inner{ .strong = 0, .weak = 1, .value = undefined }; + + // Strong references should collectively own a shared weak reference, + // so don't run the destructor for our old weak reference. + var weak = Weak{ .inner = inner, .alloc = alloc }; + + // It's important we don't give up ownership of the weak pointer, or + // else the memory might be freed by the time `data_fn` returns. If + // we really wanted to pass ownership, we could create an additional + // weak pointer for ourselves, but this would result in additional + // updates to the weak reference count which might not be necessary + // otherwise. + inner.value = data_fn(&weak); + + std.debug.assert(@atomicRmw(usize, &inner.strong, .Add, 1, .release) == 0); + return Self{ .value = &inner.value, .alloc = alloc }; + } + + /// Gets the number of strong references to this value. + pub fn strongCount(self: *const Self) usize { + return @atomicLoad(usize, &self.innerPtr().strong, .acquire); + } + + /// Gets the number of weak references to this value. + pub fn weakCount(self: *const Self) usize { + return @atomicLoad(usize, &self.innerPtr().weak, .acquire) - 1; + } + + /// Increments the strong count. + pub fn retain(self: *Self) Self { + _ = @atomicRmw(usize, &self.innerPtr().strong, .Add, 1, .acq_rel); + return self.*; + } + + /// Creates a new weak reference to the pointed value. + pub fn downgrade(self: *Self) Weak { + return Weak.init(self); + } + + /// Decrements the reference count, deallocating if the weak count reaches zero. + /// The continued use of the pointer after calling `release` is undefined behaviour. + pub fn release(self: Self) void { + const ptr = self.innerPtr(); + + if (@atomicRmw(usize, &ptr.strong, .Sub, 1, .acq_rel) == 1) { + if (@atomicRmw(usize, &ptr.weak, .Sub, 1, .acq_rel) == 1) { + self.alloc.destroy(ptr); + } + } + } + + /// Decrements the reference count, deallocating the weak count reaches zero, + /// and executing `f` if the strong count reaches zero. + /// The continued use of the pointer after calling `release` is undefined behaviour. + pub fn releaseWithFn(self: Self, comptime f: fn (T) void) void { + const ptr = self.innerPtr(); + + if (@atomicRmw(usize, &ptr.strong, .Sub, 1, .acq_rel) == 1) { + f(self.value.*); + if (@atomicRmw(usize, &ptr.weak, .Sub, 1, .acq_rel) == 1) { + self.alloc.destroy(ptr); + } + } + } + + /// Returns the inner value, if the `Arc` has exactly one strong reference. + /// Otherwise, `null` is returned. + /// This will succeed even if there are outstanding weak references. + /// The continued use of the pointer if the method successfully returns `T` is undefined behaviour. + pub fn tryUnwrap(self: Self) ?T { + const ptr = self.innerPtr(); + + if (@cmpxchgStrong(usize, &ptr.strong, 1, 0, .monotonic, .monotonic) == null) { + ptr.strong = 0; + const tmp = self.value.*; + if (@atomicRmw(usize, &ptr.weak, .Sub, 1, .acq_rel) == 1) { + self.alloc.destroy(ptr); + } + return tmp; + } + + return null; + } + + /// Total size (in bytes) of the reference counted value on the heap. + /// This value accounts for the extra memory required to count the references. + pub fn innerSize() comptime_int { + return Inner.innerSize(); + } + + /// Alignment (in bytes) of the reference counted value on the heap. + /// This value accounts for the extra memory required to count the references. + pub fn innerAlign() comptime_int { + return Inner.innerAlign(); + } + + inline fn innerPtr(self: *const Self) *Inner { + return @alignCast(@fieldParentPtr("value", self.value)); + } + + /// A multi-threaded, weak reference to a reference-counted value. + pub const Weak = struct { + inner: ?*Inner = null, + alloc: std.mem.Allocator, + + /// Creates a new weak reference. + pub fn init(parent: *Arc(T)) Weak { + const ptr = parent.innerPtr(); + _ = @atomicRmw(usize, &ptr.weak, .Add, 1, .acq_rel); + return Weak{ .inner = ptr, .alloc = parent.alloc }; + } + + /// Creates a new weak reference object from a pointer to it's underlying value, + /// without increasing the weak count. + pub fn fromValuePtr(value: *T, alloc: std.mem.Allocator) Weak { + return .{ .inner = @fieldParentPtr("value", value), .alloc = alloc }; + } + + /// Gets the number of strong references to this value. + pub fn strongCount(self: *const Weak) usize { + const ptr = self.innerPtr() orelse return 0; + return @atomicLoad(usize, &ptr.strong, .acquire); + } + + /// Gets the number of weak references to this value. + pub fn weakCount(self: *const Weak) usize { + const ptr = self.innerPtr() orelse return 1; + const weak = @atomicLoad(usize, &ptr.weak, .acquire); + + if (@atomicLoad(usize, &ptr.strong, .acquire) == 0) { + return weak; + } else { + return weak - 1; + } + } + + /// Increments the weak count. + pub fn retain(self: *Weak) Weak { + if (self.innerPtr()) |ptr| { + _ = @atomicRmw(usize, &ptr.weak, .Add, 1, .acq_rel); + } + return self.*; + } + + /// Attempts to upgrade the weak pointer to an `Arc`, delaying dropping of the inner value if successful. + /// + /// Returns `null` if the inner value has since been dropped. + pub fn upgrade(self: *Weak) ?Arc(T) { + const ptr = self.innerPtr() orelse return null; + + while (true) { + const prev = @atomicLoad(usize, &ptr.strong, .acquire); + + if (prev == 0) { + if (@atomicRmw(usize, &ptr.weak, .Sub, 1, .acq_rel) == 1) { + self.alloc.destroy(ptr); + self.inner = null; + } + return null; + } + + if (@cmpxchgStrong(usize, &ptr.strong, prev, prev + 1, .acquire, .monotonic) == null) { + return Arc(T){ + .value = &ptr.value, + .alloc = self.alloc, + }; + } + + std.atomic.spinLoopHint(); + } + } + + /// Decrements the weak reference count, deallocating if it reaches zero. + /// The continued use of the pointer after calling `release` is undefined behaviour. + pub fn release(self: Weak) void { + if (self.innerPtr()) |ptr| { + if (@atomicRmw(usize, &ptr.weak, .Sub, 1, .acq_rel) == 1) { + self.alloc.destroy(ptr); + } + } + } + + /// Total size (in bytes) of the reference counted value on the heap. + /// This value accounts for the extra memory required to count the references. + pub fn innerSize() comptime_int { + return Inner.innerSize(); + } + + /// Alignment (in bytes) of the reference counted value on the heap. + /// This value accounts for the extra memory required to count the references. + pub fn innerAlign() comptime_int { + return Inner.innerAlign(); + } + + inline fn innerPtr(self: *const Weak) ?*Inner { + return @as(?*Inner, @ptrCast(self.inner)); + } + }; + }; +} + +const Allocator = std.mem.Allocator; + +/// Creates a new `Rc` inferring the type of `value` +pub fn rc(alloc: Allocator, value: anytype) Allocator.Error!Rc(@TypeOf(value)) { + return Rc(@TypeOf(value)).init(alloc, value); +} + +/// Creates a new `Arc` inferring the type of `value` +pub fn arc(alloc: Allocator, value: anytype) Allocator.Error!Arc(@TypeOf(value)) { + return Arc(@TypeOf(value)).init(alloc, value); +}