diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index f86b0702..f364ef6c 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -3,6 +3,9 @@ from typing import Optional from solders.keypair import Keypair from solana.transaction import Transaction +from solders.transaction import VersionedTransaction +from solders.transaction import TransactionVersion, Legacy +from solders.message import MessageV0 from solders.instruction import Instruction from solders.system_program import ID from solders.sysvar import RENT @@ -35,7 +38,7 @@ class DriftClient: depositing, opening new positions, closing positions, placing orders, etc. """ - def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None, tx_params: Optional[TxParams] = None): + def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None, tx_params: Optional[TxParams] = None, tx_version: Optional[TransactionVersion] = None): """Initializes the drift client object -- likely want to use the .from_config method instead of this one Args: @@ -69,6 +72,8 @@ def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = self.tx_params = tx_params + self.tx_version = tx_version if tx_version is not None else Legacy + @staticmethod def from_config(config: Config, provider: Provider, authority: Keypair = None): """Initializes the drift client object from a Config @@ -139,24 +144,32 @@ async def send_ixs( if isinstance(ixs, Instruction): ixs = [ixs] - tx = Transaction() - if self.tx_params.compute_units is not None: - tx.add(set_compute_unit_limit(self.tx_params.compute_units)) + ixs.insert(0, set_compute_unit_limit(self.tx_params.compute_units)) if self.tx_params.compute_units_price is not None: - tx.add(set_compute_unit_price(self.tx_params.compute_units_price)) - - [tx.add(ix) for ix in ixs] + ixs.insert(1, set_compute_unit_price(self.tx_params.compute_units_price)) latest_blockhash = (await self.program.provider.connection.get_latest_blockhash()).value.blockhash - tx.recent_blockhash = latest_blockhash - tx.fee_payer = self.signer.pubkey() - tx.sign_partial(self.signer) + if self.tx_version == Legacy: + tx = Transaction( + instructions=ixs, + recent_blockhash=latest_blockhash, + fee_payer=self.signer.pubkey() + ) - if signers is not None: - [tx.sign_partial(signer) for signer in signers] + tx.sign_partial(self.signer) + + if signers is not None: + [tx.sign_partial(signer) for signer in signers] + elif self.tx_version == 0: + msg = MessageV0.try_compile( + self.signer.pubkey(), ixs, [], latest_blockhash + ) + tx = VersionedTransaction(msg, [self.signer]) + else: + raise NotImplementedError("unknown tx version", self.tx_version) return await self.program.provider.send(tx)