Skip to content

Commit

Permalink
Merge pull request #2581 from opentensor/feat/rao-alpha-shares
Browse files Browse the repository at this point in the history
Feat/rao alpha shares
  • Loading branch information
ibraheem-opentensor authored Jan 22, 2025
2 parents c0cd372 + d145389 commit b76efac
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 82 deletions.
6 changes: 3 additions & 3 deletions bittensor/core/extrinsics/staking.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def add_stake_extrinsic(
netuid=netuid,
)
if old_stake is not None:
old_stake = old_stake.stake
old_stake = old_stake
else:
old_stake = Balance.from_tao(0)

Expand Down Expand Up @@ -243,7 +243,7 @@ def add_stake_extrinsic(
netuid=netuid,
)
if new_stake is not None:
new_stake = new_stake.stake
new_stake = new_stake
else:
new_stake = Balance.from_tao(0)

Expand Down Expand Up @@ -439,7 +439,7 @@ def add_stake_multiple_extrinsic(
netuid=netuid,
)
if new_stake is not None:
new_stake = new_stake.stake
new_stake = new_stake
else:
new_stake = Balance.from_tao(0)

Expand Down
6 changes: 3 additions & 3 deletions bittensor/core/extrinsics/unstaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def unstake_extrinsic(
netuid=netuid,
)
if old_stake is not None:
old_stake = old_stake.stake
old_stake = old_stake
else:
old_stake = Balance.from_tao(0)

Expand Down Expand Up @@ -205,7 +205,7 @@ def unstake_extrinsic(
netuid=netuid,
)
if new_stake is not None:
new_stake = new_stake.stake
new_stake = new_stake
else:
new_stake = Balance.from_tao(0)
logging.info(
Expand Down Expand Up @@ -377,7 +377,7 @@ def unstake_multiple_extrinsic(
netuid=netuid,
)
if new_stake is not None:
new_stake = new_stake.stake
new_stake = new_stake
else:
new_stake = Balance.from_tao(0)

Expand Down
46 changes: 30 additions & 16 deletions bittensor/core/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
hex_to_bytes,
Certificate,
)
from bittensor.utils.balance import Balance
from bittensor.utils.balance import Balance, fixed_to_float, FixedPoint
from bittensor.utils.btlogging import logging
from bittensor.utils.registration import legacy_torch_api_compat
from bittensor.utils.weight_utils import generate_weight_hash
Expand Down Expand Up @@ -1798,7 +1798,7 @@ def get_stake_for_coldkey_and_hotkey(
coldkey_ss58: str,
netuid: Optional[int] = None,
block: Optional[int] = None,
) -> Optional[Union["StakeInfo", list["StakeInfo"]]]:
) -> Balance:
"""
Returns the stake under a coldkey - hotkey pairing.
Expand All @@ -1811,20 +1811,34 @@ def get_stake_for_coldkey_and_hotkey(
Returns:
Optional[StakeInfo]: The StakeInfo object/s under the coldkey - hotkey pairing, or ``None`` if the pairing does not exist or the stake is not found.
"""
all_stakes = self.get_stake_for_coldkey(coldkey_ss58, block)
stakes = [
stake
for stake in all_stakes # type: ignore
if stake.hotkey_ss58 == hotkey_ss58
and (netuid is None or stake.netuid == netuid)
and stake.stake > 0
]
if not stakes:
return None
elif len(stakes) == 1:
return stakes[0]
else:
return stakes
alpha_shares: FixedPoint = self.query_module(
module="SubtensorModule",
name="Alpha",
block=block,
params=[hotkey_ss58, coldkey_ss58, netuid],
).value
hotkey_alpha: int = self.query_module(
module="SubtensorModule",
name="TotalHotkeyAlpha",
block=block,
params=[hotkey_ss58, netuid],
).value
hotkey_shares: FixedPoint = self.query_module(
module="SubtensorModule",
name="TotalHotkeyShares",
block=block,
params=[hotkey_ss58, netuid],
).value

