Skip to content

Commit

Permalink
Merge pull request #465 from getAlby/task-limits-jwt
Browse files Browse the repository at this point in the history
feat: fetch and use the limits from the JWT Token
  • Loading branch information
kiwiidb authored Dec 8, 2023
2 parents dbc5993 + 478bc3b commit 110d09a
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 33 deletions.
2 changes: 1 addition & 1 deletion controllers/addinvoice.ctrl.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error
return c.JSON(http.StatusBadRequest, responses.BadArgumentsError)
}

resp, err := svc.CheckIncomingPaymentAllowed(c.Request().Context(), amount, userID)
resp, err := svc.CheckIncomingPaymentAllowed(c, amount, userID)
if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
}
Expand Down
2 changes: 1 addition & 1 deletion controllers/keysend.ctrl.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error {
})
}

resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID)
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, lnPayReq, userID)
if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
}
Expand Down
2 changes: 1 addition & 1 deletion controllers/payinvoice.ctrl.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
lnPayReq.PayReq.NumSatoshis = amt
}

resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID)
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, lnPayReq, userID)
if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
}
Expand Down
2 changes: 1 addition & 1 deletion controllers_v2/invoice.ctrl.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (controller *InvoiceController) AddInvoice(c echo.Context) error {
return c.JSON(http.StatusBadRequest, responses.BadArgumentsError)
}

