Skip to content

Commit

Permalink
add versioned tx
Browse files Browse the repository at this point in the history
  • Loading branch information
crispheaney committed Nov 17, 2023
1 parent 71190ee commit 99708f3
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import Optional
from solders.keypair import Keypair
from solana.transaction import Transaction
from solders.transaction import VersionedTransaction
from solders.transaction import TransactionVersion, Legacy
from solders.message import MessageV0
from solders.instruction import Instruction
from solders.system_program import ID
from solders.sysvar import RENT
Expand Down Expand Up @@ -35,7 +38,7 @@ class DriftClient:
depositing, opening new positions, closing positions, placing orders, etc.
"""

def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None, tx_params: Optional[TxParams] = None):
def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None, tx_params: Optional[TxParams] = None, tx_version: Optional[TransactionVersion] = None):
"""Initializes the drift client object -- likely want to use the .from_config method instead of this one
Args:
Expand Down Expand Up @@ -69,6 +72,8 @@ def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey =

self.tx_params = tx_params

self.tx_version = tx_version if tx_version is not None else Legacy

@staticmethod
def from_config(config: Config, provider: Provider, authority: Keypair = None):
"""Initializes the drift client object from a Config
Expand Down Expand Up @@ -139,24 +144,32 @@ async def send_ixs(
if isinstance(ixs, Instruction):
ixs = [ixs]

tx = Transaction()

if self.tx_params.compute_units is not None:
tx.add(set_compute_unit_limit(self.tx_params.compute_units))
ixs.insert(0, set_compute_unit_limit(self.tx_params.compute_units))

if self.tx_params.compute_units_price is not None:
tx.add(set_compute_unit_price(self.tx_params.compute_units_price))

[tx.add(ix) for ix in ixs]
ixs.insert(1, set_compute_unit_price(self.tx_params.compute_units_price))

latest_blockhash = (await self.program.provider.connection.get_latest_blockhash()).value.blockhash
tx.recent_blockhash = latest_blockhash
tx.fee_payer = self.signer.pubkey()

tx.sign_partial(self.signer)
if self.tx_version == Legacy:
tx = Transaction(
instructions=ixs,
recent_blockhash=latest_blockhash,
fee_payer=self.signer.pubkey()
)

if signers is not None:
[tx.sign_partial(signer) for signer in signers]
tx.sign_partial(self.signer)

if signers is not None:
[tx.sign_partial(signer) for signer in signers]
elif self.tx_version == 0:
msg = MessageV0.try_compile(
self.signer.pubkey(), ixs, [], latest_blockhash
)
tx = VersionedTransaction(msg, [self.signer])
else:
raise NotImplementedError("unknown tx version", self.tx_version)

return await self.program.provider.send(tx)

Expand Down

0 comments on commit 99708f3

Please sign in to comment.