alpha_shares_as_float = fixed_to_float(alpha_shares)
hotkey_shares_as_float = fixed_to_float(hotkey_shares)

if hotkey_shares_as_float == 0:
return Balance.from_rao(0)

stake = alpha_shares_as_float / hotkey_shares_as_float * hotkey_alpha

return Balance.from_rao(int(stake)).set_unit(netuid=netuid)

def does_hotkey_exist(self, hotkey_ss58: str, block: Optional[int] = None) -> bool:
"""
Expand Down
30 changes: 29 additions & 1 deletion bittensor/utils/balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Union
from typing import Union, TypedDict

from bittensor.core import settings

Expand Down Expand Up @@ -284,3 +284,31 @@ def set_unit(self, netuid: int):
self.unit = Balance.get_unit(netuid)
self.rao_unit = Balance.get_unit(netuid)
return self


class FixedPoint(TypedDict):
"""
Represents a fixed point ``U64F64`` number.
Where ``bits`` is a U128 representation of the fixed point number.
This matches the type of the Alpha shares.
"""

bits: int


def fixed_to_float(fixed: FixedPoint) -> float:
# Currently this is stored as a U64F64
# which is 64 bits of integer and 64 bits of fractional
uint_bits = 64
frac_bits = 64

data: int = fixed["bits"]

# Shift bits to extract integer part (assuming 64 bits for integer part)
integer_part = data >> frac_bits
fractional_part = data & (2**frac_bits - 1)

frac_float = fractional_part / (2**frac_bits)

return integer_part + frac_float
109 changes: 50 additions & 59 deletions tests/unit_tests/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from bittensor.core.settings import version_as_int
from bittensor.core.subtensor import Subtensor, logging
from bittensor.utils import u16_normalized_float, u64_normalized_float, Certificate
from bittensor.utils.balance import Balance
from bittensor.utils.balance import Balance, fixed_to_float

U16_MAX = 65535
U64_MAX = 18446744073709551615
Expand Down Expand Up @@ -2197,63 +2197,28 @@ def test_networks_during_connection(mocker):


def test_get_stake_for_coldkey_and_hotkey_with_single_result(subtensor, mocker):
"""Test `get_stake_for_coldkey_and_hotkey` calls right method with correct arguments and get 1 stake info."""
"""Test get_stake_for_coldkey_and_hotkey calculation and network calls."""
# Preps
fake_hotkey_ss58 = "FAKE_H_SS58"
fake_coldkey_ss58 = "FAKE_C_SS58"
fake_netuid = 255
fake_block = 123

fake_stake_info_1 = mocker.Mock(hotkey_ss58="some")
fake_stake_info_2 = mocker.Mock(
hotkey_ss58=fake_hotkey_ss58, netuid=fake_netuid, stake=100
)

return_value = [
fake_stake_info_1,
fake_stake_info_2,
]

subtensor.get_stake_for_coldkey = mocker.patch.object(
subtensor, "get_stake_for_coldkey", return_value=return_value
)

# Call
result = subtensor.get_stake_for_coldkey_and_hotkey(
hotkey_ss58=fake_hotkey_ss58,
coldkey_ss58=fake_coldkey_ss58,
netuid=fake_netuid,
block=fake_block,
)

# Asserts
subtensor.get_stake_for_coldkey.assert_called_once_with(
fake_coldkey_ss58, fake_block
)
assert result == fake_stake_info_2


def test_get_stake_for_coldkey_and_hotkey_with_multiple_result(subtensor, mocker):
"""Test `get_stake_for_coldkey_and_hotkey` calls right method with correct arguments and get multiple stake info."""
# Preps
fake_hotkey_ss58 = "FAKE_H_SS58"
fake_coldkey_ss58 = "FAKE_C_SS58"
fake_netuid = 255
fake_block = 123
fake_hotkey_ss58 = "FAKE_HK_SS58"
fake_coldkey_ss58 = "FAKE_CK_SS58"
fake_netuid = 2
fake_block = None

fake_stake_info_1 = mocker.Mock(hotkey_ss58="some")
fake_stake_info_2 = mocker.Mock(
hotkey_ss58=fake_hotkey_ss58, netuid=fake_netuid, stake=100
)
fake_stake_info_3 = mocker.Mock(
hotkey_ss58=fake_hotkey_ss58, netuid=fake_netuid, stake=200
)
alpha_shares = {"bits": 177229957888291400329606044405}
hotkey_alpha = 96076552686
hotkey_shares = {"bits": 177229957888291400329606044405}

return_value = [fake_stake_info_1, fake_stake_info_2, fake_stake_info_3]
# Mock
def mock_query_module(module, name, block, params):
if name == "Alpha":
return mocker.Mock(value=alpha_shares)
elif name == "TotalHotkeyAlpha":
return mocker.Mock(value=hotkey_alpha)
elif name == "TotalHotkeyShares":
return mocker.Mock(value=hotkey_shares)
return None

subtensor.get_stake_for_coldkey = mocker.patch.object(
subtensor, "get_stake_for_coldkey", return_value=return_value
)
subtensor.query_module = mocker.MagicMock(side_effect=mock_query_module)

# Call
result = subtensor.get_stake_for_coldkey_and_hotkey(
Expand All @@ -2263,11 +2228,37 @@ def test_get_stake_for_coldkey_and_hotkey_with_multiple_result(subtensor, mocker
block=fake_block,
)

# Asserts
subtensor.get_stake_for_coldkey.assert_called_once_with(
fake_coldkey_ss58, fake_block
)
assert result == [fake_stake_info_2, fake_stake_info_3]
# Assertions
subtensor.query_module.assert_has_calls(
[
mocker.call(
module="SubtensorModule",
name="Alpha",
block=fake_block,
params=[fake_hotkey_ss58, fake_coldkey_ss58, fake_netuid],
),
mocker.call(
module="SubtensorModule",
name="TotalHotkeyAlpha",
block=fake_block,
params=[fake_hotkey_ss58, fake_netuid],
),
mocker.call(
module="SubtensorModule",
name="TotalHotkeyShares",
block=fake_block,
params=[fake_hotkey_ss58, fake_netuid],
),
]
)

alpha_shares_as_float = fixed_to_float(alpha_shares)
hotkey_shares_as_float = fixed_to_float(hotkey_shares)
expected_stake = int(
(alpha_shares_as_float / hotkey_shares_as_float) * hotkey_alpha
)

assert result == Balance.from_rao(expected_stake).set_unit(netuid=fake_netuid)


def test_does_hotkey_exist_true(mocker, subtensor):
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/utils/test_fixed_float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from bittensor.utils.balance import fixed_to_float, FixedPoint

# Generated using the following gist: https://gist.github.com/camfairchild/8c6b6b9faa8aa1ae7ddc49ce177a27f2
examples: list[tuple[int, float]] = [
(22773757908449605611411210240, 1234567890),
(22773757910726980065558528000, 1234567890.1234567),
(22773757910726980065558528000, 1234567890.1234567),
(22773757910726980065558528000, 1234567890.1234567),
(4611686018427387904, 0.25),
(9223372036854775808, 0.5),
(13835058055282163712, 0.75),
(18446744073709551616, 1.0),
(23058430092136939520, 1.25),
(27670116110564327424, 1.5),
(32281802128991715328, 1.75),
(36893488147419103232, 2.0),
(6148914691236516864, 0.3333333333333333),
(2635249153387078656, 0.14285714285714285),
(4611686018427387904, 0.25),
(0, 0),
(0, 0.0),
]


@pytest.mark.parametrize("bits, float_value", examples)
def test_fixed_to_float(bits: int, float_value: float):
EPS = 1e-10
fp = FixedPoint(bits=bits)
assert abs(fixed_to_float(fp) - float_value) < EPS

0 comments on commit b76efac

Please sign in to comment.