From 401cd4a458e9082ee4921a6f87e1d504462b1c53 Mon Sep 17 00:00:00 2001 From: im-adithya Date: Mon, 4 Dec 2023 16:21:06 +0530 Subject: [PATCH 1/4] chore: refactor payment checks --- controllers/addinvoice.ctrl.go | 9 +++++ controllers/keysend.ctrl.go | 15 ++------ controllers/payinvoice.ctrl.go | 13 ++----- controllers_v2/invoice.ctrl.go | 9 +++++ controllers_v2/keysend.ctrl.go | 11 +++--- controllers_v2/payinvoice.ctrl.go | 15 ++------ lib/service/invoices.go | 46 ----------------------- lib/service/user.go | 61 ++++++++++++++++++++++++++++++- 8 files changed, 93 insertions(+), 86 deletions(-) diff --git a/controllers/addinvoice.ctrl.go b/controllers/addinvoice.ctrl.go index 8eda455d..90ccbdda 100644 --- a/controllers/addinvoice.ctrl.go +++ b/controllers/addinvoice.ctrl.go @@ -61,6 +61,15 @@ func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } + resp, err := svc.CheckIncomingPaymentAllowed(c.Request().Context(), amount, userID) + if err != nil { + return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) + } + if resp != nil { + c.Logger().Errorf("Error: %v user_id:%v amount:%v", resp.Message, userID, amount) + return c.JSON(resp.HttpStatusCode, resp) + } + c.Logger().Infof("Adding invoice: user_id:%v memo:%s value:%v description_hash:%s", userID, body.Memo, amount, body.DescriptionHash) invoice, errResp := svc.AddIncomingInvoice(c.Request().Context(), userID, amount, body.Memo, body.DescriptionHash) diff --git a/controllers/keysend.ctrl.go b/controllers/keysend.ctrl.go index f9fdb92c..0eda4062 100644 --- a/controllers/keysend.ctrl.go +++ b/controllers/keysend.ctrl.go @@ -75,20 +75,13 @@ func (controller *KeySendController) KeySend(c echo.Context) error { }) } - resp, err := controller.svc.CheckPaymentAllowed(c.Request().Context(), lnPayReq, userID) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID) if err != nil { - c.Logger().Errorj( - log.JSON{ - "message": "failed to check balance", - "error": err, - "lndhub_user_id": userID, - }, - ) - return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) + return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) } if resp != nil { - c.Logger().Errorf("User does not have enough balance user_id:%v amount:%v", userID, lnPayReq.PayReq.NumSatoshis) - return c.JSON(http.StatusBadRequest, resp) + c.Logger().Errorf("Error: %v user_id:%v amount:%v", resp.Message, userID, lnPayReq.PayReq.NumSatoshis) + return c.JSON(resp.HttpStatusCode, resp) } invoice, errResp := controller.svc.AddOutgoingInvoice(c.Request().Context(), userID, "", lnPayReq) if errResp != nil { diff --git a/controllers/payinvoice.ctrl.go b/controllers/payinvoice.ctrl.go index 827de5d5..04dd3976 100644 --- a/controllers/payinvoice.ctrl.go +++ b/controllers/payinvoice.ctrl.go @@ -90,19 +90,12 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { lnPayReq.PayReq.NumSatoshis = amt } - resp, err := controller.svc.CheckPaymentAllowed(c.Request().Context(), lnPayReq, userID) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID) if err != nil { - c.Logger().Errorj( - log.JSON{ - "message": "error checking balance", - "error": err, - "lndhub_user_id": userID, - }, - ) - return c.JSON(http.StatusBadRequest, responses.GeneralServerError) + return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) } if resp != nil { - c.Logger().Errorf("User does not have enough balance user_id:%v amount:%v", userID, lnPayReq.PayReq.NumSatoshis) + c.Logger().Errorf("Error: %v user_id:%v amount:%v", resp.Message, userID, lnPayReq.PayReq.NumSatoshis) return c.JSON(http.StatusBadRequest, resp) } diff --git a/controllers_v2/invoice.ctrl.go b/controllers_v2/invoice.ctrl.go index fff42a9f..d38a3530 100644 --- a/controllers_v2/invoice.ctrl.go +++ b/controllers_v2/invoice.ctrl.go @@ -177,6 +177,15 @@ func (controller *InvoiceController) AddInvoice(c echo.Context) error { return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } + resp, err := controller.svc.CheckIncomingPaymentAllowed(c.Request().Context(), body.Amount, userID) + if err != nil { + return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) + } + if resp != nil { + c.Logger().Errorf("Error: %v user_id:%v amount:%v", resp.Message, userID, body.Amount) + return c.JSON(resp.HttpStatusCode, resp) + } + c.Logger().Infof("Adding invoice: user_id:%v memo:%s value:%v description_hash:%s", userID, body.Description, body.Amount, body.DescriptionHash) invoice, errResp := controller.svc.AddIncomingInvoice(c.Request().Context(), userID, body.Amount, body.Description, body.DescriptionHash) diff --git a/controllers_v2/keysend.ctrl.go b/controllers_v2/keysend.ctrl.go index f49b711a..1c9e28c8 100644 --- a/controllers_v2/keysend.ctrl.go +++ b/controllers_v2/keysend.ctrl.go @@ -169,15 +169,14 @@ func (controller *KeySendController) checkKeysendPaymentAllowed(ctx context.Cont }, Keysend: true, } - resp, err := controller.svc.CheckPaymentAllowed(ctx, syntheticPayReq, userID) - if resp != nil { - controller.svc.Logger.Errorf("User does not have enough balance user_id:%v amount:%v", userID, syntheticPayReq.PayReq.NumSatoshis) - return resp - } + resp, err := controller.svc.CheckOutgoingPaymentAllowed(ctx, syntheticPayReq, userID) if err != nil { - controller.svc.Logger.Error(err) return &responses.GeneralServerError } + if resp != nil { + controller.svc.Logger.Errorf("Error: %v user_id:%v amount:%v", resp.Message, userID, syntheticPayReq.PayReq.NumSatoshis) + return resp + } return nil } diff --git a/controllers_v2/payinvoice.ctrl.go b/controllers_v2/payinvoice.ctrl.go index 627c95fe..b1375072 100644 --- a/controllers_v2/payinvoice.ctrl.go +++ b/controllers_v2/payinvoice.ctrl.go @@ -98,20 +98,13 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { } lnPayReq.PayReq.NumSatoshis = amt } - resp, err := controller.svc.CheckPaymentAllowed(c.Request().Context(), lnPayReq, userID) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID) if err != nil { - c.Logger().Errorj( - log.JSON{ - "message": "error checking balance", - "error": err, - "lndhub_user_id": userID, - }, - ) - return err + return c.JSON(http.StatusBadRequest, responses.GeneralServerError) } if resp != nil { - c.Logger().Errorf("User does not have enough balance user_id:%v amount:%v", userID, lnPayReq.PayReq.NumSatoshis) - return c.JSON(http.StatusInternalServerError, resp) + c.Logger().Errorf("Error: %v user_id:%v amount:%v", resp.Message, userID, lnPayReq.PayReq.NumSatoshis) + return c.JSON(resp.HttpStatusCode, resp) } invoice, errResp := controller.svc.AddOutgoingInvoice(c.Request().Context(), userID, paymentRequest, lnPayReq) if errResp != nil { diff --git a/lib/service/invoices.go b/lib/service/invoices.go index 873ae491..0b2ba4b1 100644 --- a/lib/service/invoices.go +++ b/lib/service/invoices.go @@ -16,7 +16,6 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lnd" "github.com/getsentry/sentry-go" - "github.com/labstack/gommon/log" "github.com/lightningnetwork/lnd/lnrpc" "github.com/uptrace/bun" "github.com/uptrace/bun/schema" @@ -442,13 +441,6 @@ func (svc *LndhubService) HandleSuccessfulPayment(ctx context.Context, invoice * } func (svc *LndhubService) AddOutgoingInvoice(ctx context.Context, userID int64, paymentRequest string, lnPayReq *lnd.LNPayReq) (*models.Invoice, *responses.ErrorResponse) { - if svc.Config.MaxSendAmount > 0 { - if lnPayReq.PayReq.NumSatoshis > svc.Config.MaxSendAmount { - svc.Logger.Errorf("Max send amount exceeded for user_id %v (amount:%v)", userID, lnPayReq.PayReq.NumSatoshis) - return nil, &responses.SendExceededError - } - } - // Initialize new DB invoice invoice := models.Invoice{ Type: common.InvoiceTypeOutgoing, @@ -487,44 +479,6 @@ func (svc *LndhubService) AddOutgoingInvoice(ctx context.Context, userID int64, } func (svc *LndhubService) AddIncomingInvoice(ctx context.Context, userID int64, amount int64, memo, descriptionHashStr string) (*models.Invoice, *responses.ErrorResponse) { - - if svc.Config.MaxReceiveAmount > 0 { - if amount > svc.Config.MaxReceiveAmount { - svc.Logger.Errorf("Max receive amount exceeded for user_id %d", userID) - return nil, &responses.ReceiveExceededError - } - } - - if svc.Config.MaxAccountBalance > 0 { - currentBalance, err := svc.CurrentUserBalance(ctx, userID) - if err != nil { - svc.Logger.Errorj( - log.JSON{ - "message": "error fetching balance", - "lndhub_user_id": userID, - "error": err, - }, - ) - return nil, &responses.GeneralServerError - } - if currentBalance+amount > svc.Config.MaxAccountBalance { - svc.Logger.Errorf("Max account balance exceeded for user_id %d", userID) - return nil, &responses.BalanceExceededError - } - } - - if svc.Config.MaxVolume > 0 { - volume, err := svc.GetVolumeOverPeriod(ctx, userID, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) - if err != nil { - return nil, &responses.GeneralServerError - } - if volume > svc.Config.MaxVolume { - svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userID) - sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userID)) - return nil, &responses.TooMuchVolumeError - } - } - preimage, err := makePreimageHex() if err != nil { return nil, &responses.GeneralServerError diff --git a/lib/service/user.go b/lib/service/user.go index 26d6d4db..84f1c51f 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -13,6 +13,7 @@ import ( "github.com/getAlby/lndhub.go/lib/security" "github.com/getAlby/lndhub.go/lnd" "github.com/getsentry/sentry-go" + "github.com/labstack/gommon/log" "github.com/uptrace/bun" passwordvalidator "github.com/wagslane/go-password-validator" ) @@ -124,9 +125,23 @@ func (svc *LndhubService) FindUserByLogin(ctx context.Context, login string) (*m return &user, nil } -func (svc *LndhubService) CheckPaymentAllowed(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64) (result *responses.ErrorResponse, err error) { +func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64) (result *responses.ErrorResponse, err error) { + if svc.Config.MaxSendAmount > 0 { + if lnpayReq.PayReq.NumSatoshis > svc.Config.MaxSendAmount { + svc.Logger.Errorf("Max send amount exceeded for user_id %v (amount:%v)", userId, lnpayReq.PayReq.NumSatoshis) + return &responses.SendExceededError, nil + } + } + currentBalance, err := svc.CurrentUserBalance(ctx, userId) if err != nil { + svc.Logger.Errorj( + log.JSON{ + "message": "error checking balance", + "error": err, + "lndhub_user_id": userId, + }, + ) return nil, err } @@ -137,13 +152,54 @@ func (svc *LndhubService) CheckPaymentAllowed(ctx context.Context, lnpayReq *lnd if currentBalance < minimumBalance { return &responses.NotEnoughBalanceError, nil } - //only check for volume if configured + + return svc.CheckVolumeAllowed(ctx, userId) +} + +func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64) (result *responses.ErrorResponse, err error) { + if svc.Config.MaxReceiveAmount > 0 { + if amount > svc.Config.MaxReceiveAmount { + svc.Logger.Errorf("Max receive amount exceeded for user_id %d", userId) + return &responses.ReceiveExceededError, nil + } + } + + if svc.Config.MaxAccountBalance > 0 { + currentBalance, err := svc.CurrentUserBalance(ctx, userId) + if err != nil { + svc.Logger.Errorj( + log.JSON{ + "message": "error fetching balance", + "lndhub_user_id": userId, + "error": err, + }, + ) + return nil, err + } + if currentBalance+amount > svc.Config.MaxAccountBalance { + svc.Logger.Errorf("Max account balance exceeded for user_id %d", userId) + return &responses.BalanceExceededError, nil + } + } + + return svc.CheckVolumeAllowed(ctx, userId) +} + +func (svc *LndhubService) CheckVolumeAllowed(ctx context.Context, userId int64) (result *responses.ErrorResponse, err error) { if svc.Config.MaxVolume > 0 { volume, err := svc.GetVolumeOverPeriod(ctx, userId, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) if err != nil { + svc.Logger.Errorj( + log.JSON{ + "message": "error fetching volume", + "error": err, + "lndhub_user_id": userId, + }, + ) return nil, err } if volume > svc.Config.MaxVolume { + svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userId) sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId)) return &responses.TooMuchVolumeError, nil } @@ -151,6 +207,7 @@ func (svc *LndhubService) CheckPaymentAllowed(ctx context.Context, lnpayReq *lnd return nil, nil } + func (svc *LndhubService) CalcFeeLimit(destination string, amount int64) int64 { if svc.LndClient.IsIdentityPubkey(destination) { return 0 From f54246681120fcc359591f4a8e5504a84527571c Mon Sep 17 00:00:00 2001 From: im-adithya Date: Tue, 5 Dec 2023 12:22:38 +0530 Subject: [PATCH 2/4] chore: split volume for send and receive --- README.md | 3 +- integration_tests/internal_payment_test.go | 106 ++++++++++----------- lib/service/config.go | 3 +- lib/service/user.go | 22 +++-- 4 files changed, 72 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index 0156251c..664c46a6 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,8 @@ vim .env # edit your config + `MAX_RECEIVE_AMOUNT`: (default: 0 = no limit) Set maximum amount (in satoshi) for which an invoice can be created + `MAX_SEND_AMOUNT`: (default: 0 = no limit) Set maximum amount (in satoshi) of an invoice that can be paid + `MAX_ACCOUNT_BALANCE`: (default: 0 = no limit) Set maximum balance (in satoshi) for each account -+ `MAX_VOLUME`: (default: 0 = no limit) Set maximum volume (in satoshi) for each account ++ `MAX_SEND_VOLUME`: (default: 0 = no limit) Set maximum volume (in satoshi) for sending for each account ++ `MAX_RECEIVE_VOLUME`: (default: 0 = no limit) Set maximum volume (in satoshi) for receiving for each account ### Macaroon diff --git a/integration_tests/internal_payment_test.go b/integration_tests/internal_payment_test.go index aa688286..d68833ad 100644 --- a/integration_tests/internal_payment_test.go +++ b/integration_tests/internal_payment_test.go @@ -136,8 +136,6 @@ func (suite *PaymentTestSuite) TestPaymentFeeReserve() { func (suite *PaymentTestSuite) TestIncomingExceededChecks() { //this will cause the payment to fail as the account was already funded //with 1000 sats - suite.service.Config.MaxVolume = 999 - suite.service.Config.MaxVolumePeriod = 2592000 aliceFundingSats := 1000 //fund alice account invoiceResponse := suite.createAddInvoiceReq(aliceFundingSats, "integration test internal payment alice", suite.aliceToken) @@ -146,52 +144,32 @@ func (suite *PaymentTestSuite) TestIncomingExceededChecks() { //wait a bit for the payment to be processed time.Sleep(10 * time.Millisecond) - - //try to make external payment - //which should fail - //create external invoice - externalSatRequested := 500 - externalInvoice := lnrpc.Invoice{ - Memo: "integration tests: external pay from user", - Value: int64(externalSatRequested), - } - invoice, err := suite.externalLND.AddInvoice(context.Background(), &externalInvoice) - assert.NoError(suite.T(), err) - //pay external invoice - rec := httptest.NewRecorder() var buf bytes.Buffer - assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedPayInvoiceRequestBody{ - Invoice: invoice.PaymentRequest, + suite.service.Config.MaxReceiveAmount = 21 + rec := httptest.NewRecorder() + assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedAddInvoiceRequestBody{ + Amount: aliceFundingSats, + Memo: "memo", })) - req := httptest.NewRequest(http.MethodPost, "/payinvoice", &buf) + req := httptest.NewRequest(http.MethodPost, "/addinvoice", &buf) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.aliceToken)) suite.echo.ServeHTTP(rec, req) - //should fail because max volume check + //should fail because max receive amount check assert.Equal(suite.T(), http.StatusBadRequest, rec.Code) resp := &responses.ErrorResponse{} err = json.NewDecoder(rec.Body).Decode(resp) assert.NoError(suite.T(), err) - assert.Equal(suite.T(), responses.TooMuchVolumeError.Message, resp.Message) + assert.Equal(suite.T(), responses.ReceiveExceededError.Message, resp.Message) - //change the period to be 1 second, sleep for 2 seconds, try to make another payment, this should work - suite.service.Config.MaxVolumePeriod = 1 - time.Sleep(2 * time.Second) - rec = httptest.NewRecorder() - externalInvoice = lnrpc.Invoice{ - Memo: "integration tests: external pay from user", - Value: int64(externalSatRequested), - } - invoice, err = suite.externalLND.AddInvoice(context.Background(), &externalInvoice) + // remove volume and receive config and check if it works + suite.service.Config.MaxReceiveAmount = 0 + invoiceResponse = suite.createAddInvoiceReq(aliceFundingSats, "integration test internal payment alice", suite.aliceToken) + err = suite.mlnd.mockPaidInvoice(invoiceResponse, 0, false, nil) assert.NoError(suite.T(), err) - assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedPayInvoiceRequestBody{ - Invoice: invoice.PaymentRequest, - })) - suite.echo.ServeHTTP(rec, req) - assert.Equal(suite.T(), http.StatusOK, rec.Code) - suite.service.Config.MaxReceiveAmount = 21 - rec = httptest.NewRecorder() + // add max account + suite.service.Config.MaxAccountBalance = 500 assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedAddInvoiceRequestBody{ Amount: aliceFundingSats, Memo: "memo", @@ -200,23 +178,22 @@ func (suite *PaymentTestSuite) TestIncomingExceededChecks() { req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.aliceToken)) suite.echo.ServeHTTP(rec, req) - //should fail because max receive amount check + //should fail because max balance check assert.Equal(suite.T(), http.StatusBadRequest, rec.Code) resp = &responses.ErrorResponse{} err = json.NewDecoder(rec.Body).Decode(resp) assert.NoError(suite.T(), err) - assert.Equal(suite.T(), responses.ReceiveExceededError.Message, resp.Message) + assert.Equal(suite.T(), responses.BalanceExceededError.Message, resp.Message) - // remove volume and receive config and check if it works - suite.service.Config.MaxVolume = 0 - suite.service.Config.MaxVolumePeriod = 0 - suite.service.Config.MaxReceiveAmount = 0 + //change the config back and add sats, it should work now + suite.service.Config.MaxAccountBalance = 0 invoiceResponse = suite.createAddInvoiceReq(aliceFundingSats, "integration test internal payment alice", suite.aliceToken) err = suite.mlnd.mockPaidInvoice(invoiceResponse, 0, false, nil) assert.NoError(suite.T(), err) - // add max account - suite.service.Config.MaxAccountBalance = 500 + // add max receive volume + suite.service.Config.MaxReceiveVolume = 1999 // because the volume till here is 1000+500+500 + suite.service.Config.MaxVolumePeriod = 2592000 assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedAddInvoiceRequestBody{ Amount: aliceFundingSats, Memo: "memo", @@ -225,15 +202,16 @@ func (suite *PaymentTestSuite) TestIncomingExceededChecks() { req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.aliceToken)) suite.echo.ServeHTTP(rec, req) - //should fail because max balance check + //should fail because max volume check assert.Equal(suite.T(), http.StatusBadRequest, rec.Code) resp = &responses.ErrorResponse{} err = json.NewDecoder(rec.Body).Decode(resp) assert.NoError(suite.T(), err) - assert.Equal(suite.T(), responses.BalanceExceededError.Message, resp.Message) + assert.Equal(suite.T(), responses.TooMuchVolumeError.Message, resp.Message) - //change the config back and add sats, it should work now - suite.service.Config.MaxAccountBalance = 0 + //change the config back, it should work now + suite.service.Config.MaxReceiveVolume = 0 + suite.service.Config.MaxVolumePeriod = 0 invoiceResponse = suite.createAddInvoiceReq(aliceFundingSats, "integration test internal payment alice", suite.aliceToken) err = suite.mlnd.mockPaidInvoice(invoiceResponse, 0, false, nil) assert.NoError(suite.T(), err) @@ -272,7 +250,8 @@ func (suite *PaymentTestSuite) TestOutgoingExceededChecks() { req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.aliceToken)) suite.echo.ServeHTTP(rec, req) - //should fail because max volume check + + //should fail because max send check assert.Equal(suite.T(), http.StatusBadRequest, rec.Code) resp := &responses.ErrorResponse{} err = json.NewDecoder(rec.Body).Decode(resp) @@ -280,11 +259,8 @@ func (suite *PaymentTestSuite) TestOutgoingExceededChecks() { assert.Equal(suite.T(), responses.SendExceededError.Message, resp.Message) suite.service.Config.MaxSendAmount = 2000 + //should work now rec = httptest.NewRecorder() - externalInvoice = lnrpc.Invoice{ - Memo: "integration tests: external pay from user", - Value: int64(externalSatRequested), - } invoice, err = suite.externalLND.AddInvoice(context.Background(), &externalInvoice) assert.NoError(suite.T(), err) assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedPayInvoiceRequestBody{ @@ -293,8 +269,32 @@ func (suite *PaymentTestSuite) TestOutgoingExceededChecks() { suite.echo.ServeHTTP(rec, req) assert.Equal(suite.T(), http.StatusOK, rec.Code) + suite.service.Config.MaxSendVolume = 100 + suite.service.Config.MaxVolumePeriod = 2592000 + //volume + invoice, err = suite.externalLND.AddInvoice(context.Background(), &externalInvoice) + assert.NoError(suite.T(), err) + //pay external invoice + rec = httptest.NewRecorder() + assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedPayInvoiceRequestBody{ + Invoice: invoice.PaymentRequest, + })) + req = httptest.NewRequest(http.MethodPost, "/payinvoice", &buf) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.aliceToken)) + suite.echo.ServeHTTP(rec, req) + + //should fail because maximum volume check + assert.Equal(suite.T(), http.StatusBadRequest, rec.Code) + resp = &responses.ErrorResponse{} + err = json.NewDecoder(rec.Body).Decode(resp) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), responses.TooMuchVolumeError.Message, resp.Message) + //change the config back suite.service.Config.MaxSendAmount = 0 + suite.service.Config.MaxSendVolume = 0 + suite.service.Config.MaxVolumePeriod = 0 } func (suite *PaymentTestSuite) TestInternalPayment() { diff --git a/lib/service/config.go b/lib/service/config.go index fa539711..366e1f6f 100644 --- a/lib/service/config.go +++ b/lib/service/config.go @@ -37,7 +37,8 @@ type Config struct { MaxSendAmount int64 `envconfig:"MAX_SEND_AMOUNT" default:"0"` MaxAccountBalance int64 `envconfig:"MAX_ACCOUNT_BALANCE" default:"0"` MaxFeeAmount int64 `envconfig:"MAX_FEE_AMOUNT" default:"5000"` - MaxVolume int64 `envconfig:"MAX_VOLUME" default:"0"` //0 means the volume check is disabled by default + MaxSendVolume int64 `envconfig:"MAX_SEND_VOLUME" default:"0"` //0 means the volume check is disabled by default + MaxReceiveVolume int64 `envconfig:"MAX_RECEIVE_VOLUME" default:"0"` //0 means the volume check is disabled by default MaxVolumePeriod int64 `envconfig:"MAX_VOLUME_PERIOD" default:"2592000"` //in seconds, default 1 month RabbitMQUri string `envconfig:"RABBITMQ_URI"` RabbitMQLndhubInvoiceExchange string `envconfig:"RABBITMQ_INVOICE_EXCHANGE" default:"lndhub_invoice"` diff --git a/lib/service/user.go b/lib/service/user.go index 84f1c51f..1ed89d14 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -153,7 +153,7 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay return &responses.NotEnoughBalanceError, nil } - return svc.CheckVolumeAllowed(ctx, userId) + return svc.CheckVolumeAllowed(ctx, userId, common.InvoiceTypeOutgoing) } func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64) (result *responses.ErrorResponse, err error) { @@ -182,12 +182,18 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun } } - return svc.CheckVolumeAllowed(ctx, userId) + return svc.CheckVolumeAllowed(ctx, userId, common.InvoiceTypeIncoming) } -func (svc *LndhubService) CheckVolumeAllowed(ctx context.Context, userId int64) (result *responses.ErrorResponse, err error) { - if svc.Config.MaxVolume > 0 { - volume, err := svc.GetVolumeOverPeriod(ctx, userId, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) +func (svc *LndhubService) CheckVolumeAllowed(ctx context.Context, userId int64, invoiceType string) (result *responses.ErrorResponse, err error) { + var maxVolume int64 + if invoiceType == common.InvoiceTypeIncoming { + maxVolume = svc.Config.MaxReceiveVolume + } else { + maxVolume = svc.Config.MaxSendVolume + } + if maxVolume > 0 { + volume, err := svc.GetVolumeOverPeriod(ctx, userId, invoiceType, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) if err != nil { svc.Logger.Errorj( log.JSON{ @@ -198,7 +204,7 @@ func (svc *LndhubService) CheckVolumeAllowed(ctx context.Context, userId int64) ) return nil, err } - if volume > svc.Config.MaxVolume { + if volume > maxVolume { svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userId) sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId)) return &responses.TooMuchVolumeError, nil @@ -260,15 +266,17 @@ func (svc *LndhubService) InvoicesFor(ctx context.Context, userId int64, invoice return invoices, nil } -func (svc *LndhubService) GetVolumeOverPeriod(ctx context.Context, userId int64, period time.Duration) (result int64, err error) { +func (svc *LndhubService) GetVolumeOverPeriod(ctx context.Context, userId int64, invoiceType string, period time.Duration) (result int64, err error) { err = svc.DB.NewSelect().Table("invoices"). ColumnExpr("sum(invoices.amount) as result"). Where("invoices.user_id = ?", userId). + Where("invoices.type = ?", invoiceType). Where("invoices.settled_at >= ?", time.Now().Add(-1*period)). Scan(ctx, &result) if err != nil { return 0, err } + fmt.Println(result, "volume ") return result, nil } From 62d7ffb7fe2631ad4f53b92c7b7b121b6626f30f Mon Sep 17 00:00:00 2001 From: im-adithya Date: Tue, 5 Dec 2023 12:29:49 +0530 Subject: [PATCH 3/4] fix: outgoing exceeding tests --- integration_tests/internal_payment_test.go | 2 +- lib/service/user.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/integration_tests/internal_payment_test.go b/integration_tests/internal_payment_test.go index d68833ad..dcc9c7f0 100644 --- a/integration_tests/internal_payment_test.go +++ b/integration_tests/internal_payment_test.go @@ -233,7 +233,7 @@ func (suite *PaymentTestSuite) TestOutgoingExceededChecks() { //try to make external payment //which should fail //create external invoice - externalSatRequested := 500 + externalSatRequested := 400 externalInvoice := lnrpc.Invoice{ Memo: "integration tests: external pay from user", Value: int64(externalSatRequested), diff --git a/lib/service/user.go b/lib/service/user.go index 1ed89d14..107408c2 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -277,6 +277,5 @@ func (svc *LndhubService) GetVolumeOverPeriod(ctx context.Context, userId int64, if err != nil { return 0, err } - fmt.Println(result, "volume ") return result, nil } From bd92e6ff8571d94dd88509a20776366d3e77c123 Mon Sep 17 00:00:00 2001 From: im-adithya Date: Thu, 7 Dec 2023 14:15:47 +0530 Subject: [PATCH 4/4] chore: remove check volume allowed function --- lib/service/user.go | 63 ++++++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/lib/service/user.go b/lib/service/user.go index 107408c2..002bef1d 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -133,6 +133,25 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay } } + if svc.Config.MaxSendVolume > 0 { + volume, err := svc.GetVolumeOverPeriod(ctx, userId, common.InvoiceTypeOutgoing, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) + if err != nil { + svc.Logger.Errorj( + log.JSON{ + "message": "error fetching volume", + "error": err, + "lndhub_user_id": userId, + }, + ) + return nil, err + } + if volume > svc.Config.MaxSendVolume { + svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userId) + sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId)) + return &responses.TooMuchVolumeError, nil + } + } + currentBalance, err := svc.CurrentUserBalance(ctx, userId) if err != nil { svc.Logger.Errorj( @@ -153,7 +172,7 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay return &responses.NotEnoughBalanceError, nil } - return svc.CheckVolumeAllowed(ctx, userId, common.InvoiceTypeOutgoing) + return nil, nil } func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64) (result *responses.ErrorResponse, err error) { @@ -164,56 +183,46 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun } } - if svc.Config.MaxAccountBalance > 0 { - currentBalance, err := svc.CurrentUserBalance(ctx, userId) + if svc.Config.MaxReceiveVolume > 0 { + volume, err := svc.GetVolumeOverPeriod(ctx, userId, common.InvoiceTypeIncoming, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) if err != nil { svc.Logger.Errorj( log.JSON{ - "message": "error fetching balance", - "lndhub_user_id": userId, + "message": "error fetching volume", "error": err, + "lndhub_user_id": userId, }, ) return nil, err } - if currentBalance+amount > svc.Config.MaxAccountBalance { - svc.Logger.Errorf("Max account balance exceeded for user_id %d", userId) - return &responses.BalanceExceededError, nil + if volume > svc.Config.MaxReceiveVolume { + svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userId) + sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId)) + return &responses.TooMuchVolumeError, nil } } - return svc.CheckVolumeAllowed(ctx, userId, common.InvoiceTypeIncoming) -} - -func (svc *LndhubService) CheckVolumeAllowed(ctx context.Context, userId int64, invoiceType string) (result *responses.ErrorResponse, err error) { - var maxVolume int64 - if invoiceType == common.InvoiceTypeIncoming { - maxVolume = svc.Config.MaxReceiveVolume - } else { - maxVolume = svc.Config.MaxSendVolume - } - if maxVolume > 0 { - volume, err := svc.GetVolumeOverPeriod(ctx, userId, invoiceType, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) + if svc.Config.MaxAccountBalance > 0 { + currentBalance, err := svc.CurrentUserBalance(ctx, userId) if err != nil { svc.Logger.Errorj( log.JSON{ - "message": "error fetching volume", - "error": err, + "message": "error fetching balance", "lndhub_user_id": userId, + "error": err, }, ) return nil, err } - if volume > maxVolume { - svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userId) - sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId)) - return &responses.TooMuchVolumeError, nil + if currentBalance+amount > svc.Config.MaxAccountBalance { + svc.Logger.Errorf("Max account balance exceeded for user_id %d", userId) + return &responses.BalanceExceededError, nil } } + return nil, nil } - func (svc *LndhubService) CalcFeeLimit(destination string, amount int64) int64 { if svc.LndClient.IsIdentityPubkey(destination) { return 0