From 6f10adc42286abd8e890be5a32e073f4fee583dc Mon Sep 17 00:00:00 2001 From: sina <20732540+SinaKhalili@users.noreply.github.com> Date: Mon, 6 Jan 2025 16:33:27 -0800 Subject: [PATCH] Update jup ix to have direct routes --- examples/spot_market_trade.py | 93 +++++++++++++++++++++++++++++++++++ src/driftpy/drift_client.py | 56 +++++++++++++-------- src/driftpy/drift_user.py | 2 + 3 files changed, 130 insertions(+), 21 deletions(-) create mode 100644 examples/spot_market_trade.py diff --git a/examples/spot_market_trade.py b/examples/spot_market_trade.py new file mode 100644 index 00000000..9a3904a5 --- /dev/null +++ b/examples/spot_market_trade.py @@ -0,0 +1,93 @@ +import asyncio +import logging +import os + +from anchorpy.provider import Provider, Wallet +from dotenv import load_dotenv +from solana.rpc.async_api import AsyncClient + +from driftpy.constants.spot_markets import mainnet_spot_market_configs +from driftpy.drift_client import DriftClient +from driftpy.keypair import load_keypair +from driftpy.types import TxParams + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +load_dotenv() + + +def get_market_by_symbol(symbol: str): + for market in mainnet_spot_market_configs: + if market.symbol == symbol: + return market + raise Exception(f"Market {symbol} not found") + + +async def make_spot_trade(): + rpc = os.environ.get("RPC_TRITON") + secret = os.environ.get("PRIVATE_KEY") + kp = load_keypair(secret) + wallet = Wallet(kp) + logger.info(f"Using wallet: {wallet.public_key}") + + connection = AsyncClient(rpc) + provider = Provider(connection, wallet) + drift_client = DriftClient( + provider.connection, + provider.wallet, + "mainnet", + tx_params=TxParams( + compute_units_price=85_000, + compute_units=1_000_000, + ), + ) + await drift_client.subscribe() + logger.info("Drift client subscribed") + + in_decimals_result = drift_client.get_spot_market_account( + get_market_by_symbol("USDS").market_index + ) + if not in_decimals_result: + logger.error("USDS market not found") + raise Exception("Market not found") + + in_decimals = in_decimals_result.decimals + logger.info(f"USDS decimals: {in_decimals}") + + swap_amount = int(1 * 10**in_decimals) + logger.info(f"Swapping {swap_amount} USDS to USDC") + + try: + swap_ixs, swap_lookups = await drift_client.get_jupiter_swap_ix_v6( + out_market_idx=get_market_by_symbol("USDC").market_index, + in_market_idx=get_market_by_symbol("USDS").market_index, + amount=swap_amount, + swap_mode="ExactIn", + only_direct_routes=True, + ) + logger.info("Got swap instructions") + print("[DEBUG] Got swap instructions of length", len(swap_ixs)) + + await drift_client.send_ixs( + ixs=swap_ixs, + lookup_tables=swap_lookups, + ) + logger.info("Swap complete") + except Exception as e: + logger.error(f"Error during swap: {e}") + raise e + finally: + await drift_client.unsubscribe() + await connection.close() + + +if __name__ == "__main__": + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(make_spot_trade()) + finally: + pending = asyncio.all_tasks(loop) + loop.run_until_complete(asyncio.gather(*pending)) + loop.close() diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 53c956be..c49e748d 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -1,5 +1,4 @@ import base64 -import json import os import random import string @@ -3155,10 +3154,14 @@ async def get_jupiter_swap_ix_v6( amount: int, out_ata: Optional[Pubkey] = None, in_ata: Optional[Pubkey] = None, - slippage_bps: Optional[int] = None, - quote=None, + slippage_bps: int = 50, + quote: Optional[dict] = None, reduce_only: Optional[SwapReduceOnly] = None, user_account_public_key: Optional[Pubkey] = None, + swap_mode: str = "ExactIn", + fee_account: Optional[Pubkey] = None, + platform_fee_bps: Optional[int] = None, + only_direct_routes: bool = False, ) -> Tuple[list[Instruction], list[AddressLookupTableAccount]]: pre_instructions: list[Instruction] = [] JUPITER_URL = os.getenv("JUPITER_URL", "https://quote-api.jup.ag/v6") @@ -3166,26 +3169,38 @@ async def get_jupiter_swap_ix_v6( out_market = self.get_spot_market_account(out_market_idx) in_market = self.get_spot_market_account(in_market_idx) - if slippage_bps is None: - slippage_bps = 10 + if not out_market or not in_market: + raise Exception("Invalid market indexes") if quote is None: - url = f"{JUPITER_URL}/quote?inputMint={str(in_market.mint)}&outputMint={str(out_market.mint)}&amount={amount}&slippageBps={slippage_bps}" - + params = { + "inputMint": str(in_market.mint), + "outputMint": str(out_market.mint), + "amount": str(amount), + "slippageBps": slippage_bps, + "swapMode": swap_mode, + "maxAccounts": 50, + } + if only_direct_routes: + params["onlyDirectRoutes"] = "true" + if platform_fee_bps: + params["platformFeeBps"] = platform_fee_bps + + url = f"{JUPITER_URL}/quote?" + "&".join( + f"{k}={v}" for k, v in params.items() + ) quote_resp = requests.get(url) if quote_resp.status_code != 200: - raise Exception("Couldn't get a Jupiter quote") + raise Exception(f"Jupiter quote failed: {quote_resp.text}") quote = quote_resp.json() if out_ata is None: - out_ata: Pubkey = self.get_associated_token_account_public_key( + out_ata = self.get_associated_token_account_public_key( out_market.market_index ) - ai = await self.connection.get_account_info(out_ata) - if not ai.value: pre_instructions.append( self.create_associated_token_account_idempotent_instruction( @@ -3197,12 +3212,10 @@ async def get_jupiter_swap_ix_v6( ) if in_ata is None: - in_ata: Pubkey = self.get_associated_token_account_public_key( + in_ata = self.get_associated_token_account_public_key( in_market.market_index ) - ai = await self.connection.get_account_info(in_ata) - if not ai.value: pre_instructions.append( self.create_associated_token_account_idempotent_instruction( @@ -3213,23 +3226,24 @@ async def get_jupiter_swap_ix_v6( ) ) - data = { + swap_data = { "quoteResponse": quote, "userPublicKey": str(self.wallet.public_key), "destinationTokenAccount": str(out_ata), } + if fee_account: + swap_data["feeAccount"] = str(fee_account) swap_ix_resp = requests.post( f"{JUPITER_URL}/swap-instructions", headers={"Accept": "application/json", "Content-Type": "application/json"}, - data=json.dumps(data), + json=swap_data, ) if swap_ix_resp.status_code != 200: - raise Exception("Couldn't get Jupiter swap ix") + raise Exception(f"Jupiter swap instructions failed: {swap_ix_resp.text}") swap_ix_json = swap_ix_resp.json() - swap_ix = swap_ix_json.get("swapInstruction") address_table_lookups = swap_ix_json.get("addressLookupTableAddresses") @@ -3258,11 +3272,11 @@ async def get_jupiter_swap_ix_v6( cleansed_ixs: list[Instruction] = [] for ix in ixs: - if type(ix) == list: + if isinstance(ix, list): for i in ix: - if type(i) == dict: + if isinstance(i, dict): cleansed_ixs.append(self._dict_to_instructions(i)) - elif type(ix) == dict: + elif isinstance(ix, dict): cleansed_ixs.append(self._dict_to_instructions(ix)) else: cleansed_ixs.append(ix) diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index 02957fb9..ac6b91ee 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -489,6 +489,8 @@ def get_unrealized_funding_pnl( perp_market = self.drift_client.get_perp_market_account( position.market_index ) + if not perp_market: + raise Exception("Perp market account not found") unrealized_pnl += calculate_position_funding_pnl(perp_market, position)