Skip to content

Commit

Permalink
Merge pull request #217 from drift-labs/sina/add-direct-route-to-jup
Browse files Browse the repository at this point in the history
Update jup ix to have direct routes
  • Loading branch information
SinaKhalili authored Jan 7, 2025
2 parents 926f401 + 6f10adc commit 3c73509
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 21 deletions.
93 changes: 93 additions & 0 deletions examples/spot_market_trade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import asyncio
import logging
import os

from anchorpy.provider import Provider, Wallet
from dotenv import load_dotenv
from solana.rpc.async_api import AsyncClient

from driftpy.constants.spot_markets import mainnet_spot_market_configs
from driftpy.drift_client import DriftClient
from driftpy.keypair import load_keypair
from driftpy.types import TxParams

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

load_dotenv()


def get_market_by_symbol(symbol: str):
for market in mainnet_spot_market_configs:
if market.symbol == symbol:
return market
raise Exception(f"Market {symbol} not found")


async def make_spot_trade():
rpc = os.environ.get("RPC_TRITON")
secret = os.environ.get("PRIVATE_KEY")
kp = load_keypair(secret)
wallet = Wallet(kp)
logger.info(f"Using wallet: {wallet.public_key}")

connection = AsyncClient(rpc)
provider = Provider(connection, wallet)
drift_client = DriftClient(
provider.connection,
provider.wallet,
"mainnet",
tx_params=TxParams(
compute_units_price=85_000,
compute_units=1_000_000,
),
)
await drift_client.subscribe()
logger.info("Drift client subscribed")

in_decimals_result = drift_client.get_spot_market_account(
get_market_by_symbol("USDS").market_index
)
if not in_decimals_result:
logger.error("USDS market not found")
raise Exception("Market not found")

in_decimals = in_decimals_result.decimals
logger.info(f"USDS decimals: {in_decimals}")

swap_amount = int(1 * 10**in_decimals)
logger.info(f"Swapping {swap_amount} USDS to USDC")

try:
swap_ixs, swap_lookups = await drift_client.get_jupiter_swap_ix_v6(
out_market_idx=get_market_by_symbol("USDC").market_index,
in_market_idx=get_market_by_symbol("USDS").market_index,
amount=swap_amount,
swap_mode="ExactIn",
only_direct_routes=True,
)
logger.info("Got swap instructions")
print("[DEBUG] Got swap instructions of length", len(swap_ixs))

await drift_client.send_ixs(
ixs=swap_ixs,
lookup_tables=swap_lookups,
)
logger.info("Swap complete")
except Exception as e:
logger.error(f"Error during swap: {e}")
raise e
finally:
await drift_client.unsubscribe()
await connection.close()


if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(make_spot_trade())
finally:
pending = asyncio.all_tasks(loop)
loop.run_until_complete(asyncio.gather(*pending))
loop.close()
56 changes: 35 additions & 21 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import json
import os
import random
import string
Expand Down Expand Up @@ -3155,37 +3154,53 @@ async def get_jupiter_swap_ix_v6(
amount: int,
out_ata: Optional[Pubkey] = None,
in_ata: Optional[Pubkey] = None,
slippage_bps: Optional[int] = None,
quote=None,
slippage_bps: int = 50,
quote: Optional[dict] = None,
reduce_only: Optional[SwapReduceOnly] = None,
user_account_public_key: Optional[Pubkey] = None,
swap_mode: str = "ExactIn",
fee_account: Optional[Pubkey] = None,
platform_fee_bps: Optional[int] = None,
only_direct_routes: bool = False,
) -> Tuple[list[Instruction], list[AddressLookupTableAccount]]:
pre_instructions: list[Instruction] = []
JUPITER_URL = os.getenv("JUPITER_URL", "https://quote-api.jup.ag/v6")

out_market = self.get_spot_market_account(out_market_idx)
in_market = self.get_spot_market_account(in_market_idx)

if slippage_bps is None:
slippage_bps = 10
if not out_market or not in_market:
raise Exception("Invalid market indexes")

if quote is None:
url = f"{JUPITER_URL}/quote?inputMint={str(in_market.mint)}&outputMint={str(out_market.mint)}&amount={amount}&slippageBps={slippage_bps}"

params = {
"inputMint": str(in_market.mint),
"outputMint": str(out_market.mint),
"amount": str(amount),
"slippageBps": slippage_bps,
"swapMode": swap_mode,
"maxAccounts": 50,
}
if only_direct_routes:
params["onlyDirectRoutes"] = "true"
if platform_fee_bps:
params["platformFeeBps"] = platform_fee_bps

url = f"{JUPITER_URL}/quote?" + "&".join(
f"{k}={v}" for k, v in params.items()
)
quote_resp = requests.get(url)

if quote_resp.status_code != 200:
raise Exception("Couldn't get a Jupiter quote")
raise Exception(f"Jupiter quote failed: {quote_resp.text}")

quote = quote_resp.json()

if out_ata is None:
out_ata: Pubkey = self.get_associated_token_account_public_key(
out_ata = self.get_associated_token_account_public_key(
out_market.market_index
)

ai = await self.connection.get_account_info(out_ata)

if not ai.value:
pre_instructions.append(
self.create_associated_token_account_idempotent_instruction(
Expand All @@ -3197,12 +3212,10 @@ async def get_jupiter_swap_ix_v6(
)

if in_ata is None:
in_ata: Pubkey = self.get_associated_token_account_public_key(
in_ata = self.get_associated_token_account_public_key(
in_market.market_index
)

ai = await self.connection.get_account_info(in_ata)

if not ai.value:
pre_instructions.append(
self.create_associated_token_account_idempotent_instruction(
Expand All @@ -3213,23 +3226,24 @@ async def get_jupiter_swap_ix_v6(
)
)

data = {
swap_data = {
"quoteResponse": quote,
"userPublicKey": str(self.wallet.public_key),
"destinationTokenAccount": str(out_ata),
}
if fee_account:
swap_data["feeAccount"] = str(fee_account)

swap_ix_resp = requests.post(
f"{JUPITER_URL}/swap-instructions",
headers={"Accept": "application/json", "Content-Type": "application/json"},
data=json.dumps(data),
json=swap_data,
)

if swap_ix_resp.status_code != 200:
raise Exception("Couldn't get Jupiter swap ix")
raise Exception(f"Jupiter swap instructions failed: {swap_ix_resp.text}")

swap_ix_json = swap_ix_resp.json()

swap_ix = swap_ix_json.get("swapInstruction")
address_table_lookups = swap_ix_json.get("addressLookupTableAddresses")

Expand Down Expand Up @@ -3258,11 +3272,11 @@ async def get_jupiter_swap_ix_v6(
cleansed_ixs: list[Instruction] = []

for ix in ixs:
if type(ix) == list:
if isinstance(ix, list):
for i in ix:
if type(i) == dict:
if isinstance(i, dict):
cleansed_ixs.append(self._dict_to_instructions(i))
elif type(ix) == dict:
elif isinstance(ix, dict):
cleansed_ixs.append(self._dict_to_instructions(ix))
else:
cleansed_ixs.append(ix)
Expand Down
2 changes: 2 additions & 0 deletions src/driftpy/drift_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ def get_unrealized_funding_pnl(
perp_market = self.drift_client.get_perp_market_account(
position.market_index
)
if not perp_market:
raise Exception("Perp market account not found")

unrealized_pnl += calculate_position_funding_pnl(perp_market, position)

Expand Down

0 comments on commit 3c73509

Please sign in to comment.