-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactored reward calculation methods #268
base: refactor-validator
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from pydantic import BaseModel | ||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
import time | ||
import torch | ||
from folding.store import Job | ||
|
||
|
||
class RewardEvent(BaseModel): | ||
"""Contains rewards for all the responses in a batch""" | ||
|
||
reward_name: str | ||
rewards: torch.Tensor | ||
batch_time: float | ||
|
||
extra_info: Optional[dict] = None | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
|
||
class BatchRewardOutput(BaseModel): | ||
rewards: torch.Tensor | ||
extra_info: Optional[dict] = None | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
|
||
class BatchRewardInput(BaseModel): | ||
energies: torch.Tensor | ||
top_reward: float | ||
job: Job | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
|
||
class BaseReward(ABC): | ||
@abstractmethod | ||
def name(self) -> str: | ||
... | ||
|
||
@abstractmethod | ||
def __init__(self, **kwargs): | ||
pass | ||
|
||
@abstractmethod | ||
async def get_rewards( | ||
self, data: BatchRewardInput, rewards: torch.Tensor | ||
) -> BatchRewardOutput: | ||
pass | ||
|
||
async def setup_rewards(self, energies: torch.Tensor) -> torch.Tensor: | ||
"""Setup rewards for the given energies""" | ||
return torch.zeros(len(energies)) | ||
|
||
async def apply(self, data: BatchRewardInput) -> RewardEvent: | ||
self.rewards: torch.Tensor = await self.setup_rewards(energies=data.energies) | ||
t0: float = time.time() | ||
batch_rewards_output: BatchRewardOutput = await self.get_rewards( | ||
data=data, rewards=self.rewards | ||
) | ||
batch_rewards_output.rewards = await self.calculate_final_reward( | ||
rewards=batch_rewards_output.rewards, job=data.job | ||
) | ||
batch_rewards_time: float = time.time() - t0 | ||
|
||
return RewardEvent( | ||
reward_name=self.name(), | ||
rewards=batch_rewards_output.rewards, | ||
batch_time=batch_rewards_time, | ||
extra_info=batch_rewards_output.extra_info, | ||
) | ||
|
||
async def calculate_final_reward( | ||
self, rewards: torch.Tensor, job: Job | ||
) -> torch.Tensor: | ||
# priority_multiplier = 1 + (job.priority - 1) * 0.1 TODO: Implement priority | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotta implement this once the global job pool is on the horizon |
||
priority_multiplier = 1.0 | ||
organic_multiplier = 1.0 | ||
if "is_organic" in job.event.keys(): | ||
if job.event["is_organic"]: | ||
organic_multiplier = 10.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make this a parameter set in the config, same for the priority multiplier |
||
|
||
return rewards * priority_multiplier * organic_multiplier | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it feels like different reward pipelines should come with formulaic constructions of what their reward should be based on priority and if it's organic or not. Not sure, but this is what my gut is saying. In the case of different challenges, this delineation is much more clear, but when it's just is/is_not organics, I guess it's more difficult. Is there any merit in doing:
This way, you can easily call what you need and you don't need to be restricted to looking for tags in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and you can just construct the correct pipeline based on the entry point of the query |
||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}(name={self.name})" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from folding.base.reward import BaseReward, BatchRewardOutput, BatchRewardInput | ||
import torch | ||
from loguru import logger | ||
from folding.store import Job | ||
from folding.rewards.linear_reward import divide_decreasing | ||
|
||
|
||
class FoldingReward(BaseReward): | ||
"""Folding reward class""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def name(self) -> str: | ||
return "folding_reward" | ||
|
||
async def get_rewards( | ||
self, data: BatchRewardInput, rewards: torch.Tensor | ||
) -> BatchRewardOutput: | ||
""" | ||
A reward pipeline that determines how to place rewards onto the miners sampled within the batch. | ||
Currently applies a linearly decreasing reward on all miners that are not the current best / previously | ||
best loss using the function "divide_decreasing". | ||
|
||
Args: | ||
energies (torch.Tensor): tensor of returned energies | ||
rewards (torch.Tensor): tensor of rewards, floats. | ||
top_reward (float): upper bound reward. | ||
job (Job) | ||
""" | ||
energies: torch.Tensor = data.energies | ||
top_reward: float = data.top_reward | ||
job: Job = data.job | ||
|
||
nonzero_energies: torch.Tensor = torch.nonzero(energies) | ||
|
||
info = { | ||
"name": self.name(), | ||
"top_reward": top_reward, | ||
} | ||
|
||
# If the best hotkey is not in the set of hotkeys in the job, this means that the top miner has stopped replying. | ||
if job.best_hotkey not in job.hotkeys: | ||
logger.warning( | ||
f"Best hotkey {job.best_hotkey} not in hotkeys {job.hotkeys}. Assigning no reward." | ||
) | ||
return BatchRewardOutput( | ||
rewards=rewards, extra_info=info | ||
) # rewards of all 0s. | ||
|
||
best_index: int = job.hotkeys.index(job.best_hotkey) | ||
|
||
# There are cases where the top_miner stops replying. ensure to assign reward. | ||
rewards[best_index] = top_reward | ||
|
||
# If no miners reply, we want *all* reward to go to the top miner. | ||
if len(nonzero_energies) == 0: | ||
rewards[best_index] = 1 | ||
return BatchRewardOutput(rewards=rewards, extra_info=info) | ||
|
||
if (len(nonzero_energies) == 1) and (nonzero_energies[0] == best_index): | ||
rewards[best_index] = 1 | ||
return BatchRewardOutput(rewards=rewards, extra_info=info) | ||
|
||
# Find if there are any indicies that are the same as the best value | ||
remaining_miners = {} | ||
for index in nonzero_energies: | ||
# There could be multiple max energies. | ||
# The best energy could be the one that is saved in the store. | ||
if (energies[index] == job.best_loss) or (index == best_index): | ||
rewards[index] = top_reward | ||
else: | ||
remaining_miners[index] = energies[index] | ||
|
||
# The amount of reward that is distributed to the remaining miners MUST be less than the reward given to the top miners. | ||
num_reminaing_miners = len(remaining_miners) | ||
if num_reminaing_miners > 1: | ||
sorted_remaining_miners = dict( | ||
sorted(remaining_miners.items(), key=lambda item: item[1]) | ||
) # sort smallest to largest | ||
|
||
# Apply a fixed decrease in reward on the remaining non-zero miners. | ||
rewards_per_miner = divide_decreasing( | ||
amount_to_distribute=1 - top_reward, | ||
number_of_elements=num_reminaing_miners, | ||
) | ||
for index, r in zip(sorted_remaining_miners.keys(), rewards_per_miner): | ||
rewards[index] = r | ||
else: | ||
for index in remaining_miners.keys(): | ||
rewards[index] = 1 - top_reward | ||
|
||
return BatchRewardOutput(rewards=rewards, extra_info=info) |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your flow here, but I think it would make more sense to have
class OrganicReward(BaseReward)
so you don't have to have any of this baked in logic for checking if a set of data comes from a specific source. It would give us more flexibility, but this is a good start