Skip to content

Commit

Permalink
pack local variables into structs; remove need for IR
Browse files Browse the repository at this point in the history
  • Loading branch information
saucepoint committed Jan 11, 2024
1 parent 2ca6a04 commit 3691e3d
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 69 deletions.
101 changes: 60 additions & 41 deletions contracts/lens/Quoter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -147,43 +147,61 @@ contract Quoter is IQuoter, ILockCallback {
abi.decode(reason, (int128[], uint160[], uint32[]));
}

struct QuoteResult {
int128[] deltaAmounts;
uint160[] sqrtPriceX96AfterList;
uint32[] initializedTicksLoadedList;
}

struct QuoteCache {
Currency prevCurrency;
uint128 prevAmount;
int128 deltaIn;
int128 deltaOut;
BalanceDelta curDeltas;
uint160 sqrtPriceX96After;
int24 tickBefore;
int24 tickAfter;
}

/// @dev quote an ExactInput swap along a path of tokens, then revert with the result
function _quoteExactInput(QuoteExactParams memory params) public selfOnly returns (bytes memory) {
uint256 pathLength = params.path.length;

int128[] memory deltaAmounts = new int128[](pathLength + 1);
uint160[] memory sqrtPriceX96AfterList = new uint160[](pathLength);
uint32[] memory initializedTicksLoadedList = new uint32[](pathLength);
Currency prevCurrencyOut;
uint128 prevAmountOut;
QuoteResult memory result = QuoteResult({
deltaAmounts: new int128[](pathLength + 1),
sqrtPriceX96AfterList: new uint160[](pathLength),
initializedTicksLoadedList: new uint32[](pathLength)
});
QuoteCache memory cache;

for (uint256 i = 0; i < pathLength; i++) {
(PoolKey memory poolKey, bool zeroForOne) =
params.path[i].getPoolAndSwapDirection(i == 0 ? params.exactCurrency : prevCurrencyOut);
(, int24 tickBefore,) = manager.getSlot0(poolKey.toId());
params.path[i].getPoolAndSwapDirection(i == 0 ? params.exactCurrency : cache.prevCurrency);
(, cache.tickBefore,) = manager.getSlot0(poolKey.toId());

(BalanceDelta curDeltas, uint160 sqrtPriceX96After, int24 tickAfter) = _swap(
(cache.curDeltas, cache.sqrtPriceX96After, cache.tickAfter) = _swap(
poolKey,
zeroForOne,
int256(int128(i == 0 ? params.exactAmount : prevAmountOut)),
int256(int128(i == 0 ? params.exactAmount : cache.prevAmount)),
0,
params.path[i].hookData
);

(int128 deltaIn, int128 deltaOut) =
zeroForOne ? (curDeltas.amount0(), curDeltas.amount1()) : (curDeltas.amount1(), curDeltas.amount0());
deltaAmounts[i] += deltaIn;
deltaAmounts[i + 1] += deltaOut;
(cache.deltaIn, cache.deltaOut) =
zeroForOne ? (cache.curDeltas.amount0(), cache.curDeltas.amount1()) : (cache.curDeltas.amount1(), cache.curDeltas.amount0());
result.deltaAmounts[i] += cache.deltaIn;
result.deltaAmounts[i + 1] += cache.deltaOut;

prevAmountOut = zeroForOne ? uint128(-curDeltas.amount1()) : uint128(-curDeltas.amount0());
prevCurrencyOut = params.path[i].intermediateCurrency;
sqrtPriceX96AfterList[i] = sqrtPriceX96After;
initializedTicksLoadedList[i] =
PoolTicksCounter.countInitializedTicksLoaded(manager, poolKey, tickBefore, tickAfter);
cache.prevAmount = zeroForOne ? uint128(-cache.curDeltas.amount1()) : uint128(-cache.curDeltas.amount0());
cache.prevCurrency = params.path[i].intermediateCurrency;
result.sqrtPriceX96AfterList[i] = cache.sqrtPriceX96After;
result.initializedTicksLoadedList[i] =
PoolTicksCounter.countInitializedTicksLoaded(manager, poolKey, cache.tickBefore, cache.tickAfter);
}
bytes memory result = abi.encode(deltaAmounts, sqrtPriceX96AfterList, initializedTicksLoadedList);
bytes memory r = abi.encode(result.deltaAmounts, result.sqrtPriceX96AfterList, result.initializedTicksLoadedList);
assembly {
revert(add(0x20, result), mload(result))
revert(add(0x20, r), mload(r))
}
}

Expand Down Expand Up @@ -216,42 +234,43 @@ contract Quoter is IQuoter, ILockCallback {
function _quoteExactOutput(QuoteExactParams memory params) public selfOnly returns (bytes memory) {
uint256 pathLength = params.path.length;

int128[] memory deltaAmounts = new int128[](pathLength + 1);
uint160[] memory sqrtPriceX96AfterList = new uint160[](pathLength);
uint32[] memory initializedTicksLoadedList = new uint32[](pathLength);
Currency prevCurrencyIn;
uint128 prevAmountIn;
QuoteResult memory result = QuoteResult({
deltaAmounts: new int128[](pathLength + 1),
sqrtPriceX96AfterList: new uint160[](pathLength),
initializedTicksLoadedList: new uint32[](pathLength)
});
QuoteCache memory cache;
uint128 curAmountOut;

for (uint256 i = pathLength; i > 0; i--) {
curAmountOut = i == pathLength ? params.exactAmount : prevAmountIn;
curAmountOut = i == pathLength ? params.exactAmount : cache.prevAmount;
amountOutCached = curAmountOut;

(PoolKey memory poolKey, bool oneForZero) = PathKeyLib.getPoolAndSwapDirection(
params.path[i - 1], i == pathLength ? params.exactCurrency : prevCurrencyIn
params.path[i - 1], i == pathLength ? params.exactCurrency : cache.prevCurrency
);

(, int24 tickBefore,) = manager.getSlot0(poolKey.toId());
(, cache.tickBefore,) = manager.getSlot0(poolKey.toId());

(BalanceDelta curDeltas, uint160 sqrtPriceX96After, int24 tickAfter) =
(cache.curDeltas, cache.sqrtPriceX96After, cache.tickAfter) =
_swap(poolKey, !oneForZero, -int256(uint256(curAmountOut)), 0, params.path[i - 1].hookData);

// always clear because sqrtPriceLimitX96 is set to 0 always
delete amountOutCached;
(int128 deltaIn, int128 deltaOut) =
!oneForZero ? (curDeltas.amount0(), curDeltas.amount1()) : (curDeltas.amount1(), curDeltas.amount0());
deltaAmounts[i - 1] += deltaIn;
deltaAmounts[i] += deltaOut;

prevAmountIn = !oneForZero ? uint128(curDeltas.amount0()) : uint128(curDeltas.amount1());
prevCurrencyIn = params.path[i - 1].intermediateCurrency;
sqrtPriceX96AfterList[i - 1] = sqrtPriceX96After;
initializedTicksLoadedList[i - 1] =
PoolTicksCounter.countInitializedTicksLoaded(manager, poolKey, tickBefore, tickAfter);
(cache.deltaIn, cache.deltaOut) =
!oneForZero ? (cache.curDeltas.amount0(), cache.curDeltas.amount1()) : (cache.curDeltas.amount1(), cache.curDeltas.amount0());
result.deltaAmounts[i - 1] += cache.deltaIn;
result.deltaAmounts[i] += cache.deltaOut;

cache.prevAmount = !oneForZero ? uint128(cache.curDeltas.amount0()) : uint128(cache.curDeltas.amount1());
cache.prevCurrency = params.path[i - 1].intermediateCurrency;
result.sqrtPriceX96AfterList[i - 1] = cache.sqrtPriceX96After;
result.initializedTicksLoadedList[i - 1] =
PoolTicksCounter.countInitializedTicksLoaded(manager, poolKey, cache.tickBefore, cache.tickAfter);
}
bytes memory result = abi.encode(deltaAmounts, sqrtPriceX96AfterList, initializedTicksLoadedList);
bytes memory r = abi.encode(result.deltaAmounts, result.sqrtPriceX96AfterList, result.initializedTicksLoadedList);
assembly {
revert(add(0x20, result), mload(result))
revert(add(0x20, r), mload(r))
}
}

Expand Down
54 changes: 29 additions & 25 deletions contracts/libraries/PoolTicksCounter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol";
library PoolTicksCounter {
using PoolIdLibrary for PoolKey;

struct TickCache {
int16 wordPosLower;
int16 wordPosHigher;
uint8 bitPosLower;
uint8 bitPosHigher;
bool tickBeforeInitialized;
bool tickAfterInitialized;
}

/// @dev This function counts the number of initialized ticks that would incur a gas cost between tickBefore and tickAfter.
/// When tickBefore and/or tickAfter themselves are initialized, the logic over whether we should count them depends on the
/// direction of the swap. If we are swapping upwards (tickAfter > tickBefore) we don't want to count tickBefore but we do
Expand All @@ -18,12 +27,7 @@ library PoolTicksCounter {
view
returns (uint32 initializedTicksLoaded)
{
int16 wordPosLower;
int16 wordPosHigher;
uint8 bitPosLower;
uint8 bitPosHigher;
bool tickBeforeInitialized;
bool tickAfterInitialized;
TickCache memory cache;

{
// Get the key and offset in the tick bitmap of the active tick before and after the swap.
Expand All @@ -39,53 +43,53 @@ library PoolTicksCounter {
// and we shouldn't count it.
uint256 bmAfter = self.getPoolBitmapInfo(key.toId(), wordPosAfter);
//uint256 bmAfter = PoolGetters.getTickBitmapAtWord(self, key.toId(), wordPosAfter);
tickAfterInitialized =
cache.tickAfterInitialized =
((bmAfter & (1 << bitPosAfter)) > 0) && ((tickAfter % key.tickSpacing) == 0) && (tickBefore > tickAfter);

// In the case where tickBefore is initialized, we only want to count it if we are swapping upwards.
// Use the same logic as above to decide whether we should count tickBefore or not.
uint256 bmBefore = self.getPoolBitmapInfo(key.toId(), wordPos);
//uint256 bmBefore = PoolGetters.getTickBitmapAtWord(self, key.toId(), wordPos);
tickBeforeInitialized =
cache.tickBeforeInitialized =
((bmBefore & (1 << bitPos)) > 0) && ((tickBefore % key.tickSpacing) == 0) && (tickBefore < tickAfter);

if (wordPos < wordPosAfter || (wordPos == wordPosAfter && bitPos <= bitPosAfter)) {
wordPosLower = wordPos;
bitPosLower = bitPos;
wordPosHigher = wordPosAfter;
bitPosHigher = bitPosAfter;
cache.wordPosLower = wordPos;
cache.bitPosLower = bitPos;
cache.wordPosHigher = wordPosAfter;
cache.bitPosHigher = bitPosAfter;
} else {
wordPosLower = wordPosAfter;
bitPosLower = bitPosAfter;
wordPosHigher = wordPos;
bitPosHigher = bitPos;
cache.wordPosLower = wordPosAfter;
cache.bitPosLower = bitPosAfter;
cache.wordPosHigher = wordPos;
cache.bitPosHigher = bitPos;
}
}

// Count the number of initialized ticks crossed by iterating through the tick bitmap.
// Our first mask should include the lower tick and everything to its left.
uint256 mask = type(uint256).max << bitPosLower;
while (wordPosLower <= wordPosHigher) {
uint256 mask = type(uint256).max << cache.bitPosLower;
while (cache.wordPosLower <= cache.wordPosHigher) {
// If we're on the final tick bitmap page, ensure we only count up to our
// ending tick.
if (wordPosLower == wordPosHigher) {
mask = mask & (type(uint256).max >> (255 - bitPosHigher));
if (cache.wordPosLower == cache.wordPosHigher) {
mask = mask & (type(uint256).max >> (255 - cache.bitPosHigher));
}

//uint256 bmLower = PoolGetters.getTickBitmapAtWord(self, key.toId(), wordPosLower);
uint256 bmLower = self.getPoolBitmapInfo(key.toId(), wordPosLower);
//uint256 bmLower = PoolGetters.getTickBitmapAtWord(self, key.toId(), cache.wordPosLower);
uint256 bmLower = self.getPoolBitmapInfo(key.toId(), cache.wordPosLower);
uint256 masked = bmLower & mask;
initializedTicksLoaded += countOneBits(masked);
wordPosLower++;
cache.wordPosLower++;
// Reset our mask so we consider all bits on the next iteration.
mask = type(uint256).max;
}

if (tickAfterInitialized) {
if (cache.tickAfterInitialized) {
initializedTicksLoaded -= 1;
}

if (tickBeforeInitialized) {
if (cache.tickBeforeInitialized) {
initializedTicksLoaded -= 1;
}

Expand Down
4 changes: 1 addition & 3 deletions foundry.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
[profile.default]
src = 'contracts'
out = 'foundry-out'
solc_version = '0.8.20'
via_ir = true
optimizer_runs = 1000000
solc_version = '0.8.22'
ffi = true
fs_permissions = [{ access = "read-write", path = ".forge-snapshots/"}]
cancun = true
Expand Down

0 comments on commit 3691e3d

Please sign in to comment.