Skip to content

Commit

Permalink
Merge branch 'master' into crispeheaney/compute-units
Browse files Browse the repository at this point in the history
  • Loading branch information
0xbigz committed Nov 17, 2023
2 parents 465c796 + 15eb39c commit a037e12
Show file tree
Hide file tree
Showing 15 changed files with 105 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.6.42
current_version = 0.6.43
commit = True
tag = True
tag_name = {new_version}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ on:
branches: [master]

env:
solana_verion: 1.8.5
anchor_version: 0.19.0
solana_verion: 1.14.7
anchor_version: 0.26.0

jobs:
tests:
Expand Down
2 changes: 1 addition & 1 deletion docs/accounts.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ print(

# get usdc spot market info
usdc_spot_market_index = 0
usdc_market = await get_spot_market_account(clearing_house.program, usdc_spot_market_index)
usdc_market = await get_spot_market_account(drift_client.program, usdc_spot_market_index)
print(
usdc.market_index,
usdc.deposit_balance,
Expand Down
8 changes: 4 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ wallet = Wallet(kp)
connection = AsyncClient(config.default_http)
provider = Provider(connection, wallet)

drfit_client = DriftClient.from_config(config, provider)
drift_user = User(clearing_house)
drift_client = DriftClient.from_config(config, provider)
drift_user = User(drift_client)

# open a 10 SOL long position
sig = await drift_client.open_position(
Expand All @@ -68,12 +68,12 @@ leverage = await drift_user.get_leverage()
print('current leverage:', leverage / 10_000)

# you can also inspect other accounts information using the (authority=) flag
bigz_acc = ser(clearing_house, authority=PublicKey('bigZ'))
bigz_acc = User(drift_client, authority=PublicKey('bigZ'))
leverage = await bigz_acc.get_leverage()
print('bigZs leverage:', leverage / 10_000)

# clearing house user calls can be expensive on the rpc so we can cache them
drift_user = User(clearing_house, use_cache=True)
drift_user = User(drift_client, use_cache=True)
await drift_user.set_cache()

# works without any rpc calls (uses the cached data)
Expand Down
2 changes: 1 addition & 1 deletion examples/if_stake.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from driftpy.constants.config import configs
from driftpy.drift_client import DriftClient
from driftpy.accounts import *
from driftpy.drift_user import User
from driftpy.drift_user import DriftUser

async def view_logs(
sig: str,
Expand Down
6 changes: 3 additions & 3 deletions examples/limit_order_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from driftpy.accounts.oracle import get_oracle_price_data_and_slot
from driftpy.math.spot_market import get_signed_token_amount, get_token_amount
from driftpy.drift_client import DriftClient
from driftpy.drift_user import User
from driftpy.drift_user import DriftUser
from driftpy.constants.numeric_constants import BASE_PRECISION, PRICE_PRECISION
from borsh_construct.enum import _rust_enum
import time
Expand Down Expand Up @@ -118,7 +118,7 @@ async def main(
drift_acct.program, market_index
)
try:
oracle_data = await get_oracle_price_data_and_slot(connection, market.amm.oracle)
oracle_data = (await get_oracle_price_data_and_slot(connection, market.amm.oracle)).data
current_price = oracle_data.price/PRICE_PRECISION
except:
current_price = market.amm.historical_oracle_data.last_oracle_price/PRICE_PRECISION
Expand All @@ -132,7 +132,7 @@ async def main(
else:
market = await get_spot_market_account( drift_acct.program, market_index)
try:
oracle_data = await get_oracle_price_data_and_slot(connection, market.oracle)
oracle_data = (await get_oracle_price_data_and_slot(connection, market.oracle)).data
current_price = oracle_data.price/PRICE_PRECISION
except:
current_price = market.historical_oracle_data.last_oracle_price/PRICE_PRECISION
Expand Down
2 changes: 1 addition & 1 deletion examples/start_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# todo: airdrop udsc + init account for any kp
# rn do it through UI
from driftpy.drift_user import User
from driftpy.drift_user import DriftUser
from driftpy.constants.numeric_constants import AMM_RESERVE_PRECISION
from solana.rpc import commitment
import pprint
Expand Down
2 changes: 1 addition & 1 deletion examples/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from driftpy.math.positions import is_available
from driftpy.constants.numeric_constants import *

from driftpy.drift_user import User
from driftpy.drift_user import DriftUser

async def main(
authority,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "driftpy"
version = "0.6.42"
version = "0.6.43"
description = "A Python client for the Drift DEX"
authors = ["x19 <https://twitter.com/[email protected]>", "bigz <https://twitter.com/bigz_pubkey>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion src/driftpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.6.42"
__version__ = "0.6.43"
3 changes: 2 additions & 1 deletion src/driftpy/accounts/cache/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def __init__(
self.user_and_slot = None

async def update_cache(self):
user_and_slot = await get_user_account_and_slot(self.program)
user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey)
self.user_and_slot = user_and_slot

async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]:
await self.cache_if_needed()
return self.user_and_slot

async def cache_if_needed(self):
Expand Down
67 changes: 34 additions & 33 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ def from_config(config: Config, provider: Provider, authority: Keypair = None):
provider,
)

drift_client = DriftClient
(program, authority)
drift_client = DriftClient(program, authority)
drift_client.config = config
drift_client.idl = idl

Expand Down Expand Up @@ -710,7 +709,7 @@ async def place_spot_order(
return await self.send_ixs(
[
self.get_increase_compute_ix(),
await self.get_place_spot_order_ix(order_params, maker_info, user_id),
await self.get_place_spot_order_ix(order_params, user_id),
]
)

Expand Down Expand Up @@ -788,62 +787,64 @@ async def place_perp_order(
user_id: int = 0,
):
return await self.send_ixs(
[
[
self.get_increase_compute_ix(),
await self.get_place_perp_order_ix(order_params, maker_info, user_id),
(await self.get_place_perp_order_ix(order_params, user_id))[-1]
]

)

async def get_place_perp_order_ix(
self,
order_params: OrderParams,
user_id: int = 0,
):
) -> TransactionInstruction:
user_account_public_key = self.get_user_account_public_key(user_id)
remaining_accounts = await self.get_remaining_accounts(
writable_market_index=order_params.market_index, user_id=user_id
)

ix = self.program.instruction["place_perp_order"](
order_params,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"authority": self.signer.public_key,
},
remaining_accounts=remaining_accounts,
),
)
order_params,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"authority": self.signer.public_key,
},
remaining_accounts=remaining_accounts,
),
)

return ix

async def get_place_perp_orders_ix(
self,
order_params: List[OrderParams],
user_id: int = 0,
cancel_all=True
):
user_account_public_key = self.get_user_account_public_key(user_id)
writeable_market_indexes = list(set([x.market_index for x in order_params]))
remaining_accounts = await self.get_remaining_accounts(
writable_market_index=writeable_market_indexes, user_id=user_id
)

ixs = [
self.program.instruction["cancel_orders"](
None,
None,
None,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": self.get_user_account_public_key(user_id),
"authority": self.signer.public_key,
},
remaining_accounts=remaining_accounts,
),
)
]
ixs = []
if cancel_all:
ixs.append(
self.program.instruction["cancel_orders"](
None,
None,
None,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": self.get_user_account_public_key(user_id),
"authority": self.signer.public_key,
},
remaining_accounts=remaining_accounts,
),
))
for order_param in order_params:
ix = self.program.instruction["place_perp_order"](
order_param,
Expand Down Expand Up @@ -1025,7 +1026,7 @@ def default_order_params(
price=0,
market_index=market_index,
reduce_only=False,
post_only=PostOnlyParam.NONE(),
post_only=PostOnlyParams.NONE(),
immediate_or_cancel=False,
trigger_price=0,
trigger_condition=OrderTriggerCondition.ABOVE(),
Expand Down
14 changes: 12 additions & 2 deletions src/driftpy/drift_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from driftpy.types import OraclePriceData


class User:
"""This class is the main way to retrieve and inspect user account data."""
class DriftUser:
"""This class is the main way to retrieve and inspect drift user account data."""

def __init__(
self,
Expand Down Expand Up @@ -68,6 +68,16 @@ async def get_perp_market(self, market_index: int) -> PerpMarket:
async def get_user(self) -> User:
return (await self.account_subscriber.get_user_account_and_slot()).data


async def get_open_orders(self,
# market_type: MarketType,
# market_index: int,
# position_direction: PositionDirection
):
user: User = await self.get_user()
return user.orders


async def get_spot_market_liability(
self,
market_index=None,
Expand Down
6 changes: 3 additions & 3 deletions src/driftpy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class OracleSource:


@_rust_enum
class PostOnlyParam:
class PostOnlyParams:
NONE = constructor()
MUST_POST_ONLY = constructor()
TRY_POST_ONLY = constructor()
Expand Down Expand Up @@ -340,7 +340,7 @@ class OrderParams:
price: int
market_index: int
reduce_only: bool
post_only: PostOnlyParam
post_only: PostOnlyParams
immediate_or_cancel: bool
max_ts: Optional[int]
trigger_price: Optional[int]
Expand All @@ -357,7 +357,7 @@ class ModifyOrderParams:
base_asset_amount: Optional[int]
price: Optional[int]
reduce_only: Optional[bool]
post_only: Optional[PostOnlyParam]
post_only: Optional[PostOnlyParams]
immediate_or_cancel: Optional[bool]
max_ts: Optional[int]
trigger_price: Optional[int]
Expand Down
38 changes: 38 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from driftpy.constants.numeric_constants import (
PRICE_PRECISION,
AMM_RESERVE_PRECISION,
BASE_PRECISION,
QUOTE_PRECISION,
SPOT_BALANCE_PRECISION,
SPOT_WEIGHT_PRECISION,
)
from math import sqrt

from driftpy.drift_user import DriftUser
from driftpy.drift_client import DriftClient
from driftpy.setup.helpers import (
_create_mint,
Expand All @@ -28,6 +30,8 @@
PositionDirection,
OracleSource,
PerpMarket,
OrderType,
OrderParams
# SwapDirection,
)
from driftpy.accounts import (
Expand Down Expand Up @@ -198,6 +202,40 @@ async def test_usdc_deposit(
== USDC_AMOUNT / QUOTE_PRECISION * SPOT_BALANCE_PRECISION
)

@mark.asyncio
async def test_open_orders(
drift_client: Admin,
):

drift_user = DriftUser(drift_client)
user_account = await drift_client.get_user(0)

assert(len(user_account.orders)==32)
assert(user_account.orders[0].market_index == 0)

open_orders = await drift_user.get_open_orders()
assert(len(open_orders)==32)
assert(open_orders==user_account.orders)

order_params: OrderParams = drift_client.default_order_params(
OrderType.MARKET(), 0, int(1 * BASE_PRECISION), PositionDirection.LONG()
)
order_params.user_order_id = 169
ixs = await drift_client.get_place_perp_orders_ix([order_params])
await drift_client.send_ixs(ixs)
await drift_user.account_subscriber.update_cache()
open_orders_after = await drift_user.get_open_orders()
assert(open_orders_after[0].base_asset_amount == BASE_PRECISION)
assert(open_orders_after[0].order_id == 1)
assert(open_orders_after[0].user_order_id == 169)

await drift_client.cancel_order(1, 0)
await drift_user.account_subscriber.update_cache()
open_orders_after2 = await drift_user.get_open_orders()
assert(open_orders_after2[0].base_asset_amount == 0)




@mark.asyncio
async def test_update_curve(
Expand Down

0 comments on commit a037e12

Please sign in to comment.