diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..745ff771 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://github.com/psf/black + rev: 23.11.0 + hooks: + - id: black + language_version: python3.10 \ No newline at end of file diff --git a/Makefile b/Makefile index 5c58e85e..b82c4419 100644 --- a/Makefile +++ b/Makefile @@ -4,5 +4,8 @@ test: lint: poetry run black --check --diff src tests - poetry run flake8 src tests poetry run mypy src tests + +lintfix: + poetry run black src tests + diff --git a/poetry.lock b/poetry.lock index 6ff0858e..b385e42d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -298,6 +298,22 @@ docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] +[[package]] +name = "autopep8" +version = "2.0.4" +description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "autopep8-2.0.4-py2.py3-none-any.whl", hash = "sha256:067959ca4a07b24dbd5345efa8325f5f58da4298dab0dde0443d5ed765de80cb"}, + {file = "autopep8-2.0.4.tar.gz", hash = "sha256:2913064abd97b3419d1cc83ea71f042cb821f87e45b9c88cad5ad3c4ea87fe0c"}, +] + +[package.dependencies] +pycodestyle = ">=2.10.0" +tomli = {version = "*", markers = "python_version < \"3.11\""} + [[package]] name = "backoff" version = "2.2.1" @@ -1237,48 +1253,57 @@ files = [ [[package]] name = "mypy" -version = "0.931" +version = "1.7.0" description = "Optional static typing for Python" -category = "dev" +category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "mypy-0.931-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3c5b42d0815e15518b1f0990cff7a705805961613e701db60387e6fb663fe78a"}, - {file = "mypy-0.931-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c89702cac5b302f0c5d33b172d2b55b5df2bede3344a2fbed99ff96bddb2cf00"}, - {file = "mypy-0.931-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:300717a07ad09525401a508ef5d105e6b56646f7942eb92715a1c8d610149714"}, - {file = "mypy-0.931-cp310-cp310-win_amd64.whl", hash = "sha256:7b3f6f557ba4afc7f2ce6d3215d5db279bcf120b3cfd0add20a5d4f4abdae5bc"}, - {file = "mypy-0.931-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:1bf752559797c897cdd2c65f7b60c2b6969ffe458417b8d947b8340cc9cec08d"}, - {file = "mypy-0.931-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4365c60266b95a3f216a3047f1d8e3f895da6c7402e9e1ddfab96393122cc58d"}, - {file = "mypy-0.931-cp36-cp36m-win_amd64.whl", hash = "sha256:1b65714dc296a7991000b6ee59a35b3f550e0073411ac9d3202f6516621ba66c"}, - {file = "mypy-0.931-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e839191b8da5b4e5d805f940537efcaa13ea5dd98418f06dc585d2891d228cf0"}, - {file = "mypy-0.931-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:50c7346a46dc76a4ed88f3277d4959de8a2bd0a0fa47fa87a4cde36fe247ac05"}, - {file = "mypy-0.931-cp37-cp37m-win_amd64.whl", hash = "sha256:d8f1ff62f7a879c9fe5917b3f9eb93a79b78aad47b533911b853a757223f72e7"}, - {file = "mypy-0.931-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f9fe20d0872b26c4bba1c1be02c5340de1019530302cf2dcc85c7f9fc3252ae0"}, - {file = "mypy-0.931-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1b06268df7eb53a8feea99cbfff77a6e2b205e70bf31743e786678ef87ee8069"}, - {file = "mypy-0.931-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8c11003aaeaf7cc2d0f1bc101c1cc9454ec4cc9cb825aef3cafff8a5fdf4c799"}, - {file = "mypy-0.931-cp38-cp38-win_amd64.whl", hash = "sha256:d9d2b84b2007cea426e327d2483238f040c49405a6bf4074f605f0156c91a47a"}, - {file = "mypy-0.931-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ff3bf387c14c805ab1388185dd22d6b210824e164d4bb324b195ff34e322d166"}, - {file = "mypy-0.931-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5b56154f8c09427bae082b32275a21f500b24d93c88d69a5e82f3978018a0266"}, - {file = "mypy-0.931-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8ca7f8c4b1584d63c9a0f827c37ba7a47226c19a23a753d52e5b5eddb201afcd"}, - {file = "mypy-0.931-cp39-cp39-win_amd64.whl", hash = "sha256:74f7eccbfd436abe9c352ad9fb65872cc0f1f0a868e9d9c44db0893440f0c697"}, - {file = "mypy-0.931-py3-none-any.whl", hash = "sha256:1171f2e0859cfff2d366da2c7092b06130f232c636a3f7301e3feb8b41f6377d"}, - {file = "mypy-0.931.tar.gz", hash = "sha256:0038b21890867793581e4cb0d810829f5fd4441aa75796b53033af3aa30430ce"}, + {file = "mypy-1.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5da84d7bf257fd8f66b4f759a904fd2c5a765f70d8b52dde62b521972a0a2357"}, + {file = "mypy-1.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a3637c03f4025f6405737570d6cbfa4f1400eb3c649317634d273687a09ffc2f"}, + {file = "mypy-1.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b633f188fc5ae1b6edca39dae566974d7ef4e9aaaae00bc36efe1f855e5173ac"}, + {file = "mypy-1.7.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d6ed9a3997b90c6f891138e3f83fb8f475c74db4ccaa942a1c7bf99e83a989a1"}, + {file = "mypy-1.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:1fe46e96ae319df21359c8db77e1aecac8e5949da4773c0274c0ef3d8d1268a9"}, + {file = "mypy-1.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:df67fbeb666ee8828f675fee724cc2cbd2e4828cc3df56703e02fe6a421b7401"}, + {file = "mypy-1.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a79cdc12a02eb526d808a32a934c6fe6df07b05f3573d210e41808020aed8b5d"}, + {file = "mypy-1.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f65f385a6f43211effe8c682e8ec3f55d79391f70a201575def73d08db68ead1"}, + {file = "mypy-1.7.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e81ffd120ee24959b449b647c4b2fbfcf8acf3465e082b8d58fd6c4c2b27e46"}, + {file = "mypy-1.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:f29386804c3577c83d76520abf18cfcd7d68264c7e431c5907d250ab502658ee"}, + {file = "mypy-1.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:87c076c174e2c7ef8ab416c4e252d94c08cd4980a10967754f91571070bf5fbe"}, + {file = "mypy-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6cb8d5f6d0fcd9e708bb190b224089e45902cacef6f6915481806b0c77f7786d"}, + {file = "mypy-1.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93e76c2256aa50d9c82a88e2f569232e9862c9982095f6d54e13509f01222fc"}, + {file = "mypy-1.7.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cddee95dea7990e2215576fae95f6b78a8c12f4c089d7e4367564704e99118d3"}, + {file = "mypy-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:d01921dbd691c4061a3e2ecdbfbfad029410c5c2b1ee88946bf45c62c6c91210"}, + {file = "mypy-1.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:185cff9b9a7fec1f9f7d8352dff8a4c713b2e3eea9c6c4b5ff7f0edf46b91e41"}, + {file = "mypy-1.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a7b1e399c47b18feb6f8ad4a3eef3813e28c1e871ea7d4ea5d444b2ac03c418"}, + {file = "mypy-1.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc9fe455ad58a20ec68599139ed1113b21f977b536a91b42bef3ffed5cce7391"}, + {file = "mypy-1.7.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d0fa29919d2e720c8dbaf07d5578f93d7b313c3e9954c8ec05b6d83da592e5d9"}, + {file = "mypy-1.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b53655a295c1ed1af9e96b462a736bf083adba7b314ae775563e3fb4e6795f5"}, + {file = "mypy-1.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c1b06b4b109e342f7dccc9efda965fc3970a604db70f8560ddfdee7ef19afb05"}, + {file = "mypy-1.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bf7a2f0a6907f231d5e41adba1a82d7d88cf1f61a70335889412dec99feeb0f8"}, + {file = "mypy-1.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:551d4a0cdcbd1d2cccdcc7cb516bb4ae888794929f5b040bb51aae1846062901"}, + {file = "mypy-1.7.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:55d28d7963bef00c330cb6461db80b0b72afe2f3c4e2963c99517cf06454e665"}, + {file = "mypy-1.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:870bd1ffc8a5862e593185a4c169804f2744112b4a7c55b93eb50f48e7a77010"}, + {file = "mypy-1.7.0-py3-none-any.whl", hash = "sha256:96650d9a4c651bc2a4991cf46f100973f656d69edc7faf91844e87fe627f7e96"}, + {file = "mypy-1.7.0.tar.gz", hash = "sha256:1e280b5697202efa698372d2f39e9a6713a0395a756b1c6bd48995f8d72690dc"}, ] [package.dependencies] -mypy-extensions = ">=0.4.3" -tomli = ">=1.1.0" -typing-extensions = ">=3.10" +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.1.0" [package.extras] dmypy = ["psutil (>=4.0)"] -python2 = ["typed-ast (>=1.4.0,<2)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] [[package]] name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -category = "dev" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -2292,4 +2317,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "fc879d686ae6b2d2eedcde15487706afc24e62d3c5fd2a2ca0181867f18f7593" +content-hash = "5fc096eea3c4b4688a49b7e41849119f07b949b6f941b6f65dc8579ec1b289e8" diff --git a/pyproject.toml b/pyproject.toml index a0b26c83..71449c0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,17 +79,19 @@ websockets = "10.4" yarl = "1.8.2" zstandard = "0.18.0" jinja2 = "<3.1" +mypy = "^1.7.0" [tool.poetry.dev-dependencies] pytest = "^7.2.0" flake8 = "6.0.0" -mypy = "^0.931" black = "^23.3.0" pytest-asyncio = "^0.21.0" mkdocs = "^1.3.0" mkdocstrings = "^0.17.0" mkdocs-material = "^8.1.8" bump2version = "^1.0.1" +autopep8 = "^2.0.4" +mypy = "^1.7.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/src/driftpy/_types.py b/src/driftpy/_types.py index 1e11393d..531311e6 100644 --- a/src/driftpy/_types.py +++ b/src/driftpy/_types.py @@ -1,3 +1,5 @@ +from driftpy.constants.numeric_constants import SPOT_RATE_PRECISION +from driftpy.types import OracleSource from typing import Optional, Any from dataclasses import dataclass from sumtypes import constructor # type: ignore @@ -287,9 +289,10 @@ class MarketPosition: padding3: int padding4: int - ## dw why this doesnt register :( + # dw why this doesnt register :( # def is_available(self): - # return self.base_asset_amount == 0 and self.open_orders == 0 and self.lp_shares == 0 + # return self.base_asset_amount == 0 and self.open_orders == 0 and + # self.lp_shares == 0 @dataclass @@ -377,14 +380,14 @@ class AMM: quote_asset_amount_long: int = 0 quote_asset_amount_short: int = 0 - ## lp stuff + # lp stuff cumulative_funding_payment_per_lp: int = 0 cumulative_fee_per_lp: int = 0 cumulative_base_asset_amount_with_amm_per_lp: int = 0 lp_cooldown_time: int = 0 user_lp_shares: int = 0 - ## funding + # funding last_funding_rate: int = 0 last_funding_rate_ts: int = 0 funding_period: int = 0 @@ -397,11 +400,11 @@ class AMM: last_mark_price_twap: int = 0 last_mark_price_twap_ts: int = 0 - ## trade constraints + # trade constraints minimum_quote_asset_trade_size: int = 0 base_asset_amount_step_size: int = 0 - ## market making + # market making base_spread: int = 0 long_spread: int = 0 short_spread: int = 0 @@ -420,7 +423,7 @@ class AMM: short_intensity_volume: int = 0 curve_update_intensity: int = 0 - ## fee tracking + # fee tracking total_fee: int = 0 total_mm_fee: int = 0 total_exchange_fee: int = 0 @@ -462,10 +465,6 @@ class Market: padding4: int = 0 -from driftpy.types import OracleSource -from driftpy.constants.numeric_constants import SPOT_RATE_PRECISION - - @dataclass class SpotMarket: mint: PublicKey # this diff --git a/src/driftpy/accounts/__init__.py b/src/driftpy/accounts/__init__.py index 4c8643a9..a7722442 100644 --- a/src/driftpy/accounts/__init__.py +++ b/src/driftpy/accounts/__init__.py @@ -1,2 +1,2 @@ from .get_accounts import * -from .types import * \ No newline at end of file +from .types import * diff --git a/src/driftpy/accounts/cache/__init__.py b/src/driftpy/accounts/cache/__init__.py index f31516fc..58e298b4 100644 --- a/src/driftpy/accounts/cache/__init__.py +++ b/src/driftpy/accounts/cache/__init__.py @@ -1,2 +1,2 @@ from .drift_client import * -from .user import * \ No newline at end of file +from .user import * diff --git a/src/driftpy/accounts/cache/drift_client.py b/src/driftpy/accounts/cache/drift_client.py index 37db66e3..29c20eb9 100644 --- a/src/driftpy/accounts/cache/drift_client.py +++ b/src/driftpy/accounts/cache/drift_client.py @@ -2,8 +2,11 @@ from solders.pubkey import Pubkey from solana.rpc.commitment import Commitment -from driftpy.accounts import get_state_account_and_slot, get_spot_market_account_and_slot, \ - get_perp_market_account_and_slot +from driftpy.accounts import ( + get_state_account_and_slot, + get_spot_market_account_and_slot, + get_perp_market_account_and_slot, +) from driftpy.accounts.oracle import get_oracle_price_data_and_slot from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot from typing import Optional @@ -28,30 +31,37 @@ async def update_cache(self): spot_markets = [] for i in range(state_and_slot.data.number_of_spot_markets): - spot_market_and_slot = await get_spot_market_account_and_slot(self.program, i) + spot_market_and_slot = await get_spot_market_account_and_slot( + self.program, i + ) spot_markets.append(spot_market_and_slot) oracle_price_data_and_slot = await get_oracle_price_data_and_slot( self.program.provider.connection, spot_market_and_slot.data.oracle, - spot_market_and_slot.data.oracle_source - + spot_market_and_slot.data.oracle_source, ) - oracle_data[str(spot_market_and_slot.data.oracle)] = oracle_price_data_and_slot + oracle_data[ + str(spot_market_and_slot.data.oracle) + ] = oracle_price_data_and_slot self.cache["spot_markets"] = spot_markets perp_markets = [] for i in range(state_and_slot.data.number_of_markets): - perp_market_and_slot = await get_perp_market_account_and_slot(self.program, i) + perp_market_and_slot = await get_perp_market_account_and_slot( + self.program, i + ) perp_markets.append(perp_market_and_slot) oracle_price_data_and_slot = await get_oracle_price_data_and_slot( self.program.provider.connection, perp_market_and_slot.data.amm.oracle, - perp_market_and_slot.data.amm.oracle_source + perp_market_and_slot.data.amm.oracle_source, ) - oracle_data[str(perp_market_and_slot.data.amm.oracle)] = oracle_price_data_and_slot + oracle_data[ + str(perp_market_and_slot.data.amm.oracle) + ] = oracle_price_data_and_slot self.cache["perp_markets"] = perp_markets @@ -61,15 +71,21 @@ async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]: await self.cache_if_needed() return self.cache["state"] - async def get_perp_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[PerpMarket]]: + async def get_perp_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[PerpMarket]]: await self.cache_if_needed() return self.cache["perp_markets"][market_index] - async def get_spot_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[SpotMarket]]: + async def get_spot_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[SpotMarket]]: await self.cache_if_needed() return self.cache["spot_markets"][market_index] - async def get_oracle_data_and_slot(self, oracle: Pubkey) -> Optional[DataAndSlot[OraclePriceData]]: + async def get_oracle_data_and_slot( + self, oracle: Pubkey + ) -> Optional[DataAndSlot[OraclePriceData]]: await self.cache_if_needed() return self.cache["oracle_price_data"][str(oracle)] diff --git a/src/driftpy/accounts/cache/user.py b/src/driftpy/accounts/cache/user.py index 357f393f..18a3c96f 100644 --- a/src/driftpy/accounts/cache/user.py +++ b/src/driftpy/accounts/cache/user.py @@ -10,7 +10,12 @@ class CachedUserAccountSubscriber(UserAccountSubscriber): - def __init__(self, user_pubkey: Pubkey, program: Program, commitment: Commitment = "confirmed"): + def __init__( + self, + user_pubkey: Pubkey, + program: Program, + commitment: Commitment = "confirmed", + ): self.program = program self.commitment = commitment self.user_pubkey = user_pubkey diff --git a/src/driftpy/accounts/get_accounts.py b/src/driftpy/accounts/get_accounts.py index 9a56dfcb..31fc8837 100644 --- a/src/driftpy/accounts/get_accounts.py +++ b/src/driftpy/accounts/get_accounts.py @@ -9,8 +9,9 @@ from .types import DataAndSlot, T -async def get_account_data_and_slot(address: Pubkey, program: Program, commitment: Commitment = "processed") -> Optional[ - DataAndSlot[T]]: +async def get_account_data_and_slot( + address: Pubkey, program: Program, commitment: Commitment = "processed" +) -> Optional[DataAndSlot[T]]: account_info = await program.provider.connection.get_account_info( address, encoding="base64", @@ -38,7 +39,7 @@ async def get_state_account(program: Program) -> State: async def get_if_stake_account( - program: Program, authority: Pubkey, spot_market_index: int + program: Program, authority: Pubkey, spot_market_index: int ) -> InsuranceFundStake: if_stake_pk = get_insurance_fund_stake_public_key( program.program_id, authority, spot_market_index @@ -48,8 +49,8 @@ async def get_if_stake_account( async def get_user_stats_account( - program: Program, - authority: Pubkey, + program: Program, + authority: Pubkey, ) -> UserStats: user_stats_public_key = get_user_stats_account_public_key( program.program_id, @@ -58,21 +59,27 @@ async def get_user_stats_account( response = await program.account["UserStats"].fetch(user_stats_public_key) return cast(UserStats, response) + async def get_user_account_and_slot( - program: Program, - user_public_key: Pubkey, + program: Program, + user_public_key: Pubkey, ) -> DataAndSlot[User]: return await get_account_data_and_slot(user_public_key, program) + async def get_user_account( - program: Program, - user_public_key: Pubkey, + program: Program, + user_public_key: Pubkey, ) -> User: return (await get_user_account_and_slot(program, user_public_key)).data -async def get_perp_market_account_and_slot(program: Program, market_index: int) -> Optional[DataAndSlot[PerpMarket]]: - perp_market_public_key = get_perp_market_public_key(program.program_id, market_index) +async def get_perp_market_account_and_slot( + program: Program, market_index: int +) -> Optional[DataAndSlot[PerpMarket]]: + perp_market_public_key = get_perp_market_public_key( + program.program_id, market_index + ) return await get_account_data_and_slot(perp_market_public_key, program) @@ -85,7 +92,7 @@ async def get_all_perp_market_accounts(program: Program) -> list[ProgramAccount] async def get_spot_market_account_and_slot( - program: Program, spot_market_index: int + program: Program, spot_market_index: int ) -> DataAndSlot[SpotMarket]: spot_market_public_key = get_spot_market_public_key( program.program_id, spot_market_index @@ -94,7 +101,7 @@ async def get_spot_market_account_and_slot( async def get_spot_market_account( - program: Program, spot_market_index: int + program: Program, spot_market_index: int ) -> SpotMarket: return (await get_spot_market_account_and_slot(program, spot_market_index)).data diff --git a/src/driftpy/accounts/oracle.py b/src/driftpy/accounts/oracle.py index f6c8c373..5e2a28c8 100644 --- a/src/driftpy/accounts/oracle.py +++ b/src/driftpy/accounts/oracle.py @@ -10,20 +10,25 @@ import base64 import struct + def convert_pyth_price(price, scale=1): return int(price * PRICE_PRECISION * scale) -async def get_oracle_price_data_and_slot(connection: AsyncClient, address: Pubkey, oracle_source=OracleSource.PYTH()) -> DataAndSlot[ - OraclePriceData]: - if 'Pyth' in str(oracle_source): + +async def get_oracle_price_data_and_slot( + connection: AsyncClient, address: Pubkey, oracle_source=OracleSource.PYTH() +) -> DataAndSlot[OraclePriceData]: + if "Pyth" in str(oracle_source): rpc_reponse = await connection.get_account_info(address) rpc_response_slot = rpc_reponse.context.slot - (pyth_price_info, last_slot, twac, twap) = await _parse_pyth_price_info(rpc_reponse) + (pyth_price_info, last_slot, twac, twap) = await _parse_pyth_price_info( + rpc_reponse + ) scale = 1 - if '1K' in str(oracle_source): + if "1K" in str(oracle_source): scale = 1e3 - elif '1M' in str(oracle_source): + elif "1M" in str(oracle_source): scale = 1e6 oracle_data = OraclePriceData( @@ -36,29 +41,41 @@ async def get_oracle_price_data_and_slot(connection: AsyncClient, address: Pubke ) return DataAndSlot(data=oracle_data, slot=rpc_response_slot) - elif 'Quote' in str(oracle_source): - return DataAndSlot(data=OraclePriceData(PRICE_PRECISION, 0, 1, 1, 0, True), slot=0) + elif "Quote" in str(oracle_source): + return DataAndSlot( + data=OraclePriceData(PRICE_PRECISION, 0, 1, 1, 0, True), slot=0 + ) else: - raise NotImplementedError('Unsupported Oracle Source', str(oracle_source)) + raise NotImplementedError("Unsupported Oracle Source", str(oracle_source)) -async def _parse_pyth_price_info(resp: GetAccountInfoResp) -> (PythPriceInfo, int, int, int): + +async def _parse_pyth_price_info( + resp: GetAccountInfoResp, +) -> (PythPriceInfo, int, int, int): buffer = resp.value.data offset = _ACCOUNT_HEADER_BYTES _, exponent, _ = struct.unpack_from(" Optional[DataAndSlot[State]]: pass @abstractmethod - async def get_perp_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[PerpMarket]]: + async def get_perp_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[PerpMarket]]: pass @abstractmethod - async def get_spot_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[SpotMarket]]: + async def get_spot_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[SpotMarket]]: pass @abstractmethod - async def get_oracle_data_and_slot(self, oracle: Pubkey) -> Optional[DataAndSlot[OraclePriceData]]: + async def get_oracle_data_and_slot( + self, oracle: Pubkey + ) -> Optional[DataAndSlot[OraclePriceData]]: pass + class UserAccountSubscriber: @abstractmethod async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: - pass \ No newline at end of file + pass diff --git a/src/driftpy/addresses.py b/src/driftpy/addresses.py index 048450bf..8aa58cf1 100644 --- a/src/driftpy/addresses.py +++ b/src/driftpy/addresses.py @@ -77,9 +77,7 @@ def get_user_stats_account_public_key( program_id: Pubkey, authority: Pubkey, ) -> Pubkey: - return Pubkey.find_program_address( - [b"user_stats", bytes(authority)], program_id - )[0] + return Pubkey.find_program_address([b"user_stats", bytes(authority)], program_id)[0] def get_user_account_public_key( diff --git a/src/driftpy/admin.py b/src/driftpy/admin.py index 9c8c634b..ddd4bb4e 100644 --- a/src/driftpy/admin.py +++ b/src/driftpy/admin.py @@ -1,4 +1,3 @@ - from solders.pubkey import Pubkey from solders.signature import Signature from solders.keypair import Keypair @@ -74,9 +73,7 @@ async def initialize( "admin": self.authority, "state": state_public_key, "quote_asset_mint": usdc_mint, - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "rent": RENT, "system_program": ID, "token_program": TOKEN_PROGRAM_ID, @@ -181,9 +178,7 @@ async def initialize_spot_market( "spot_market": spot_public_key, "spot_market_vault": spot_vault_public_key, "insurance_fund_vault": insurance_vault_public_key, - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "spot_market_mint": mint, "oracle": oracle, "rent": RENT, diff --git a/src/driftpy/constants/config.py b/src/driftpy/constants/config.py index 3ac44389..0329373d 100644 --- a/src/driftpy/constants/config.py +++ b/src/driftpy/constants/config.py @@ -25,7 +25,9 @@ class Config: drift_client_program_id=Pubkey.from_string( "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH" ), - usdc_mint_address=Pubkey.from_string("8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2"), + usdc_mint_address=Pubkey.from_string( + "8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2" + ), default_http="https://api.devnet.solana.com", default_ws="wss://api.devnet.solana.com", markets=devnet_markets, @@ -39,7 +41,9 @@ class Config: drift_client_program_id=Pubkey.from_string( "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH" ), - usdc_mint_address=Pubkey.from_string("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"), + usdc_mint_address=Pubkey.from_string( + "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v" + ), default_http="https://api.mainnet-beta.solana.com", default_ws="wss://api.mainnet-beta.solana.com", markets=mainnet_markets, diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index f364ef6c..d96e232e 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -33,12 +33,21 @@ DEFAULT_USER_NAME = "Main Account" + class DriftClient: """This class is the main way to interact with Drift Protocol including depositing, opening new positions, closing positions, placing orders, etc. """ - def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None, tx_params: Optional[TxParams] = None, tx_version: Optional[TransactionVersion] = None): + def __init__( + self, + program: Program, + signer: Keypair = None, + authority: Pubkey = None, + account_subscriber: Optional[DriftClientAccountSubscriber] = None, + tx_params: Optional[TxParams] = None, + tx_version: Optional[TransactionVersion] = None, + ): """Initializes the drift client object -- likely want to use the .from_config method instead of this one Args: @@ -112,7 +121,9 @@ def get_user_account_public_key(self, user_id=0) -> Pubkey: return get_user_account_public_key(self.program_id, self.authority, user_id) async def get_user(self, user_id=0) -> User: - return await get_user_account(self.program, self.get_user_account_public_key(user_id)) + return await get_user_account( + self.program, self.get_user_account_public_key(user_id) + ) def get_state_public_key(self): return get_state_public_key(self.program_id) @@ -122,19 +133,25 @@ def get_user_stats_public_key(self): async def get_state(self) -> Optional[State]: state_and_slot = await self.account_subscriber.get_state_account_and_slot() - return getattr(state_and_slot, 'data', None) + return getattr(state_and_slot, "data", None) async def get_perp_market(self, market_index: int) -> Optional[PerpMarket]: - perp_market_and_slot = await self.account_subscriber.get_perp_market_and_slot(market_index) - return getattr(perp_market_and_slot, 'data', None) + perp_market_and_slot = await self.account_subscriber.get_perp_market_and_slot( + market_index + ) + return getattr(perp_market_and_slot, "data", None) async def get_spot_market(self, market_index: int) -> Optional[SpotMarket]: - spot_market_and_slot = await self.account_subscriber.get_spot_market_and_slot(market_index) - return getattr(spot_market_and_slot, 'data', None) + spot_market_and_slot = await self.account_subscriber.get_spot_market_and_slot( + market_index + ) + return getattr(spot_market_and_slot, "data", None) async def get_oracle_price_data(self, oracle: Pubkey) -> Optional[OraclePriceData]: - oracle_price_data_and_slot = await self.account_subscriber.get_oracle_data_and_slot(oracle) - return getattr(oracle_price_data_and_slot, 'data', None) + oracle_price_data_and_slot = ( + await self.account_subscriber.get_oracle_data_and_slot(oracle) + ) + return getattr(oracle_price_data_and_slot, "data", None) async def send_ixs( self, @@ -150,13 +167,15 @@ async def send_ixs( if self.tx_params.compute_units_price is not None: ixs.insert(1, set_compute_unit_price(self.tx_params.compute_units_price)) - latest_blockhash = (await self.program.provider.connection.get_latest_blockhash()).value.blockhash + latest_blockhash = ( + await self.program.provider.connection.get_latest_blockhash() + ).value.blockhash if self.tx_version == Legacy: tx = Transaction( instructions=ixs, recent_blockhash=latest_blockhash, - fee_payer=self.signer.pubkey() + fee_payer=self.signer.pubkey(), ) tx.sign_partial(self.signer) @@ -164,9 +183,7 @@ async def send_ixs( if signers is not None: [tx.sign_partial(signer) for signer in signers] elif self.tx_version == 0: - msg = MessageV0.try_compile( - self.signer.pubkey(), ixs, [], latest_blockhash - ) + msg = MessageV0.try_compile(self.signer.pubkey(), ixs, [], latest_blockhash) tx = VersionedTransaction(msg, [self.signer]) else: raise NotImplementedError("unknown tx version", self.tx_version) @@ -268,7 +285,9 @@ async def get_remaining_accounts( accounts = [] for pk, id in zip(authority, user_id): - user_public_key = get_user_account_public_key(self.program.program_id, pk, id) + user_public_key = get_user_account_public_key( + self.program.program_id, pk, id + ) user_account = await get_user_account(self.program, user_public_key) accounts.append(user_account) @@ -1528,9 +1547,7 @@ async def get_cancel_request_remove_insurance_fund_stake_ix( "insurance_fund_vault": get_insurance_fund_vault_public_key( self.program_id, spot_market_index ), - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "user_token_account": self.spot_market_atas[spot_market_index], "token_program": TOKEN_PROGRAM_ID, }, @@ -1567,9 +1584,7 @@ async def get_remove_insurance_fund_stake_ix(self, spot_market_index: int): "insurance_fund_vault": get_insurance_fund_vault_public_key( self.program_id, spot_market_index ), - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "user_token_account": self.spot_market_atas[spot_market_index], "token_program": TOKEN_PROGRAM_ID, }, @@ -1617,9 +1632,7 @@ async def get_add_insurance_fund_stake_ix( "insurance_fund_vault": get_insurance_fund_vault_public_key( self.program_id, spot_market_index ), - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "user_token_account": self.spot_market_atas[spot_market_index], "token_program": TOKEN_PROGRAM_ID, }, @@ -1715,9 +1728,7 @@ async def settle_revenue_to_insurance_fund(self, spot_market_index: int): "spot_market_vault": get_spot_market_vault_public_key( self.program_id, spot_market_index ), - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "insurance_fund_vault": get_insurance_fund_vault_public_key( self.program_id, spot_market_index, diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index c9934636..041a64f2 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -35,18 +35,25 @@ def __init__( self.connection = self.program.provider.connection self.subaccount_id = subaccount_id - self.user_public_key = get_user_account_public_key(self.program.program_id, self.authority, self.subaccount_id) + self.user_public_key = get_user_account_public_key( + self.program.program_id, self.authority, self.subaccount_id + ) if account_subscriber is None: - account_subscriber = CachedUserAccountSubscriber(self.user_public_key, self.program) + account_subscriber = CachedUserAccountSubscriber( + self.user_public_key, self.program + ) self.account_subscriber = account_subscriber - - async def get_spot_oracle_data(self, spot_market: SpotMarket) -> Optional[OraclePriceData]: + async def get_spot_oracle_data( + self, spot_market: SpotMarket + ) -> Optional[OraclePriceData]: return await self.drift_client.get_oracle_price_data(spot_market.oracle) - async def get_perp_oracle_data(self, perp_market: PerpMarket) -> Optional[OraclePriceData]: + async def get_perp_oracle_data( + self, perp_market: PerpMarket + ) -> Optional[OraclePriceData]: return await self.drift_client.get_oracle_price_data(perp_market.amm.oracle) async def get_state(self) -> State: @@ -356,11 +363,11 @@ async def get_spot_market_asset_value( if not include_open_orders: token_amount = get_token_amount( - position.scaled_balance, spot_market, position.balance_type - ) + position.scaled_balance, spot_market, position.balance_type + ) spot_token_value = get_spot_asset_value( - token_amount, oracle_data, spot_market, margin_category - ) + token_amount, oracle_data, spot_market, margin_category + ) match str(position.balance_type): case "SpotBalanceType.Deposit()": spot_token_value *= 1 diff --git a/src/driftpy/math/amm.py b/src/driftpy/math/amm.py index 7470143f..ccb6be43 100644 --- a/src/driftpy/math/amm.py +++ b/src/driftpy/math/amm.py @@ -169,7 +169,7 @@ def get_swap_direction( def calculate_budgeted_repeg(amm, cost, target_px=None, pay_only=False): - if target_px == None: + if target_px is None: target_px = amm.last_oracle_price # / 1e10 assert amm.last_oracle_price != 0 diff --git a/src/driftpy/math/market.py b/src/driftpy/math/market.py index c793bbe9..410d800d 100644 --- a/src/driftpy/math/market.py +++ b/src/driftpy/math/market.py @@ -90,7 +90,8 @@ def calculate_candidate_amm(market, oracle_price=None): base_scale = 1 quote_scale = 1 - budget_cost = None # max(0, (market.amm.total_fee_minus_distributions/1e6)/2) + # max(0, (market.amm.total_fee_minus_distributions/1e6)/2) + budget_cost = None fee_pool = (market.amm.total_fee_minus_distributions / QUOTE_PRECISION) - ( market.amm.total_fee / QUOTE_PRECISION ) / 2 diff --git a/src/driftpy/math/repeg.py b/src/driftpy/math/repeg.py index 1539017c..dbf8c4ce 100644 --- a/src/driftpy/math/repeg.py +++ b/src/driftpy/math/repeg.py @@ -1,3 +1,4 @@ +from driftpy.constants.numeric_constants import * from driftpy.math.amm import calculate_terminal_price, calculate_budgeted_repeg from driftpy.math.positions import calculate_base_asset_value, calculate_position_pnl from driftpy.types import PerpPosition @@ -216,10 +217,6 @@ def calculate_buyout_cost(market, market_index, new_peg, sqrt_k): return cost / 1e6, marketNewK -from driftpy.types import AMM -from driftpy.constants.numeric_constants import * - - def calculate_repeg_cost(amm: AMM, new_peg: int) -> int: dqar = amm.quote_asset_reserve - amm.terminal_quote_asset_reserve cost = ( diff --git a/src/driftpy/setup/helpers.py b/src/driftpy/setup/helpers.py index 6d9c07a4..b66402db 100644 --- a/src/driftpy/setup/helpers.py +++ b/src/driftpy/setup/helpers.py @@ -1,3 +1,4 @@ +from solana.rpc.async_api import AsyncClient from base64 import b64decode from dataclasses import dataclass from typing import Optional @@ -63,9 +64,7 @@ async def _airdrop_user( ) -> tuple[Keypair, Signature]: if user is None: user = Keypair() - resp = await provider.connection.request_airdrop( - user.pubkey(), 100_0 * 1000000000 - ) + resp = await provider.connection.request_airdrop(user.pubkey(), 100_0 * 1000000000) tx_sig = resp.value return user, tx_sig @@ -94,7 +93,9 @@ async def _create_mint(provider: Provider) -> Keypair: fake_tx = Transaction( instructions=[create_create_mint_account_ix, init_collateral_mint_ix], - recent_blockhash=(await provider.connection.get_latest_blockhash()).value.blockhash, + recent_blockhash=( + await provider.connection.get_latest_blockhash() + ).value.blockhash, fee_payer=provider.wallet.public_key, ) @@ -195,7 +196,9 @@ async def _create_and_mint_user_usdc( for ix in mint_tx.instructions: ata_tx.add(ix) - ata_tx.recent_blockhash = (await provider.connection.get_latest_blockhash()).value.blockhash + ata_tx.recent_blockhash = ( + await provider.connection.get_latest_blockhash() + ).value.blockhash ata_tx.fee_payer = provider.wallet.payer.pubkey() ata_tx.sign_partial(usdc_account) @@ -307,9 +310,6 @@ async def get_feed_data(oracle_program: Program, price_feed: Pubkey) -> PriceDat return parse_price_data(info_resp.value.data) -from solana.rpc.async_api import AsyncClient - - async def get_oracle_data( connection: AsyncClient, oracle_addr: Pubkey, diff --git a/src/driftpy/types.py b/src/driftpy/types.py index 8b037de7..210bb822 100644 --- a/src/driftpy/types.py +++ b/src/driftpy/types.py @@ -4,48 +4,57 @@ from sumtypes import constructor from typing import Optional + @_rust_enum class SwapDirection: ADD = constructor() REMOVE = constructor() - + + @_rust_enum class ModifyOrderId: USER_ORDER_ID = constructor() ORDER_ID = constructor() - + + @_rust_enum class PositionDirection: LONG = constructor() SHORT = constructor() - + + @_rust_enum class SpotFulfillmentType: SERUM_V3 = constructor() MATCH = constructor() PHOENIX_V1 = constructor() - + + @_rust_enum class SwapReduceOnly: IN = constructor() OUT = constructor() - + + @_rust_enum class TwapPeriod: FUNDING_PERIOD = constructor() FIVE_MIN = constructor() - + + @_rust_enum class LiquidationMultiplierType: DISCOUNT = constructor() PREMIUM = constructor() - + + @_rust_enum class MarginRequirementType: INITIAL = constructor() FILL = constructor() MAINTENANCE = constructor() - + + @_rust_enum class OracleValidity: INVALID = constructor() @@ -55,7 +64,8 @@ class OracleValidity: INSUFFICIENT_DATA_POINTS = constructor() STALE_FOR_A_M_M = constructor() VALID = constructor() - + + @_rust_enum class DriftAction: UPDATE_FUNDING = constructor() @@ -67,7 +77,8 @@ class DriftAction: MARGIN_CALC = constructor() UPDATE_TWAP = constructor() UPDATE_A_M_M_CURVE = constructor() - + + @_rust_enum class PositionUpdateType: OPEN = constructor() @@ -75,17 +86,20 @@ class PositionUpdateType: REDUCE = constructor() CLOSE = constructor() FLIP = constructor() - + + @_rust_enum class DepositExplanation: NONE = constructor() TRANSFER = constructor() - + + @_rust_enum class DepositDirection: DEPOSIT = constructor() WITHDRAW = constructor() - + + @_rust_enum class OrderAction: PLACE = constructor() @@ -93,7 +107,8 @@ class OrderAction: FILL = constructor() TRIGGER = constructor() EXPIRE = constructor() - + + @_rust_enum class OrderActionExplanation: NONE = constructor() @@ -114,13 +129,15 @@ class OrderActionExplanation: ORDER_FILL_WITH_PHOENIX = constructor() ORDER_FILLED_WITH_A_M_M_JIT_L_P_SPLIT = constructor() ORDER_FILLED_WITH_L_P_JIT = constructor() - + + @_rust_enum class LPAction: ADD_LIQUIDITY = constructor() REMOVE_LIQUIDITY = constructor() SETTLE_LIQUIDITY = constructor() - + + @_rust_enum class LiquidationType: LIQUIDATE_PERP = constructor() @@ -129,12 +146,14 @@ class LiquidationType: LIQUIDATE_PERP_PNL_FOR_DEPOSIT = constructor() PERP_BANKRUPTCY = constructor() SPOT_BANKRUPTCY = constructor() - + + @_rust_enum class SettlePnlExplanation: NONE = constructor() EXPIRED_POSITION = constructor() - + + @_rust_enum class StakeAction: STAKE = constructor() @@ -143,28 +162,33 @@ class StakeAction: UNSTAKE = constructor() UNSTAKE_TRANSFER = constructor() STAKE_TRANSFER = constructor() - + + @_rust_enum class FillMode: FILL = constructor() PLACE_AND_MAKE = constructor() PLACE_AND_TAKE = constructor() - + + @_rust_enum class PerpFulfillmentMethod: A_M_M = constructor() MATCH = constructor() - + + @_rust_enum class SpotFulfillmentMethod: EXTERNAL_MARKET = constructor() MATCH = constructor() - + + @_rust_enum class MarginCalculationMode: STANDARD = constructor() LIQUIDATION = constructor() - + + @_rust_enum class OracleSource: PYTH = constructor() @@ -173,19 +197,22 @@ class OracleSource: PYTH1_K = constructor() PYTH1_M = constructor() PYTH_STABLE_COIN = constructor() - + + @_rust_enum class PostOnlyParam: NONE = constructor() MUST_POST_ONLY = constructor() TRY_POST_ONLY = constructor() SLIDE = constructor() - + + @_rust_enum class ModifyOrderPolicy: TRY_MODIFY = constructor() MUST_MODIFY = constructor() - + + @_rust_enum class MarketStatus: INITIALIZED = constructor() @@ -197,12 +224,14 @@ class MarketStatus: REDUCE_ONLY = constructor() SETTLEMENT = constructor() DELISTED = constructor() - + + @_rust_enum class ContractType: PERPETUAL = constructor() FUTURE = constructor() - + + @_rust_enum class ContractTier: A = constructor() @@ -210,23 +239,27 @@ class ContractTier: C = constructor() SPECULATIVE = constructor() ISOLATED = constructor() - + + @_rust_enum class AMMLiquiditySplit: PROTOCOL_OWNED = constructor() L_P_OWNED = constructor() SHARED = constructor() - + + @_rust_enum class SpotBalanceType: DEPOSIT = constructor() BORROW = constructor() - + + @_rust_enum class SpotFulfillmentConfigStatus: ENABLED = constructor() DISABLED = constructor() - + + @_rust_enum class AssetTier: COLLATERAL = constructor() @@ -234,7 +267,8 @@ class AssetTier: CROSS = constructor() ISOLATED = constructor() UNLISTED = constructor() - + + @_rust_enum class ExchangeStatus: DEPOSIT_PAUSED = constructor() @@ -244,25 +278,29 @@ class ExchangeStatus: LIQ_PAUSED = constructor() FUNDING_PAUSED = constructor() SETTLE_PNL_PAUSED = constructor() - + + @_rust_enum class UserStatus: BEING_LIQUIDATED = constructor() BANKRUPT = constructor() REDUCE_ONLY = constructor() - + + @_rust_enum class AssetType: BASE = constructor() QUOTE = constructor() - + + @_rust_enum class OrderStatus: INIT = constructor() OPEN = constructor() FILLED = constructor() CANCELED = constructor() - + + @_rust_enum class OrderType: MARKET = constructor() @@ -270,24 +308,28 @@ class OrderType: TRIGGER_MARKET = constructor() TRIGGER_LIMIT = constructor() ORACLE = constructor() - + + @_rust_enum class OrderTriggerCondition: ABOVE = constructor() BELOW = constructor() TRIGGERED_ABOVE = constructor() TRIGGERED_BELOW = constructor() - + + @_rust_enum class MarketType: SPOT = constructor() PERP = constructor() - + + @dataclass class MarketIdentifier: market_type: MarketType market_index: int - + + @dataclass class OrderParams: order_type: OrderType @@ -307,7 +349,8 @@ class OrderParams: auction_duration: Optional[int] auction_start_price: Optional[int] auction_end_price: Optional[int] - + + @dataclass class ModifyOrderParams: direction: Optional[PositionDirection] @@ -324,7 +367,8 @@ class ModifyOrderParams: auction_start_price: Optional[int] auction_end_price: Optional[int] policy: Optional[ModifyOrderPolicy] - + + @dataclass class HistoricalOracleData: last_oracle_price: int @@ -333,13 +377,15 @@ class HistoricalOracleData: last_oracle_price_twap: int last_oracle_price_twap5min: int last_oracle_price_twap_ts: int - + + @dataclass class PoolBalance: scaled_balance: int market_index: int padding: list[int] - + + @dataclass class AMM: oracle: Pubkey @@ -423,24 +469,28 @@ class AMM: padding2: int total_fee_earned_per_lp: int padding: list[int] - + + @dataclass class PriceDivergenceGuardRails: mark_oracle_percent_divergence: int oracle_twap5min_percent_divergence: int - + + @dataclass class ValidityGuardRails: slots_before_stale_for_amm: int slots_before_stale_for_margin: int confidence_interval_max_size: int too_volatile_ratio: int - + + @dataclass class OracleGuardRails: price_divergence: PriceDivergenceGuardRails validity: ValidityGuardRails - + + @dataclass class FeeTier: fee_numerator: int @@ -451,20 +501,23 @@ class FeeTier: referrer_reward_denominator: int referee_fee_numerator: int referee_fee_denominator: int - + + @dataclass class OrderFillerRewardStructure: reward_numerator: int reward_denominator: int time_based_reward_lower_bound: int - + + @dataclass class FeeStructure: fee_tiers: list[FeeTier] filler_reward_structure: OrderFillerRewardStructure referrer_reward_epoch_upper_bound: int flat_filler_fee: int - + + @dataclass class SpotPosition: scaled_balance: int @@ -475,7 +528,8 @@ class SpotPosition: balance_type: SpotBalanceType open_orders: int padding: list[int] - + + @dataclass class Order: slot: int @@ -502,7 +556,8 @@ class Order: trigger_condition: OrderTriggerCondition auction_duration: int padding: list[int] - + + @dataclass class PhoenixV1FulfillmentConfig: pubkey: Pubkey @@ -515,7 +570,8 @@ class PhoenixV1FulfillmentConfig: fulfillment_type: SpotFulfillmentType status: SpotFulfillmentConfigStatus padding: list[int] - + + @dataclass class SerumV3FulfillmentConfig: pubkey: Pubkey @@ -533,7 +589,8 @@ class SerumV3FulfillmentConfig: fulfillment_type: SpotFulfillmentType status: SpotFulfillmentConfigStatus padding: list[int] - + + @dataclass class InsuranceClaim: revenue_withdraw_since_last_settle: int @@ -541,7 +598,8 @@ class InsuranceClaim: quote_max_insurance: int quote_settled_insurance: int last_revenue_withdraw_ts: int - + + @dataclass class PerpMarket: pubkey: Pubkey @@ -573,7 +631,8 @@ class PerpMarket: quote_spot_market_index: int fee_adjustment: int padding: list[int] - + + @dataclass class HistoricalIndexData: last_index_bid_price: int @@ -581,7 +640,8 @@ class HistoricalIndexData: last_index_price_twap: int last_index_price_twap5min: int last_index_price_twap_ts: int - + + @dataclass class InsuranceFund: vault: Pubkey @@ -593,7 +653,8 @@ class InsuranceFund: revenue_settle_period: int total_factor: int user_factor: int - + + @dataclass class SpotMarket: pubkey: Pubkey @@ -649,7 +710,8 @@ class SpotMarket: total_swap_fee: int scale_initial_asset_weight_start: int padding: list[int] - + + @dataclass class State: admin: Pubkey @@ -675,7 +737,8 @@ class State: liquidation_duration: int initial_pct_to_liquidate: int padding: list[int] - + + @dataclass class PerpPosition: last_cumulative_funding_rate: int @@ -693,7 +756,8 @@ class PerpPosition: market_index: int open_orders: int per_lp_base: int - + + @dataclass class User: authority: Pubkey @@ -723,7 +787,8 @@ class User: open_auctions: int has_open_auction: bool padding: list[int] - + + @dataclass class UserFees: total_fee_paid: int @@ -732,7 +797,8 @@ class UserFees: total_referee_discount: int total_referrer_reward: int current_epoch_referrer_reward: int - + + @dataclass class UserStats: authority: Pubkey @@ -751,7 +817,8 @@ class UserStats: is_referrer: bool disable_update_perp_bid_ask_twap: bool padding: list[int] - + + @dataclass class LiquidatePerpRecord: market_index: int @@ -764,7 +831,8 @@ class LiquidatePerpRecord: liquidator_order_id: int liquidator_fee: int if_fee: int - + + @dataclass class LiquidateSpotRecord: asset_market_index: int @@ -774,7 +842,8 @@ class LiquidateSpotRecord: liability_price: int liability_transfer: int if_fee: int - + + @dataclass class LiquidateBorrowForPerpPnlRecord: perp_market_index: int @@ -783,7 +852,8 @@ class LiquidateBorrowForPerpPnlRecord: liability_market_index: int liability_price: int liability_transfer: int - + + @dataclass class LiquidatePerpPnlForDepositRecord: perp_market_index: int @@ -792,7 +862,8 @@ class LiquidatePerpPnlForDepositRecord: asset_market_index: int asset_price: int asset_transfer: int - + + @dataclass class PerpBankruptcyRecord: market_index: int @@ -801,14 +872,16 @@ class PerpBankruptcyRecord: clawback_user: Optional[Pubkey] clawback_user_payment: Optional[int] cumulative_funding_rate_delta: int - + + @dataclass class SpotBankruptcyRecord: market_index: int borrow_amount: int if_payment: int cumulative_deposit_interest_delta: int - + + @dataclass class InsuranceFundStake: authority: Pubkey @@ -821,7 +894,8 @@ class InsuranceFundStake: cost_basis: int market_index: int padding: list[int] - + + @dataclass class ProtocolIfSharesTransferConfig: whitelisted_signers: list[Pubkey] @@ -829,7 +903,8 @@ class ProtocolIfSharesTransferConfig: current_epoch_transfer: int next_epoch_ts: int padding: list[int] - + + @dataclass class ReferrerName: authority: Pubkey @@ -837,6 +912,7 @@ class ReferrerName: user_stats: Pubkey name: list[int] + @dataclass class OraclePriceData: price: int @@ -846,7 +922,8 @@ class OraclePriceData: twap_confidence: int has_sufficient_number_of_datapoints: bool + @dataclass class TxParams: compute_units: Optional[int] - compute_units_price: Optional[int] \ No newline at end of file + compute_units_price: Optional[int] diff --git a/tests/test.py b/tests/test.py index 58c9bfee..da38453c 100644 --- a/tests/test.py +++ b/tests/test.py @@ -132,9 +132,7 @@ async def test_initialized_spot_market_2( @async_fixture(scope="session") -async def initialized_market( - drift_client: Admin, workspace: WorkspaceType -) -> Pubkey: +async def initialized_market(drift_client: Admin, workspace: WorkspaceType) -> Pubkey: pyth_program = workspace["pyth"] sol_usd = await mock_oracle(pyth_program=pyth_program, price=1) perp_market_index = 0 @@ -176,10 +174,10 @@ async def test_init_user( drift_client: Admin, ): await drift_client.intialize_user() - user_public_key = get_user_account_public_key(drift_client.program.program_id, drift_client.authority, 0) - user: User = await get_user_account( - drift_client.program, user_public_key + user_public_key = get_user_account_public_key( + drift_client.program.program_id, drift_client.authority, 0 ) + user: User = await get_user_account(drift_client.program, user_public_key) assert user.authority == drift_client.authority @@ -189,7 +187,7 @@ async def test_usdc_deposit( user_usdc_account: Keypair, ): usdc_spot_market = await get_spot_market_account(drift_client.program, 0) - assert(usdc_spot_market.market_index == 0) + assert usdc_spot_market.market_index == 0 drift_client.spot_market_atas[0] = user_usdc_account.pubkey() await drift_client.deposit( USDC_AMOUNT, 0, user_usdc_account.pubkey(), user_initialized=True @@ -311,21 +309,23 @@ async def test_stake_if( await drift_client.update_update_insurance_fund_unstaking_period(0, 0) await drift_client.initialize_insurance_fund_stake(0) - if_acc = await get_if_stake_account( - drift_client.program, drift_client.authority, 0 - ) + if_acc = await get_if_stake_account(drift_client.program, drift_client.authority, 0) assert if_acc.market_index == 0 await drift_client.add_insurance_fund_stake(0, 1 * QUOTE_PRECISION) - user_stats = await get_user_stats_account(drift_client.program, drift_client.authority) + user_stats = await get_user_stats_account( + drift_client.program, drift_client.authority + ) assert user_stats.if_staked_quote_asset_amount == 1 * QUOTE_PRECISION await drift_client.request_remove_insurance_fund_stake(0, 1 * QUOTE_PRECISION) await drift_client.remove_insurance_fund_stake(0) - user_stats = await get_user_stats_account(drift_client.program, drift_client.authority) + user_stats = await get_user_stats_account( + drift_client.program, drift_client.authority + ) assert user_stats.if_staked_quote_asset_amount == 0