Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
callebtc committed Dec 14, 2024
1 parent 7be5f9f commit 275a4f3
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 160 deletions.
5 changes: 3 additions & 2 deletions cashu/wallet/protocols.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict, Protocol
from typing import Dict, List, Protocol

import httpx

from ..core.base import Unit, WalletKeyset
from ..core.base import Proof, Unit, WalletKeyset
from ..core.crypto.secp import PrivateKey
from ..core.db import Database

Expand All @@ -13,6 +13,7 @@ class SupportsPrivateKey(Protocol):

class SupportsDb(Protocol):
db: Database
proofs: List[Proof]


class SupportsKeysets(Protocol):
Expand Down
44 changes: 24 additions & 20 deletions cashu/wallet/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,26 +156,30 @@ async def generate_n_secrets(
"""
if n < 1:
return [], [], []

secret_counters_start = await bump_secret_derivation(
db=self.db, keyset_id=self.keyset_id, by=n, skip=skip_bump
)
logger.trace(f"secret_counters_start: {secret_counters_start}")
secret_counters = list(range(secret_counters_start, secret_counters_start + n))
logger.trace(
f"Generating secret nr {secret_counters[0]} to {secret_counters[-1]}."
)
secrets_rs_derivationpaths = [
await self.generate_determinstic_secret(s) for s in secret_counters
]
# secrets are supplied as str
secrets = [s[0].hex() for s in secrets_rs_derivationpaths]
# rs are supplied as PrivateKey
rs = [PrivateKey(privkey=s[1], raw=True) for s in secrets_rs_derivationpaths]

derivation_paths = [s[2] for s in secrets_rs_derivationpaths]

return secrets, rs, derivation_paths
async with self.db.get_connection(lock_table="keysets") as conn:
secret_counters_start = await bump_secret_derivation(
db=self.db, keyset_id=self.keyset_id, by=n, skip=skip_bump, conn=conn
)
logger.trace(f"secret_counters_start: {secret_counters_start}")
secret_counters = list(
range(secret_counters_start, secret_counters_start + n)
)
logger.trace(
f"Generating secret nr {secret_counters[0]} to {secret_counters[-1]}."
)
secrets_rs_derivationpaths = [
await self.generate_determinstic_secret(s) for s in secret_counters
]
# secrets are supplied as str
secrets = [s[0].hex() for s in secrets_rs_derivationpaths]
# rs are supplied as PrivateKey
rs = [
PrivateKey(privkey=s[1], raw=True) for s in secrets_rs_derivationpaths
]

derivation_paths = [s[2] for s in secrets_rs_derivationpaths]

return secrets, rs, derivation_paths

async def generate_secrets_from_to(
self, from_counter: int, to_counter: int, keyset_id: Optional[str] = None
Expand Down
100 changes: 99 additions & 1 deletion cashu/wallet/transactions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import Dict, List, Union
from typing import Dict, List, Optional, Tuple, Union

from loguru import logger

Expand All @@ -10,6 +10,8 @@
)
from ..core.db import Database
from ..core.helpers import amount_summary, sum_proofs
from ..core.settings import settings
from ..core.split import amount_split
from ..wallet.crud import (
update_proof,
)
Expand Down Expand Up @@ -109,6 +111,102 @@ def coinselect_fee(self, proofs: List[Proof], amount: int) -> int:
proofs_send = self.coinselect(proofs, amount, include_fees=True)
return self.get_fees_for_proofs(proofs_send)

def split_wallet_state(self, amount: int) -> List[int]:
"""This function produces an amount split for outputs based on the current state of the wallet.
Its objective is to fill up the wallet so that it reaches `n_target` coins of each amount.
Args:
amount (int): Amount to split
Returns:
List[int]: List of amounts to mint
"""
# read the target count for each amount from settings
n_target = settings.wallet_target_amount_count
amounts_we_have = [p.amount for p in self.proofs if p.reserved is not True]
amounts_we_have.sort()
# NOTE: Do not assume 2^n here
all_possible_amounts: list[int] = [2**i for i in range(settings.max_order)]
amounts_we_want_ll = [
[a] * max(0, n_target - amounts_we_have.count(a))
for a in all_possible_amounts
]
# flatten list of lists to list
amounts_we_want = [item for sublist in amounts_we_want_ll for item in sublist]
# sort by increasing amount
amounts_we_want.sort()

logger.trace(
f"Amounts we have: {[(a, amounts_we_have.count(a)) for a in set(amounts_we_have)]}"
)
amounts: list[int] = []
while sum(amounts) < amount and amounts_we_want:
if sum(amounts) + amounts_we_want[0] > amount:
break
amounts.append(amounts_we_want.pop(0))

remaining_amount = amount - sum(amounts)
if remaining_amount > 0:
amounts += amount_split(remaining_amount)
amounts.sort()

logger.trace(f"Amounts we want: {amounts}")
if sum(amounts) != amount:
raise Exception(f"Amounts do not sum to {amount}.")

return amounts

def determine_output_amounts(
self,
proofs: List[Proof],
amount: int,
include_fees: bool = False,
keyset_id_outputs: Optional[str] = None,
) -> Tuple[List[int], List[int]]:
"""This function generates a suitable amount split for the outputs to keep and the outputs to send. It
calculates the amount to keep based on the wallet state and the amount to send based on the amount
provided.
Amount to keep is based on the proofs we have in the wallet
Amount to send is optimally split based on the amount provided plus optionally the fees required to receive them.
Args:
proofs (List[Proof]): Proofs to be split.
amount (int): Amount to be sent.
include_fees (bool, optional): If True, the fees are included in the amount to send (output of
this method, to be sent in the future). This is not the fee that is required to swap the
`proofs` (input to this method). Defaults to False.
keyset_id_outputs (str, optional): The keyset ID of the outputs to be produced, used to determine the
fee if `include_fees` is set.
Returns:
Tuple[List[int], List[int]]: Two lists of amounts, one for keeping and one for sending.
"""
# create a suitable amount split based on the proofs provided
total = sum_proofs(proofs)
keep_amt, send_amt = total - amount, amount

if include_fees:
keyset_id = keyset_id_outputs or self.keyset_id
tmp_proofs = [Proof(id=keyset_id) for _ in amount_split(send_amt)]
fee = self.get_fees_for_proofs(tmp_proofs)
keep_amt -= fee
send_amt += fee

logger.trace(f"Keep amount: {keep_amt}, send amount: {send_amt}")
logger.trace(f"Total input: {sum_proofs(proofs)}")
# generate optimal split for outputs to send
send_amounts = amount_split(send_amt)

# we subtract the input fee for the entire transaction from the amount to keep
keep_amt -= self.get_fees_for_proofs(proofs)
logger.trace(f"Keep amount: {keep_amt}")

# we determine the amounts to keep based on the wallet state
keep_amounts = self.split_wallet_state(keep_amt)

return keep_amounts, send_amounts

async def set_reserved(self, proofs: List[Proof], reserved: bool) -> None:
"""Mark a proof as reserved or reset it in the wallet db to avoid reuse when it is sent.
Expand Down
150 changes: 13 additions & 137 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import base64
import copy
import threading
import time
Expand All @@ -20,7 +19,6 @@
WalletKeyset,
)
from ..core.crypto import b_dhke
from ..core.crypto.keys import derive_keyset_id
from ..core.crypto.secp import PrivateKey, PublicKey
from ..core.db import Database
from ..core.errors import KeysetNotFoundError
Expand All @@ -39,8 +37,8 @@
from ..core.nuts import nut20
from ..core.p2pk import Secret
from ..core.settings import settings
from ..core.split import amount_split
from . import migrations
from .compat import WalletCompat
from .crud import (
bump_secret_derivation,
get_bolt11_mint_quote,
Expand Down Expand Up @@ -70,7 +68,13 @@


class Wallet(
LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets, WalletTransactions, WalletProofs
LedgerAPI,
WalletP2PK,
WalletHTLC,
WalletSecrets,
WalletTransactions,
WalletProofs,
WalletCompat,
):
"""
Nutshell wallet class.
Expand Down Expand Up @@ -250,39 +254,6 @@ async def load_mint_keysets(self, force_old_keysets=False):

await self.load_keysets_from_db()

async def inactivate_base64_keysets(self, force_old_keysets: bool) -> None:
# BEGIN backwards compatibility: phase out keysets with base64 ID by treating them as inactive
if settings.wallet_inactivate_base64_keysets and not force_old_keysets:
keysets_in_db = await get_keysets(mint_url=self.url, db=self.db)
for keyset in keysets_in_db:
if not keyset.active:
continue
# test if the keyset id is a hex string, if not it's base64
try:
int(keyset.id, 16)
except ValueError:
# verify that it's base64
try:
_ = base64.b64decode(keyset.id)
except ValueError:
logger.error("Unexpected: keyset id is neither hex nor base64.")
continue

# verify that we have a hex version of the same keyset by comparing public keys
hex_keyset_id = derive_keyset_id(keys=keyset.public_keys)
if hex_keyset_id not in [k.id for k in keysets_in_db]:
logger.warning(
f"Keyset {keyset.id} is base64 but we don't have a hex version. Ignoring."
)
continue

logger.warning(
f"Keyset {keyset.id} is base64 and has a hex counterpart, setting inactive."
)
keyset.active = False
await update_keyset(keyset=keyset, db=self.db)
# END backwards compatibility

async def activate_keyset(self, keyset_id: Optional[str] = None) -> None:
"""Activates a keyset by setting self.keyset_id. Either activates a specific keyset
of chooses one of the active keysets of the mint with the same unit as the wallet.
Expand Down Expand Up @@ -453,51 +424,6 @@ async def request_mint(
await store_bolt11_mint_quote(db=self.db, quote=quote)
return quote

def split_wallet_state(self, amount: int) -> List[int]:
"""This function produces an amount split for outputs based on the current state of the wallet.
Its objective is to fill up the wallet so that it reaches `n_target` coins of each amount.
Args:
amount (int): Amount to split
Returns:
List[int]: List of amounts to mint
"""
# read the target count for each amount from settings
n_target = settings.wallet_target_amount_count
amounts_we_have = [p.amount for p in self.proofs if p.reserved is not True]
amounts_we_have.sort()
# NOTE: Do not assume 2^n here
all_possible_amounts: list[int] = [2**i for i in range(settings.max_order)]
amounts_we_want_ll = [
[a] * max(0, n_target - amounts_we_have.count(a))
for a in all_possible_amounts
]
# flatten list of lists to list
amounts_we_want = [item for sublist in amounts_we_want_ll for item in sublist]
# sort by increasing amount
amounts_we_want.sort()

logger.trace(
f"Amounts we have: {[(a, amounts_we_have.count(a)) for a in set(amounts_we_have)]}"
)
amounts: list[int] = []
while sum(amounts) < amount and amounts_we_want:
if sum(amounts) + amounts_we_want[0] > amount:
break
amounts.append(amounts_we_want.pop(0))

remaining_amount = amount - sum(amounts)
if remaining_amount > 0:
amounts += amount_split(remaining_amount)
amounts.sort()

logger.trace(f"Amounts we want: {amounts}")
if sum(amounts) != amount:
raise Exception(f"Amounts do not sum to {amount}.")

return amounts

async def mint(
self,
amount: int,
Expand Down Expand Up @@ -574,68 +500,14 @@ async def redeem(
self,
proofs: List[Proof],
) -> Tuple[List[Proof], List[Proof]]:
"""Redeem proofs by sending them to yourself (by calling a split).)
Calls `add_witnesses_to_proofs` which parses all proofs and checks whether their
secrets corresponds to any locks that we have the unlock conditions for. If so,
it adds the unlock conditions to the proofs.
"""Redeem proofs by sending them to yourself by calling a split.
Args:
proofs (List[Proof]): Proofs to be redeemed.
"""
# verify DLEQ of incoming proofs
self.verify_proofs_dleq(proofs)
return await self.split(proofs=proofs, amount=0)

def determine_output_amounts(
self,
proofs: List[Proof],
amount: int,
include_fees: bool = False,
keyset_id_outputs: Optional[str] = None,
) -> Tuple[List[int], List[int]]:
"""This function generates a suitable amount split for the outputs to keep and the outputs to send. It
calculates the amount to keep based on the wallet state and the amount to send based on the amount
provided.
Amount to keep is based on the proofs we have in the wallet
Amount to send is optimally split based on the amount provided plus optionally the fees required to receive them.
Args:
proofs (List[Proof]): Proofs to be split.
amount (int): Amount to be sent.
include_fees (bool, optional): If True, the fees are included in the amount to send (output of
this method, to be sent in the future). This is not the fee that is required to swap the
`proofs` (input to this method). Defaults to False.
keyset_id_outputs (str, optional): The keyset ID of the outputs to be produced, used to determine the
fee if `include_fees` is set.
Returns:
Tuple[List[int], List[int]]: Two lists of amounts, one for keeping and one for sending.
"""
# create a suitable amount split based on the proofs provided
total = sum_proofs(proofs)
keep_amt, send_amt = total - amount, amount

if include_fees:
keyset_id = keyset_id_outputs or self.keyset_id
tmp_proofs = [Proof(id=keyset_id) for _ in amount_split(send_amt)]
fee = self.get_fees_for_proofs(tmp_proofs)
keep_amt -= fee
send_amt += fee

logger.trace(f"Keep amount: {keep_amt}, send amount: {send_amt}")
logger.trace(f"Total input: {sum_proofs(proofs)}")
# generate optimal split for outputs to send
send_amounts = amount_split(send_amt)

# we subtract the input fee for the entire transaction from the amount to keep
keep_amt -= self.get_fees_for_proofs(proofs)
logger.trace(f"Keep amount: {keep_amt}")

# we determine the amounts to keep based on the wallet state
keep_amounts = self.split_wallet_state(keep_amt)

return keep_amounts, send_amounts

async def split(
self,
proofs: List[Proof],
Expand All @@ -649,6 +521,10 @@ async def split(
and the promises to send (send_outputs). If secret_lock is provided, the wallet will create
blinded secrets with those to attach a predefined spending condition to the tokens they want to send.
Calls `add_witnesses_to_proofs` which parses all proofs and checks whether their
secrets corresponds to any locks that we have the unlock conditions for. If so,
it adds the unlock conditions to the proofs.
Args:
proofs (List[Proof]): Proofs to be split.
amount (int): Amount to be sent.
Expand Down

0 comments on commit 275a4f3

Please sign in to comment.