diff --git a/Cargo.lock b/Cargo.lock index b7407c89e2e..209c7331a96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -933,6 +933,18 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "enum_dispatch" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8946e241a7774d5327d92749c50806f275f57d031d2229ecbfd65469a8ad338e" +dependencies = [ + "once_cell", + "proc-macro2 1.0.24", + "quote 1.0.8", + "syn 1.0.58", +] + [[package]] name = "enumflags2" version = "0.6.4" @@ -3882,6 +3894,7 @@ version = "0.1.0" dependencies = [ "arbitrary", "arrayref", + "enum_dispatch", "num-derive", "num-traits", "proptest", diff --git a/token-swap/js/client/token-swap.js b/token-swap/js/client/token-swap.js index 5032055a7e5..f2d01023c3e 100644 --- a/token-swap/js/client/token-swap.js +++ b/token-swap/js/client/token-swap.js @@ -58,6 +58,7 @@ export class Numberu64 extends BN { */ export const TokenSwapLayout: typeof BufferLayout.Structure = BufferLayout.struct( [ + BufferLayout.u8('version'), BufferLayout.u8('isInitialized'), BufferLayout.u8('nonce'), Layout.publicKey('tokenProgramId'), diff --git a/token-swap/program/Cargo.toml b/token-swap/program/Cargo.toml index 12ca9c017cb..8c5010a4e43 100644 --- a/token-swap/program/Cargo.toml +++ b/token-swap/program/Cargo.toml @@ -14,6 +14,7 @@ fuzz = ["arbitrary"] [dependencies] arrayref = "0.3.6" +enum_dispatch = "0.3.4" num-derive = "0.3" num-traits = "0.2" solana-program = "1.5.1" diff --git a/token-swap/program/fuzz/src/native_token_swap.rs b/token-swap/program/fuzz/src/native_token_swap.rs index 3116343cc0e..fc20c8998e6 100644 --- a/token-swap/program/fuzz/src/native_token_swap.rs +++ b/token-swap/program/fuzz/src/native_token_swap.rs @@ -10,14 +10,12 @@ use spl_token_swap::{ self, DepositAllTokenTypes, DepositSingleTokenTypeExactAmountIn, Swap, WithdrawAllTokenTypes, WithdrawSingleTokenTypeExactAmountOut, }, - state::SwapInfo, + state::SwapVersion, }; use spl_token::instruction::approve; -use solana_program::{ - bpf_loader, entrypoint::ProgramResult, program_pack::Pack, pubkey::Pubkey, system_program, -}; +use solana_program::{bpf_loader, entrypoint::ProgramResult, pubkey::Pubkey, system_program}; pub struct NativeTokenSwap { pub user_account: NativeAccountData, @@ -51,7 +49,8 @@ impl NativeTokenSwap { ) -> Self { let mut user_account = NativeAccountData::new(0, system_program::id()); user_account.is_signer = true; - let mut swap_account = NativeAccountData::new(SwapInfo::LEN, spl_token_swap::id()); + let mut swap_account = + NativeAccountData::new(SwapVersion::LATEST_LEN, spl_token_swap::id()); let (authority_key, nonce) = Pubkey::find_program_address( &[&swap_account.key.to_bytes()[..]], &spl_token_swap::id(), diff --git a/token-swap/program/src/instruction.rs b/token-swap/program/src/instruction.rs index acebf505a53..73e18775a75 100644 --- a/token-swap/program/src/instruction.rs +++ b/token-swap/program/src/instruction.rs @@ -92,11 +92,11 @@ pub struct WithdrawSingleTokenTypeExactAmountOut { pub maximum_pool_token_amount: u64, } -/// Instructions supported by the SwapInfo program. +/// Instructions supported by the token swap program. #[repr(C)] #[derive(Debug, PartialEq)] pub enum SwapInstruction { - /// Initializes a new SwapInfo. + /// Initializes a new swap /// /// 0. `[writable, signer]` New Token-swap to create. /// 1. `[]` swap authority derived from `create_program_address(&[Token-swap account])` diff --git a/token-swap/program/src/processor.rs b/token-swap/program/src/processor.rs index 4ff786e6f4b..571d7eae31c 100644 --- a/token-swap/program/src/processor.rs +++ b/token-swap/program/src/processor.rs @@ -12,7 +12,7 @@ use crate::{ DepositAllTokenTypes, DepositSingleTokenTypeExactAmountIn, Initialize, Swap, SwapInstruction, WithdrawAllTokenTypes, WithdrawSingleTokenTypeExactAmountOut, }, - state::SwapInfo, + state::{SwapState, SwapV1, SwapVersion}, }; use num_traits::FromPrimitive; use solana_program::{ @@ -152,7 +152,7 @@ impl Processor { #[allow(clippy::too_many_arguments)] fn check_accounts( - token_swap: &SwapInfo, + token_swap: &dyn SwapState, program_id: &Pubkey, swap_account_info: &AccountInfo, authority_info: &AccountInfo, @@ -168,20 +168,20 @@ impl Processor { return Err(ProgramError::IncorrectProgramId); } if *authority_info.key - != Self::authority_id(program_id, swap_account_info.key, token_swap.nonce)? + != Self::authority_id(program_id, swap_account_info.key, token_swap.nonce())? { return Err(SwapError::InvalidProgramAddress.into()); } - if *token_a_info.key != token_swap.token_a { + if *token_a_info.key != *token_swap.token_a_account() { return Err(SwapError::IncorrectSwapAccount.into()); } - if *token_b_info.key != token_swap.token_b { + if *token_b_info.key != *token_swap.token_b_account() { return Err(SwapError::IncorrectSwapAccount.into()); } - if *pool_mint_info.key != token_swap.pool_mint { + if *pool_mint_info.key != *token_swap.pool_mint() { return Err(SwapError::IncorrectPoolMint.into()); } - if *token_program_info.key != token_swap.token_program_id { + if *token_program_info.key != *token_swap.token_program_id() { return Err(SwapError::IncorrectTokenProgramId.into()); } if let Some(user_token_a_info) = user_token_a_info { @@ -195,7 +195,7 @@ impl Processor { } } if let Some(pool_fee_account_info) = pool_fee_account_info { - if *pool_fee_account_info.key != token_swap.pool_fee_account { + if *pool_fee_account_info.key != *token_swap.pool_fee_account() { return Err(SwapError::IncorrectFeeAccount.into()); } } @@ -222,8 +222,7 @@ impl Processor { let token_program_info = next_account_info(account_info_iter)?; let token_program_id = *token_program_info.key; - let token_swap = SwapInfo::unpack_unchecked(&swap_info.data.borrow())?; - if token_swap.is_initialized { + if SwapVersion::is_initialized(&swap_info.data.borrow()) { return Err(SwapError::AlreadyInUse.into()); } @@ -306,7 +305,7 @@ impl Processor { to_u64(initial_amount)?, )?; - let obj = SwapInfo { + let obj = SwapVersion::SwapV1(SwapV1 { is_initialized: true, nonce, token_program_id, @@ -318,8 +317,8 @@ impl Processor { pool_fee_account: *fee_account_info.key, fees, swap_curve, - }; - SwapInfo::pack(obj, &mut swap_info.data.borrow_mut())?; + }); + SwapVersion::pack(obj, &mut swap_info.data.borrow_mut())?; Ok(()) } @@ -345,18 +344,19 @@ impl Processor { if swap_info.owner != program_id { return Err(ProgramError::IncorrectProgramId); } - let token_swap = SwapInfo::unpack(&swap_info.data.borrow())?; + let token_swap = SwapVersion::unpack(&swap_info.data.borrow())?; - if *authority_info.key != Self::authority_id(program_id, swap_info.key, token_swap.nonce)? { + if *authority_info.key != Self::authority_id(program_id, swap_info.key, token_swap.nonce())? + { return Err(SwapError::InvalidProgramAddress.into()); } - if !(*swap_source_info.key == token_swap.token_a - || *swap_source_info.key == token_swap.token_b) + if !(*swap_source_info.key == *token_swap.token_a_account() + || *swap_source_info.key == *token_swap.token_b_account()) { return Err(SwapError::IncorrectSwapAccount.into()); } - if !(*swap_destination_info.key == token_swap.token_a - || *swap_destination_info.key == token_swap.token_b) + if !(*swap_destination_info.key == *token_swap.token_a_account() + || *swap_destination_info.key == *token_swap.token_b_account()) { return Err(SwapError::IncorrectSwapAccount.into()); } @@ -369,35 +369,35 @@ impl Processor { if swap_destination_info.key == destination_info.key { return Err(SwapError::InvalidInput.into()); } - if *pool_mint_info.key != token_swap.pool_mint { + if *pool_mint_info.key != *token_swap.pool_mint() { return Err(SwapError::IncorrectPoolMint.into()); } - if *pool_fee_account_info.key != token_swap.pool_fee_account { + if *pool_fee_account_info.key != *token_swap.pool_fee_account() { return Err(SwapError::IncorrectFeeAccount.into()); } - if *token_program_info.key != token_swap.token_program_id { + if *token_program_info.key != *token_swap.token_program_id() { return Err(SwapError::IncorrectTokenProgramId.into()); } let source_account = - Self::unpack_token_account(swap_source_info, &token_swap.token_program_id)?; + Self::unpack_token_account(swap_source_info, &token_swap.token_program_id())?; let dest_account = - Self::unpack_token_account(swap_destination_info, &token_swap.token_program_id)?; - let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id)?; + Self::unpack_token_account(swap_destination_info, &token_swap.token_program_id())?; + let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id())?; - let trade_direction = if *swap_source_info.key == token_swap.token_a { + let trade_direction = if *swap_source_info.key == *token_swap.token_a_account() { TradeDirection::AtoB } else { TradeDirection::BtoA }; let result = token_swap - .swap_curve + .swap_curve() .swap( to_u128(amount_in)?, to_u128(source_account.amount)?, to_u128(dest_account.amount)?, trade_direction, - &token_swap.fees, + token_swap.fees(), ) .ok_or(SwapError::ZeroTradingTokens)?; if result.destination_amount_swapped < to_u128(minimum_amount_out)? { @@ -421,12 +421,12 @@ impl Processor { source_info.clone(), swap_source_info.clone(), user_transfer_authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), to_u64(result.source_amount_swapped)?, )?; let mut pool_token_amount = token_swap - .swap_curve + .swap_curve() .trading_tokens_to_pool_tokens( result.owner_fee, swap_token_a_amount, @@ -434,7 +434,7 @@ impl Processor { to_u128(pool_mint.supply)?, trade_direction, RoundDirection::Ceiling, - &token_swap.fees, + token_swap.fees(), ) .ok_or(SwapError::FeeCalculationFailure)?; @@ -443,13 +443,13 @@ impl Processor { if let Ok(host_fee_account_info) = next_account_info(account_info_iter) { let host_fee_account = Self::unpack_token_account( host_fee_account_info, - &token_swap.token_program_id, + token_swap.token_program_id(), )?; if *pool_mint_info.key != host_fee_account.mint { return Err(SwapError::IncorrectPoolMint.into()); } let host_fee = token_swap - .fees + .fees() .host_fee(pool_token_amount) .ok_or(SwapError::FeeCalculationFailure)?; if host_fee > 0 { @@ -462,7 +462,7 @@ impl Processor { pool_mint_info.clone(), host_fee_account_info.clone(), authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), to_u64(host_fee)?, )?; } @@ -473,7 +473,7 @@ impl Processor { pool_mint_info.clone(), pool_fee_account_info.clone(), authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), to_u64(pool_token_amount)?, )?; } @@ -484,7 +484,7 @@ impl Processor { swap_destination_info.clone(), destination_info.clone(), authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), to_u64(result.destination_amount_swapped)?, )?; @@ -511,12 +511,13 @@ impl Processor { let dest_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; - let token_swap = SwapInfo::unpack(&swap_info.data.borrow())?; - if !token_swap.swap_curve.calculator.allows_deposits() { + let token_swap = SwapVersion::unpack(&swap_info.data.borrow())?; + let calculator = &token_swap.swap_curve().calculator; + if !calculator.allows_deposits() { return Err(SwapError::UnsupportedCurveOperation.into()); } Self::check_accounts( - &token_swap, + token_swap.as_ref(), program_id, swap_info, authority_info, @@ -529,14 +530,12 @@ impl Processor { None, )?; - let token_a = Self::unpack_token_account(token_a_info, &token_swap.token_program_id)?; - let token_b = Self::unpack_token_account(token_b_info, &token_swap.token_program_id)?; - let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id)?; + let token_a = Self::unpack_token_account(token_a_info, &token_swap.token_program_id())?; + let token_b = Self::unpack_token_account(token_b_info, &token_swap.token_program_id())?; + let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id())?; let pool_token_amount = to_u128(pool_token_amount)?; let pool_mint_supply = to_u128(pool_mint.supply)?; - let calculator = token_swap.swap_curve.calculator; - let results = calculator .pool_tokens_to_trading_tokens( pool_token_amount, @@ -569,7 +568,7 @@ impl Processor { source_a_info.clone(), token_a_info.clone(), user_transfer_authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), token_a_amount, )?; Self::token_transfer( @@ -578,7 +577,7 @@ impl Processor { source_b_info.clone(), token_b_info.clone(), user_transfer_authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), token_b_amount, )?; Self::token_mint_to( @@ -587,7 +586,7 @@ impl Processor { pool_mint_info.clone(), dest_info.clone(), authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), pool_token_amount, )?; @@ -615,9 +614,9 @@ impl Processor { let pool_fee_account_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; - let token_swap = SwapInfo::unpack(&swap_info.data.borrow())?; + let token_swap = SwapVersion::unpack(&swap_info.data.borrow())?; Self::check_accounts( - &token_swap, + token_swap.as_ref(), program_id, swap_info, authority_info, @@ -630,18 +629,18 @@ impl Processor { Some(pool_fee_account_info), )?; - let token_a = Self::unpack_token_account(token_a_info, &token_swap.token_program_id)?; - let token_b = Self::unpack_token_account(token_b_info, &token_swap.token_program_id)?; - let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id)?; + let token_a = Self::unpack_token_account(token_a_info, token_swap.token_program_id())?; + let token_b = Self::unpack_token_account(token_b_info, token_swap.token_program_id())?; + let pool_mint = Self::unpack_mint(pool_mint_info, token_swap.token_program_id())?; - let calculator = token_swap.swap_curve.calculator; + let calculator = &token_swap.swap_curve().calculator; let withdraw_fee: u128 = if *pool_fee_account_info.key == *source_info.key { // withdrawing from the fee account, don't assess withdraw fee 0 } else { token_swap - .fees + .fees() .owner_withdraw_fee(to_u128(pool_token_amount)?) .ok_or(SwapError::FeeCalculationFailure)? }; @@ -680,7 +679,7 @@ impl Processor { source_info.clone(), pool_fee_account_info.clone(), user_transfer_authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), to_u64(withdraw_fee)?, )?; } @@ -690,7 +689,7 @@ impl Processor { source_info.clone(), pool_mint_info.clone(), user_transfer_authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), to_u64(pool_token_amount)?, )?; @@ -702,7 +701,7 @@ impl Processor { token_a_info.clone(), dest_token_a_info.clone(), authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), token_a_amount, )?; } @@ -714,7 +713,7 @@ impl Processor { token_b_info.clone(), dest_token_b_info.clone(), authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), token_b_amount, )?; } @@ -739,12 +738,13 @@ impl Processor { let destination_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; - let token_swap = SwapInfo::unpack(&swap_info.data.borrow())?; - let source_account = Self::unpack_token_account(source_info, &token_swap.token_program_id)?; + let token_swap = SwapVersion::unpack(&swap_info.data.borrow())?; + let source_account = + Self::unpack_token_account(source_info, &token_swap.token_program_id())?; let swap_token_a = - Self::unpack_token_account(swap_token_a_info, &token_swap.token_program_id)?; + Self::unpack_token_account(swap_token_a_info, &token_swap.token_program_id())?; let swap_token_b = - Self::unpack_token_account(swap_token_b_info, &token_swap.token_program_id)?; + Self::unpack_token_account(swap_token_b_info, &token_swap.token_program_id())?; let trade_direction = if source_account.mint == swap_token_a.mint { TradeDirection::AtoB @@ -760,7 +760,7 @@ impl Processor { }; Self::check_accounts( - &token_swap, + token_swap.as_ref(), program_id, swap_info, authority_info, @@ -773,11 +773,11 @@ impl Processor { None, )?; - let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id)?; + let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id())?; let pool_mint_supply = to_u128(pool_mint.supply)?; let pool_token_amount = token_swap - .swap_curve + .swap_curve() .trading_tokens_to_pool_tokens( to_u128(source_token_amount)?, to_u128(swap_token_a.amount)?, @@ -785,7 +785,7 @@ impl Processor { pool_mint_supply, trade_direction, RoundDirection::Floor, - &token_swap.fees, + token_swap.fees(), ) .ok_or(SwapError::ZeroTradingTokens)?; @@ -805,7 +805,7 @@ impl Processor { source_info.clone(), swap_token_a_info.clone(), user_transfer_authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), source_token_amount, )?; } @@ -816,7 +816,7 @@ impl Processor { source_info.clone(), swap_token_b_info.clone(), user_transfer_authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), source_token_amount, )?; } @@ -827,7 +827,7 @@ impl Processor { pool_mint_info.clone(), destination_info.clone(), authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), pool_token_amount, )?; @@ -853,13 +853,13 @@ impl Processor { let pool_fee_account_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; - let token_swap = SwapInfo::unpack(&swap_info.data.borrow())?; + let token_swap = SwapVersion::unpack(&swap_info.data.borrow())?; let destination_account = - Self::unpack_token_account(destination_info, &token_swap.token_program_id)?; + Self::unpack_token_account(destination_info, &token_swap.token_program_id())?; let swap_token_a = - Self::unpack_token_account(swap_token_a_info, &token_swap.token_program_id)?; + Self::unpack_token_account(swap_token_a_info, &token_swap.token_program_id())?; let swap_token_b = - Self::unpack_token_account(swap_token_b_info, &token_swap.token_program_id)?; + Self::unpack_token_account(swap_token_b_info, &token_swap.token_program_id())?; let trade_direction = if destination_account.mint == swap_token_a.mint { TradeDirection::AtoB @@ -874,7 +874,7 @@ impl Processor { TradeDirection::BtoA => (None, Some(destination_info)), }; Self::check_accounts( - &token_swap, + token_swap.as_ref(), program_id, swap_info, authority_info, @@ -887,7 +887,7 @@ impl Processor { Some(pool_fee_account_info), )?; - let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id)?; + let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id())?; let pool_mint_supply = to_u128(pool_mint.supply)?; let (swap_token_a_amount, swap_token_b_amount) = match trade_direction { TradeDirection::AtoB => ( @@ -911,7 +911,7 @@ impl Processor { }; let burn_pool_token_amount = token_swap - .swap_curve + .swap_curve() .trading_tokens_to_pool_tokens( to_u128(destination_token_amount)?, swap_token_a_amount, @@ -919,7 +919,7 @@ impl Processor { pool_mint_supply, trade_direction, RoundDirection::Ceiling, - &token_swap.fees, + token_swap.fees(), ) .ok_or(SwapError::ZeroTradingTokens)?; @@ -928,7 +928,7 @@ impl Processor { 0 } else { token_swap - .fees + .fees() .owner_withdraw_fee(burn_pool_token_amount) .ok_or(SwapError::FeeCalculationFailure)? }; @@ -950,7 +950,7 @@ impl Processor { source_info.clone(), pool_fee_account_info.clone(), user_transfer_authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), to_u64(withdraw_fee)?, )?; } @@ -960,7 +960,7 @@ impl Processor { source_info.clone(), pool_mint_info.clone(), user_transfer_authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), to_u64(burn_pool_token_amount)?, )?; @@ -972,7 +972,7 @@ impl Processor { swap_token_a_info.clone(), destination_info.clone(), authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), destination_token_amount, )?; } @@ -983,7 +983,7 @@ impl Processor { swap_token_b_info.clone(), destination_info.clone(), authority_info.clone(), - token_swap.nonce, + token_swap.nonce(), destination_token_amount, )?; } @@ -1275,7 +1275,7 @@ mod tests { token_b_amount: u64, ) -> Self { let swap_key = Pubkey::new_unique(); - let swap_account = Account::new(0, SwapInfo::get_packed_len(), &SWAP_PROGRAM_ID); + let swap_account = Account::new(0, SwapVersion::LATEST_LEN, &SWAP_PROGRAM_ID); let (authority_key, nonce) = Pubkey::find_program_address(&[&swap_key.to_bytes()[..]], &SWAP_PROGRAM_ID); @@ -2803,19 +2803,19 @@ mod tests { accounts.initialize_swap() ); } - let swap_info = SwapInfo::unpack(&accounts.swap_account.data).unwrap(); - assert_eq!(swap_info.is_initialized, true); - assert_eq!(swap_info.nonce, accounts.nonce); + let swap_state = SwapVersion::unpack(&accounts.swap_account.data).unwrap(); + assert_eq!(swap_state.is_initialized(), true); + assert_eq!(swap_state.nonce(), accounts.nonce); assert_eq!( - swap_info.swap_curve.curve_type, + swap_state.swap_curve().curve_type, accounts.swap_curve.curve_type ); - assert_eq!(swap_info.token_a, accounts.token_a_key); - assert_eq!(swap_info.token_b, accounts.token_b_key); - assert_eq!(swap_info.pool_mint, accounts.pool_mint_key); - assert_eq!(swap_info.token_a_mint, accounts.token_a_mint_key); - assert_eq!(swap_info.token_b_mint, accounts.token_b_mint_key); - assert_eq!(swap_info.pool_fee_account, accounts.pool_fee_key); + assert_eq!(*swap_state.token_a_account(), accounts.token_a_key); + assert_eq!(*swap_state.token_b_account(), accounts.token_b_key); + assert_eq!(*swap_state.pool_mint(), accounts.pool_mint_key); + assert_eq!(*swap_state.token_a_mint(), accounts.token_a_mint_key); + assert_eq!(*swap_state.token_b_mint(), accounts.token_b_mint_key); + assert_eq!(*swap_state.pool_fee_account(), accounts.pool_fee_key); let token_a = spl_token::state::Account::unpack(&accounts.token_a_account.data).unwrap(); assert_eq!(token_a.amount, token_a_amount); let token_b = spl_token::state::Account::unpack(&accounts.token_b_account.data).unwrap(); diff --git a/token-swap/program/src/state.rs b/token-swap/program/src/state.rs index 2d35c905b24..891b2f9cf44 100644 --- a/token-swap/program/src/state.rs +++ b/token-swap/program/src/state.rs @@ -2,16 +2,93 @@ use crate::curve::{base::SwapCurve, fees::Fees}; use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs}; +use enum_dispatch::enum_dispatch; use solana_program::{ program_error::ProgramError, program_pack::{IsInitialized, Pack, Sealed}, pubkey::Pubkey, }; +/// Trait representing access to program state across all versions +#[enum_dispatch] +pub trait SwapState { + /// Is the swap initialized, with data written to it + fn is_initialized(&self) -> bool; + /// Bump seed used to generate the program address / authority + fn nonce(&self) -> u8; + /// Token program ID associated with the swap + fn token_program_id(&self) -> &Pubkey; + /// Address of token A liquidity account + fn token_a_account(&self) -> &Pubkey; + /// Address of token B liquidity account + fn token_b_account(&self) -> &Pubkey; + /// Address of pool token mint + fn pool_mint(&self) -> &Pubkey; + + /// Address of token A mint + fn token_a_mint(&self) -> &Pubkey; + /// Address of token B mint + fn token_b_mint(&self) -> &Pubkey; + + /// Address of pool fee account + fn pool_fee_account(&self) -> &Pubkey; + + /// Fees associated with swap + fn fees(&self) -> &Fees; + /// Curve associated with swap + fn swap_curve(&self) -> &SwapCurve; +} + +/// All versions of SwapState +#[enum_dispatch(SwapState)] +pub enum SwapVersion { + /// Latest version, used for all new swaps + SwapV1, +} + +/// SwapVersion does not implement program_pack::Pack because there are size +/// checks on pack and unpack that would break backwards compatibility, so +/// special implementations are provided here +impl SwapVersion { + /// Size of the latest version of the SwapState + pub const LATEST_LEN: usize = 1 + SwapV1::LEN; // add one for the version enum + + /// Pack a swap into a byte array, based on its version + pub fn pack(src: Self, dst: &mut [u8]) -> Result<(), ProgramError> { + match src { + Self::SwapV1(swap_info) => { + dst[0] = 1; + SwapV1::pack(swap_info, &mut dst[1..]) + } + } + } + + /// Unpack the swap account based on its version, returning the result as a + /// SwapState trait object + pub fn unpack(input: &[u8]) -> Result, ProgramError> { + let (&version, rest) = input + .split_first() + .ok_or(ProgramError::InvalidAccountData)?; + match version { + 1 => Ok(Box::new(SwapV1::unpack(rest)?)), + _ => Err(ProgramError::UninitializedAccount), + } + } + + /// Special check to be done before any instruction processing, works for + /// all versions + pub fn is_initialized(input: &[u8]) -> bool { + match Self::unpack(input) { + Ok(swap) => swap.is_initialized(), + Err(_) => false, + } + } +} + /// Program states. #[repr(C)] #[derive(Debug, Default, PartialEq)] -pub struct SwapInfo { +pub struct SwapV1 { /// Initialized state. pub is_initialized: bool, /// Nonce used in program address. @@ -49,14 +126,60 @@ pub struct SwapInfo { pub swap_curve: SwapCurve, } -impl Sealed for SwapInfo {} -impl IsInitialized for SwapInfo { +impl SwapState for SwapV1 { fn is_initialized(&self) -> bool { self.is_initialized } + + fn nonce(&self) -> u8 { + self.nonce + } + + fn token_program_id(&self) -> &Pubkey { + &self.token_program_id + } + + fn token_a_account(&self) -> &Pubkey { + &self.token_a + } + + fn token_b_account(&self) -> &Pubkey { + &self.token_b + } + + fn pool_mint(&self) -> &Pubkey { + &self.pool_mint + } + + fn token_a_mint(&self) -> &Pubkey { + &self.token_a_mint + } + + fn token_b_mint(&self) -> &Pubkey { + &self.token_b_mint + } + + fn pool_fee_account(&self) -> &Pubkey { + &self.pool_fee_account + } + + fn fees(&self) -> &Fees { + &self.fees + } + + fn swap_curve(&self) -> &SwapCurve { + &self.swap_curve + } } -impl Pack for SwapInfo { +impl Sealed for SwapV1 {} +impl IsInitialized for SwapV1 { + fn is_initialized(&self) -> bool { + self.is_initialized + } +} + +impl Pack for SwapV1 { const LEN: usize = 323; fn pack_into_slice(&self, output: &mut [u8]) { @@ -87,7 +210,7 @@ impl Pack for SwapInfo { self.swap_curve.pack_into_slice(&mut swap_curve[..]); } - /// Unpacks a byte buffer into a [SwapInfo](struct.SwapInfo.html). + /// Unpacks a byte buffer into a [SwapV1](struct.SwapV1.html). fn unpack_from_slice(input: &[u8]) -> Result { let input = array_ref![input, 0, 323]; #[allow(clippy::ptr_offset_with_cast)] @@ -131,98 +254,125 @@ mod tests { use std::convert::TryInto; + const TEST_FEES: Fees = Fees { + trade_fee_numerator: 1, + trade_fee_denominator: 4, + owner_trade_fee_numerator: 3, + owner_trade_fee_denominator: 10, + owner_withdraw_fee_numerator: 2, + owner_withdraw_fee_denominator: 7, + host_fee_numerator: 5, + host_fee_denominator: 20, + }; + + const TEST_NONCE: u8 = 255; + const TEST_TOKEN_PROGRAM_ID: Pubkey = Pubkey::new_from_array([1u8; 32]); + const TEST_TOKEN_A: Pubkey = Pubkey::new_from_array([2u8; 32]); + const TEST_TOKEN_B: Pubkey = Pubkey::new_from_array([3u8; 32]); + const TEST_POOL_MINT: Pubkey = Pubkey::new_from_array([4u8; 32]); + const TEST_TOKEN_A_MINT: Pubkey = Pubkey::new_from_array([5u8; 32]); + const TEST_TOKEN_B_MINT: Pubkey = Pubkey::new_from_array([6u8; 32]); + const TEST_POOL_FEE_ACCOUNT: Pubkey = Pubkey::new_from_array([7u8; 32]); + + const TEST_CURVE_TYPE: u8 = 2; + const TEST_AMP: u64 = 1; + const TEST_CURVE: StableCurve = StableCurve { amp: TEST_AMP }; + #[test] - fn test_swap_info_packing() { - let nonce = 255; - let curve_type_raw: u8 = 2; - let curve_type = curve_type_raw.try_into().unwrap(); - let token_program_id_raw = [1u8; 32]; - let token_a_raw = [1u8; 32]; - let token_b_raw = [2u8; 32]; - let pool_mint_raw = [3u8; 32]; - let token_a_mint_raw = [4u8; 32]; - let token_b_mint_raw = [5u8; 32]; - let pool_fee_account_raw = [6u8; 32]; - let token_program_id = Pubkey::new_from_array(token_program_id_raw); - let token_a = Pubkey::new_from_array(token_a_raw); - let token_b = Pubkey::new_from_array(token_b_raw); - let pool_mint = Pubkey::new_from_array(pool_mint_raw); - let token_a_mint = Pubkey::new_from_array(token_a_mint_raw); - let token_b_mint = Pubkey::new_from_array(token_b_mint_raw); - let pool_fee_account = Pubkey::new_from_array(pool_fee_account_raw); - let trade_fee_numerator = 1; - let trade_fee_denominator = 4; - let owner_trade_fee_numerator = 3; - let owner_trade_fee_denominator = 10; - let owner_withdraw_fee_numerator = 2; - let owner_withdraw_fee_denominator = 7; - let host_fee_numerator = 5; - let host_fee_denominator = 20; - let amp: u64 = 1; - let fees = Fees { - trade_fee_numerator, - trade_fee_denominator, - owner_trade_fee_numerator, - owner_trade_fee_denominator, - owner_withdraw_fee_numerator, - owner_withdraw_fee_denominator, - host_fee_numerator, - host_fee_denominator, + fn swap_version_pack() { + let curve_type = TEST_CURVE_TYPE.try_into().unwrap(); + let calculator = Box::new(TEST_CURVE); + let swap_curve = SwapCurve { + curve_type, + calculator, }; - let calculator = Box::new(StableCurve { amp }); + let swap_info = SwapVersion::SwapV1(SwapV1 { + is_initialized: true, + nonce: TEST_NONCE, + token_program_id: TEST_TOKEN_PROGRAM_ID, + token_a: TEST_TOKEN_A, + token_b: TEST_TOKEN_B, + pool_mint: TEST_POOL_MINT, + token_a_mint: TEST_TOKEN_A_MINT, + token_b_mint: TEST_TOKEN_B_MINT, + pool_fee_account: TEST_POOL_FEE_ACCOUNT, + fees: TEST_FEES, + swap_curve: swap_curve.clone(), + }); + + let mut packed = [0u8; SwapVersion::LATEST_LEN]; + SwapVersion::pack(swap_info, &mut packed).unwrap(); + let unpacked = SwapVersion::unpack(&packed).unwrap(); + + assert_eq!(unpacked.is_initialized(), true); + assert_eq!(unpacked.nonce(), TEST_NONCE); + assert_eq!(*unpacked.token_program_id(), TEST_TOKEN_PROGRAM_ID); + assert_eq!(*unpacked.token_a_account(), TEST_TOKEN_A); + assert_eq!(*unpacked.token_b_account(), TEST_TOKEN_B); + assert_eq!(*unpacked.pool_mint(), TEST_POOL_MINT); + assert_eq!(*unpacked.token_a_mint(), TEST_TOKEN_A_MINT); + assert_eq!(*unpacked.token_b_mint(), TEST_TOKEN_B_MINT); + assert_eq!(*unpacked.pool_fee_account(), TEST_POOL_FEE_ACCOUNT); + assert_eq!(*unpacked.fees(), TEST_FEES); + assert_eq!(*unpacked.swap_curve(), swap_curve); + } + + #[test] + fn swap_v1_pack() { + let curve_type = TEST_CURVE_TYPE.try_into().unwrap(); + let calculator = Box::new(TEST_CURVE); let swap_curve = SwapCurve { curve_type, calculator, }; - let is_initialized = true; - let swap_info = SwapInfo { - is_initialized, - nonce, - token_program_id, - token_a, - token_b, - pool_mint, - token_a_mint, - token_b_mint, - pool_fee_account, - fees, + let swap_info = SwapV1 { + is_initialized: true, + nonce: TEST_NONCE, + token_program_id: TEST_TOKEN_PROGRAM_ID, + token_a: TEST_TOKEN_A, + token_b: TEST_TOKEN_B, + pool_mint: TEST_POOL_MINT, + token_a_mint: TEST_TOKEN_A_MINT, + token_b_mint: TEST_TOKEN_B_MINT, + pool_fee_account: TEST_POOL_FEE_ACCOUNT, + fees: TEST_FEES, swap_curve, }; - let mut packed = [0u8; SwapInfo::LEN]; - SwapInfo::pack_into_slice(&swap_info, &mut packed); - let unpacked = SwapInfo::unpack(&packed).unwrap(); + let mut packed = [0u8; SwapV1::LEN]; + SwapV1::pack_into_slice(&swap_info, &mut packed); + let unpacked = SwapV1::unpack(&packed).unwrap(); assert_eq!(swap_info, unpacked); let mut packed = vec![]; packed.push(1u8); - packed.push(nonce); - packed.extend_from_slice(&token_program_id_raw); - packed.extend_from_slice(&token_a_raw); - packed.extend_from_slice(&token_b_raw); - packed.extend_from_slice(&pool_mint_raw); - packed.extend_from_slice(&token_a_mint_raw); - packed.extend_from_slice(&token_b_mint_raw); - packed.extend_from_slice(&pool_fee_account_raw); - packed.extend_from_slice(&trade_fee_numerator.to_le_bytes()); - packed.extend_from_slice(&trade_fee_denominator.to_le_bytes()); - packed.extend_from_slice(&owner_trade_fee_numerator.to_le_bytes()); - packed.extend_from_slice(&owner_trade_fee_denominator.to_le_bytes()); - packed.extend_from_slice(&owner_withdraw_fee_numerator.to_le_bytes()); - packed.extend_from_slice(&owner_withdraw_fee_denominator.to_le_bytes()); - packed.extend_from_slice(&host_fee_numerator.to_le_bytes()); - packed.extend_from_slice(&host_fee_denominator.to_le_bytes()); - packed.push(curve_type_raw); - packed.extend_from_slice(&.to_le_bytes()); + packed.push(TEST_NONCE); + packed.extend_from_slice(&TEST_TOKEN_PROGRAM_ID.to_bytes()); + packed.extend_from_slice(&TEST_TOKEN_A.to_bytes()); + packed.extend_from_slice(&TEST_TOKEN_B.to_bytes()); + packed.extend_from_slice(&TEST_POOL_MINT.to_bytes()); + packed.extend_from_slice(&TEST_TOKEN_A_MINT.to_bytes()); + packed.extend_from_slice(&TEST_TOKEN_B_MINT.to_bytes()); + packed.extend_from_slice(&TEST_POOL_FEE_ACCOUNT.to_bytes()); + packed.extend_from_slice(&TEST_FEES.trade_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&TEST_FEES.trade_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&TEST_FEES.owner_trade_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&TEST_FEES.owner_trade_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&TEST_FEES.owner_withdraw_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&TEST_FEES.owner_withdraw_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&TEST_FEES.host_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&TEST_FEES.host_fee_denominator.to_le_bytes()); + packed.push(TEST_CURVE_TYPE); + packed.extend_from_slice(&TEST_AMP.to_le_bytes()); packed.extend_from_slice(&[0u8; 24]); - let unpacked = SwapInfo::unpack(&packed).unwrap(); + let unpacked = SwapV1::unpack(&packed).unwrap(); assert_eq!(swap_info, unpacked); - let packed = [0u8; SwapInfo::LEN]; - let swap_info: SwapInfo = Default::default(); - let unpack_unchecked = SwapInfo::unpack_unchecked(&packed).unwrap(); + let packed = [0u8; SwapV1::LEN]; + let swap_info: SwapV1 = Default::default(); + let unpack_unchecked = SwapV1::unpack_unchecked(&packed).unwrap(); assert_eq!(unpack_unchecked, swap_info); - let err = SwapInfo::unpack(&packed).unwrap_err(); + let err = SwapV1::unpack(&packed).unwrap_err(); assert_eq!(err, ProgramError::UninitializedAccount); } }