Skip to content

Commit

Permalink
Enforce simulation tax (#517)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrathCole authored Sep 17, 2024
1 parent b06d49f commit a936511
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 16 deletions.
23 changes: 19 additions & 4 deletions custom/auth/ante/fee.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (fd FeeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, nex

msgs := feeTx.GetMsgs()
// Compute taxes
taxes := FilterMsgAndComputeTax(ctx, fd.treasuryKeeper, msgs...)
taxes := FilterMsgAndComputeTax(ctx, fd.treasuryKeeper, simulate, msgs...)

if !simulate {
priority, err = fd.checkTxFee(ctx, tx, taxes)
Expand Down Expand Up @@ -101,9 +101,24 @@ func (fd FeeDecorator) checkDeductFee(ctx sdk.Context, feeTx sdk.FeeTx, taxes sd

feesOrTax := fee

// deduct the fees
if fee.IsZero() && simulate {
feesOrTax = taxes
if simulate {
if fee.IsZero() {
feesOrTax = taxes
}

// even if fee is not zero it might be it is lower than the increased tax from computeTax
// so we need to check if the tax is higher than the fee to not run into deduction errors
for _, tax := range taxes {
feeAmount := feesOrTax.AmountOf(tax.Denom)
// if the fee amount is zero, add the tax amount to feesOrTax
if feeAmount.IsZero() {
feesOrTax = feesOrTax.Add(tax)
} else if feeAmount.LT(tax.Amount) {
// Update feesOrTax if the tax amount is higher
missingAmount := tax.Amount.Sub(feeAmount)
feesOrTax = feesOrTax.Add(sdk.NewCoin(tax.Denom, missingAmount))
}
}
}

if !feesOrTax.IsZero() {
Expand Down
24 changes: 15 additions & 9 deletions custom/auth/ante/fee_tax.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ func isIBCDenom(denom string) bool {
}

// FilterMsgAndComputeTax computes the stability tax on messages.
func FilterMsgAndComputeTax(ctx sdk.Context, tk TreasuryKeeper, msgs ...sdk.Msg) sdk.Coins {
func FilterMsgAndComputeTax(ctx sdk.Context, tk TreasuryKeeper, simulate bool, msgs ...sdk.Msg) sdk.Coins {
taxes := sdk.Coins{}

for _, msg := range msgs {
switch msg := msg.(type) {
case *banktypes.MsgSend:
if !tk.HasBurnTaxExemptionAddress(ctx, msg.FromAddress, msg.ToAddress) {
taxes = taxes.Add(computeTax(ctx, tk, msg.Amount)...)
taxes = taxes.Add(computeTax(ctx, tk, msg.Amount, simulate)...)
}

case *banktypes.MsgMultiSend:
Expand All @@ -47,28 +47,28 @@ func FilterMsgAndComputeTax(ctx sdk.Context, tk TreasuryKeeper, msgs ...sdk.Msg)

if tainted != len(msg.Inputs)+len(msg.Outputs) {
for _, input := range msg.Inputs {
taxes = taxes.Add(computeTax(ctx, tk, input.Coins)...)
taxes = taxes.Add(computeTax(ctx, tk, input.Coins, simulate)...)
}
}

case *marketexported.MsgSwapSend:
taxes = taxes.Add(computeTax(ctx, tk, sdk.NewCoins(msg.OfferCoin))...)
taxes = taxes.Add(computeTax(ctx, tk, sdk.NewCoins(msg.OfferCoin), simulate)...)

case *wasmtypes.MsgInstantiateContract:
taxes = taxes.Add(computeTax(ctx, tk, msg.Funds)...)
taxes = taxes.Add(computeTax(ctx, tk, msg.Funds, simulate)...)

case *wasmtypes.MsgInstantiateContract2:
taxes = taxes.Add(computeTax(ctx, tk, msg.Funds)...)
taxes = taxes.Add(computeTax(ctx, tk, msg.Funds, simulate)...)

case *wasmtypes.MsgExecuteContract:
if !tk.HasBurnTaxExemptionContract(ctx, msg.Contract) {
taxes = taxes.Add(computeTax(ctx, tk, msg.Funds)...)
taxes = taxes.Add(computeTax(ctx, tk, msg.Funds, simulate)...)
}

case *authz.MsgExec:
messages, err := msg.GetMessages()
if err == nil {
taxes = taxes.Add(FilterMsgAndComputeTax(ctx, tk, messages...)...)
taxes = taxes.Add(FilterMsgAndComputeTax(ctx, tk, simulate, messages...)...)
}
}
}
Expand All @@ -77,7 +77,7 @@ func FilterMsgAndComputeTax(ctx sdk.Context, tk TreasuryKeeper, msgs ...sdk.Msg)
}

// computes the stability tax according to tax-rate and tax-cap
func computeTax(ctx sdk.Context, tk TreasuryKeeper, principal sdk.Coins) sdk.Coins {
func computeTax(ctx sdk.Context, tk TreasuryKeeper, principal sdk.Coins, simulate bool) sdk.Coins {
taxRate := tk.GetTaxRate(ctx)
if taxRate.Equal(sdk.ZeroDec()) {
return sdk.Coins{}
Expand All @@ -95,6 +95,12 @@ func computeTax(ctx sdk.Context, tk TreasuryKeeper, principal sdk.Coins) sdk.Coi
}

taxDue := sdk.NewDecFromInt(coin.Amount).Mul(taxRate).TruncateInt()
// we need to check all taxes if they are GTE 100 because otherwise we will not be able to
// simulate the split processes (i.e. BurnTaxSplit and OracleSplit)
// if they are less than 100, we will set them to 100
if simulate && taxDue.LT(sdk.NewInt(100)) {
taxDue = sdk.NewInt(100)
}

// If tax due is greater than the tax cap, cap!
taxCap := tk.GetTaxCap(ctx, coin.Denom)
Expand Down
2 changes: 1 addition & 1 deletion custom/auth/ante/fee_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ func (s *AnteTestSuite) runBurnSplitTaxTest(burnSplitRate sdk.Dec, oracleSplitRa

feeCollectorAfter := bk.GetAllBalances(s.ctx, ak.GetModuleAddress(authtypes.FeeCollectorName))
oracleAfter := bk.GetAllBalances(s.ctx, ak.GetModuleAddress(oracletypes.ModuleName))
taxes := ante.FilterMsgAndComputeTax(s.ctx, tk, msg)
taxes := ante.FilterMsgAndComputeTax(s.ctx, tk, false, msg)
communityPoolAfter, _ := dk.GetFeePoolCommunityCoins(s.ctx).TruncateDecimal()
if communityPoolAfter.IsZero() {
communityPoolAfter = sdk.NewCoins(sdk.NewCoin(core.MicroSDRDenom, sdk.ZeroInt()))
Expand Down
2 changes: 1 addition & 1 deletion custom/auth/tx/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (ts txServer) ComputeTax(c context.Context, req *ComputeTaxRequest) (*Compu
return nil, status.Errorf(codes.InvalidArgument, "empty txBytes is not allowed")
}

taxAmount := customante.FilterMsgAndComputeTax(ctx, ts.treasuryKeeper, msgs...)
taxAmount := customante.FilterMsgAndComputeTax(ctx, ts.treasuryKeeper, false, msgs...)
return &ComputeTaxResponse{
TaxAmount: taxAmount,
}, nil
Expand Down
4 changes: 3 additions & 1 deletion custom/wasm/keeper/handler_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ func (h SDKMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddr

for _, sdkMsg := range sdkMsgs {
// Charge tax on result msg
taxes := ante.FilterMsgAndComputeTax(ctx, h.treasuryKeeper, sdkMsg)
// we set simulate to false here as it is not available and we don't need to
// increase the tax amount for simulation inside of wasm
taxes := ante.FilterMsgAndComputeTax(ctx, h.treasuryKeeper, false, sdkMsg)
if !taxes.IsZero() {
eventManager := sdk.NewEventManager()
contractAcc := h.accountKeeper.GetAccount(ctx, contractAddr)
Expand Down

0 comments on commit a936511

Please sign in to comment.