resp, err := controller.svc.CheckIncomingPaymentAllowed(c.Request().Context(), body.Amount, userID)
resp, err := controller.svc.CheckIncomingPaymentAllowed(c, body.Amount, userID)
if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
}
Expand Down
8 changes: 4 additions & 4 deletions controllers_v2/keysend.ctrl.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error {
c.Logger().Errorf("Invalid keysend request body: %v", err)
return c.JSON(http.StatusBadRequest, responses.BadArgumentsError)
}
errResp := controller.checkKeysendPaymentAllowed(context.Background(), reqBody.Amount, userID)
errResp := controller.checkKeysendPaymentAllowed(c, reqBody.Amount, userID)
if errResp != nil {
c.Logger().Errorf("Failed to send keysend: %s", errResp.Message)
return c.JSON(errResp.HttpStatusCode, errResp)
Expand Down Expand Up @@ -127,7 +127,7 @@ func (controller *KeySendController) MultiKeySend(c echo.Context) error {
for _, keysend := range reqBody.Keysends {
totalAmount += keysend.Amount
}
errResp := controller.checkKeysendPaymentAllowed(context.Background(), totalAmount, userID)
errResp := controller.checkKeysendPaymentAllowed(c, totalAmount, userID)
if errResp != nil {
c.Logger().Errorf("Failed to make keysend split payments: %s", errResp.Message)
return c.JSON(errResp.HttpStatusCode, errResp)
Expand Down Expand Up @@ -162,14 +162,14 @@ func (controller *KeySendController) MultiKeySend(c echo.Context) error {
return c.JSON(status, result)
}

func (controller *KeySendController) checkKeysendPaymentAllowed(ctx context.Context, amount, userID int64) (resp *responses.ErrorResponse) {
func (controller *KeySendController) checkKeysendPaymentAllowed(c echo.Context, amount, userID int64) (resp *responses.ErrorResponse) {
syntheticPayReq := &lnd.LNPayReq{
PayReq: &lnrpc.PayReq{
NumSatoshis: amount,
},
Keysend: true,
}
resp, err := controller.svc.CheckOutgoingPaymentAllowed(ctx, syntheticPayReq, userID)
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, syntheticPayReq, userID)
if err != nil {
return &responses.GeneralServerError
}
Expand Down
2 changes: 1 addition & 1 deletion controllers_v2/payinvoice.ctrl.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
}
lnPayReq.PayReq.NumSatoshis = amt
}
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID)
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, lnPayReq, userID)
if err != nil {
return c.JSON(http.StatusBadRequest, responses.GeneralServerError)
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
github.com/lightningnetwork/lnd v0.16.4-beta.rc1
github.com/rabbitmq/amqp091-go v1.8.1
github.com/rs/zerolog v1.29.1
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.4
github.com/uptrace/bun v1.1.14
github.com/uptrace/bun/dialect/pgdialect v1.1.14
Expand Down Expand Up @@ -137,7 +138,6 @@ require (
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rogpeppe/fastuuid v1.2.0 // indirect
github.com/secure-systems-lab/go-securesystemslib v0.6.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.0 // indirect
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/rabbitmq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (suite *RabbitMQTestSuite) SetupSuite() {
e.Validator = &lib.CustomValidator{Validator: validator.New()}

suite.echo = e
suite.echo.Use(tokens.Middleware(suite.svc.Config.JWTSecret))
suite.echo.Use(tokens.Middleware([]byte(suite.svc.Config.JWTSecret)))
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.svc).AddInvoice)
suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.svc).PayInvoice)
go func() {
Expand Down
12 changes: 9 additions & 3 deletions lib/service/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ type Config struct {
MaxSendAmount int64 `envconfig:"MAX_SEND_AMOUNT" default:"0"`
MaxAccountBalance int64 `envconfig:"MAX_ACCOUNT_BALANCE" default:"0"`
MaxFeeAmount int64 `envconfig:"MAX_FEE_AMOUNT" default:"5000"`
MaxSendVolume int64 `envconfig:"MAX_SEND_VOLUME" default:"0"` //0 means the volume check is disabled by default
MaxReceiveVolume int64 `envconfig:"MAX_RECEIVE_VOLUME" default:"0"` //0 means the volume check is disabled by default
MaxSendVolume int64 `envconfig:"MAX_SEND_VOLUME" default:"0"` //0 means the volume check is disabled by default
MaxReceiveVolume int64 `envconfig:"MAX_RECEIVE_VOLUME" default:"0"` //0 means the volume check is disabled by default
MaxVolumePeriod int64 `envconfig:"MAX_VOLUME_PERIOD" default:"2592000"` //in seconds, default 1 month
RabbitMQUri string `envconfig:"RABBITMQ_URI"`
RabbitMQLndhubInvoiceExchange string `envconfig:"RABBITMQ_INVOICE_EXCHANGE" default:"lndhub_invoice"`
Expand All @@ -48,7 +48,13 @@ type Config struct {
RabbitMQPaymentConsumerQueueName string `envconfig:"RABBITMQ_PAYMENT_CONSUMER_QUEUE_NAME" default:"lnd_payment_consumer"`
Branding BrandingConfig
}

type Limits struct {
MaxSendVolume int64
MaxSendAmount int64
MaxReceiveVolume int64
MaxReceiveAmount int64
MaxAccountBalance int64
}
type BrandingConfig struct {
Title string `envconfig:"BRANDING_TITLE" default:"LndHub.go - Alby Lightning"`
Desc string `envconfig:"BRANDING_DESC" default:"Alby server for the Lightning Network"`
Expand Down
64 changes: 47 additions & 17 deletions lib/service/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/getAlby/lndhub.go/lib/security"
"github.com/getAlby/lndhub.go/lnd"
"github.com/getsentry/sentry-go"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/uptrace/bun"
passwordvalidator "github.com/wagslane/go-password-validator"
Expand Down Expand Up @@ -125,16 +126,17 @@ func (svc *LndhubService) FindUserByLogin(ctx context.Context, login string) (*m
return &user, nil
}

func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64) (result *responses.ErrorResponse, err error) {
if svc.Config.MaxSendAmount > 0 {
if lnpayReq.PayReq.NumSatoshis > svc.Config.MaxSendAmount {
func (svc *LndhubService) CheckOutgoingPaymentAllowed(c echo.Context, lnpayReq *lnd.LNPayReq, userId int64) (result *responses.ErrorResponse, err error) {
limits := svc.GetLimits(c)
if limits.MaxSendAmount > 0 {
if lnpayReq.PayReq.NumSatoshis > limits.MaxSendAmount {
svc.Logger.Errorf("Max send amount exceeded for user_id %v (amount:%v)", userId, lnpayReq.PayReq.NumSatoshis)
return &responses.SendExceededError, nil
}
}

if svc.Config.MaxSendVolume > 0 {
volume, err := svc.GetVolumeOverPeriod(ctx, userId, common.InvoiceTypeOutgoing, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second)))
if limits.MaxSendVolume > 0 {
volume, err := svc.GetVolumeOverPeriod(c.Request().Context(), userId, common.InvoiceTypeOutgoing, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second)))
if err != nil {
svc.Logger.Errorj(
log.JSON{
Expand All @@ -145,14 +147,14 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay
)
return nil, err
}
if volume > svc.Config.MaxSendVolume {
if volume > limits.MaxSendVolume {
svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userId)
sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId))
return &responses.TooMuchVolumeError, nil
}
}

currentBalance, err := svc.CurrentUserBalance(ctx, userId)
currentBalance, err := svc.CurrentUserBalance(c.Request().Context(), userId)
if err != nil {
svc.Logger.Errorj(
log.JSON{
Expand All @@ -175,16 +177,17 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay
return nil, nil
}

func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64) (result *responses.ErrorResponse, err error) {
if svc.Config.MaxReceiveAmount > 0 {
if amount > svc.Config.MaxReceiveAmount {
func (svc *LndhubService) CheckIncomingPaymentAllowed(c echo.Context, amount, userId int64) (result *responses.ErrorResponse, err error) {
limits := svc.GetLimits(c)
if limits.MaxReceiveAmount > 0 {
if amount > limits.MaxReceiveAmount {
svc.Logger.Errorf("Max receive amount exceeded for user_id %d", userId)
return &responses.ReceiveExceededError, nil
return &responses.ReceiveExceededError, nil
}
}

if svc.Config.MaxReceiveVolume > 0 {
volume, err := svc.GetVolumeOverPeriod(ctx, userId, common.InvoiceTypeIncoming, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second)))
if limits.MaxReceiveVolume > 0 {
volume, err := svc.GetVolumeOverPeriod(c.Request().Context(), userId, common.InvoiceTypeIncoming, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second)))
if err != nil {
svc.Logger.Errorj(
log.JSON{
Expand All @@ -195,15 +198,15 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun
)
return nil, err
}
if volume > svc.Config.MaxReceiveVolume {
if volume > limits.MaxReceiveVolume {
svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userId)
sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId))
return &responses.TooMuchVolumeError, nil
}
}

if svc.Config.MaxAccountBalance > 0 {
currentBalance, err := svc.CurrentUserBalance(ctx, userId)
if limits.MaxAccountBalance > 0 {
currentBalance, err := svc.CurrentUserBalance(c.Request().Context(), userId)
if err != nil {
svc.Logger.Errorj(
log.JSON{
Expand All @@ -214,7 +217,7 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun
)
return nil, err
}
if currentBalance+amount > svc.Config.MaxAccountBalance {
if currentBalance+amount > limits.MaxAccountBalance {
svc.Logger.Errorf("Max account balance exceeded for user_id %d", userId)
return &responses.BalanceExceededError, nil
}
Expand Down Expand Up @@ -288,3 +291,30 @@ func (svc *LndhubService) GetVolumeOverPeriod(ctx context.Context, userId int64,
}
return result, nil
}

func (svc *LndhubService) GetLimits(c echo.Context) (limits *Limits) {
limits = &Limits{
MaxSendVolume: svc.Config.MaxSendVolume,
MaxSendAmount: svc.Config.MaxSendAmount,
MaxReceiveVolume: svc.Config.MaxReceiveVolume,
MaxReceiveAmount: svc.Config.MaxReceiveAmount,
MaxAccountBalance: svc.Config.MaxAccountBalance,
}
if val, ok := c.Get("MaxSendVolume").(int64); ok && val > 0 {
limits.MaxSendVolume = val
}
if val, ok := c.Get("MaxSendAmount").(int64); ok && val > 0 {
limits.MaxSendAmount = val
}
if val, ok := c.Get("MaxReceiveVolume").(int64); ok && val > 0 {
limits.MaxReceiveVolume = val
}
if val, ok := c.Get("MaxReceiveAmount").(int64); ok && val > 0 {
limits.MaxReceiveAmount = val
}
if val, ok := c.Get("MaxAccountBalance").(int64); ok && val > 0 {
limits.MaxAccountBalance = val
}

return limits
}
14 changes: 12 additions & 2 deletions lib/tokens/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@ import (
)

type jwtCustomClaims struct {
ID int64 `json:"id"`
IsRefresh bool `json:"isRefresh"`
ID int64 `json:"id"`
IsRefresh bool `json:"isRefresh"`
MaxSendVolume int64 `json:"maxSendVolume"`
MaxSendAmount int64 `json:"maxSendAmount"`
MaxReceiveVolume int64 `json:"maxReceiveVolume"`
MaxReceiveAmount int64 `json:"maxReceiveAmount"`
MaxAccountBalance int64 `json:"maxAccountBalance"`
jwt.StandardClaims
}

Expand All @@ -38,6 +43,11 @@ func Middleware(secret []byte) echo.MiddlewareFunc {
token := c.Get("UserJwt").(*jwt.Token)
claims := token.Claims.(*jwtCustomClaims)
c.Set("UserID", claims.ID)
c.Set("MaxSendVolume", claims.MaxSendVolume)
c.Set("MaxSendAmount", claims.MaxSendAmount)
c.Set("MaxReceiveVolume", claims.MaxReceiveVolume)
c.Set("MaxReceiveAmount", claims.MaxReceiveAmount)
c.Set("MaxAccountBalance", claims.MaxAccountBalance)
// pass UserID to sentry for exception notifications
if hub := sentryecho.GetHubFromContext(c); hub != nil {
hub.Scope().SetUser(sentry.User{ID: strconv.FormatInt(claims.ID, 10)})
Expand Down

0 comments on commit 110d09a

Please sign in to comment.