Skip to content

Commit

Permalink
Merge pull request #365 from macrocosm-os/dev
Browse files Browse the repository at this point in the history
API SECURITY, HF REFACTOR OF SCORING.
  • Loading branch information
Arrmlet authored Jan 25, 2025
2 parents a1bebfb + eaa1137 commit a1835e8
Show file tree
Hide file tree
Showing 11 changed files with 362 additions and 59 deletions.
2 changes: 1 addition & 1 deletion common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@
NO_TWITTER_URLS_DATE = datetime.datetime(2024, 12, 28, tzinfo=datetime.timezone.utc) # December 28, 2024 UTC

# HF reward activation date.
HF_REWARD_DATE = datetime.datetime(2025, 1, 27, tzinfo=datetime.timezone.utc) # January 27, 2025 UTC
HF_REWARD_DATE = datetime.datetime(2025, 1, 25, hour=16, tzinfo=datetime.timezone.utc) # January 25, 16:00 2025 UTC
6 changes: 3 additions & 3 deletions huggingface_utils/huggingface_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def get_data_for_huggingface_upload(self, source, last_upload):
FROM DataEntity
WHERE source = ?
ORDER BY datetime ASC
LIMIT 400000000
LIMIT 200000000
"""
params = [source]
else:
Expand Down Expand Up @@ -252,8 +252,8 @@ def upload_sql_to_huggingface(self) -> List[HuggingFaceMetadata]:
continue

bt.logging.info(f"Current total rows: {total_rows}")
if total_rows >= 400_000_000:
bt.logging.info(f"Reached 400 million rows limit for source {source}. Stopping upload.")
if total_rows >= 200_000_000: # TODO
bt.logging.info(f"Reached 200 million rows limit for source {source}. Stopping upload.")
break

last_upload = df['datetime'].max()
Expand Down
13 changes: 9 additions & 4 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,15 @@ def serve_axon(self):
self.axon = bt.axon(wallet=self.wallet, config=self.config)
self.axon.serve(netuid=self.config.netuid, subtensor=self.subtensor).start()
if self.config.neuron.api_on:
bt.logging.info("Starting Validator API...")
from vali_utils.api.server import ValidatorAPI
self.api = ValidatorAPI(self, port=self.config.neuron.api_port)
self.api.start()
try:
bt.logging.info("Starting Validator API...")
from vali_utils.api.server import ValidatorAPI
self.api = ValidatorAPI(self, port=self.config.neuron.api_port)
self.api.start()
except ValueError as e:
bt.logging.error(f"Failed to start API: {str(e)}")
bt.logging.info("Validator will continue running without API.")
self.config.neuron.api_on = False

bt.logging.info(
f"Serving validator axon {self.axon} on network: {self.config.subtensor.chain_endpoint} with netuid: {self.config.netuid}."
Expand Down
28 changes: 19 additions & 9 deletions rewards/miner_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class MinerScorer:
# Start new miner's at a credibility of 0.
STARTING_CREDIBILITY = 0

# Start new miners' HF credibility at 0.375
STARTING_HF_CREDIBILITY = 0.375

# The exponent used to scale the miner's score by its credibility.
_CREDIBILITY_EXP = 2.5

Expand All @@ -27,6 +30,7 @@ def __init__(
num_neurons: int,
value_calculator: DataValueCalculator,
cred_alpha: float = 0.15,
hf_cred_alpha: float = 0.20
):
# Tracks the raw scores of each miner. i.e. not the weights that are set on the blockchain.
self.scores = torch.zeros(num_neurons, dtype=torch.float32)
Expand All @@ -39,7 +43,11 @@ def __init__(
self.cred_alpha = cred_alpha

# Keeps track of the miner's current HF boost based on the last HF evaluation.
self.hf_boost = 0.0
self.hf_boosts = torch.zeros(num_neurons, dtype=torch.float32)
self.hf_credibility = torch.full(
(num_neurons, 1), MinerScorer.STARTING_HF_CREDIBILITY, dtype=torch.float32
)
self.hf_cred_alpha = hf_cred_alpha

# Make this class thread safe because it'll eventually be accessed by multiple threads.
# One from the main validator evaluation loop and another from a background thread performing validation on user requests.
Expand All @@ -52,6 +60,8 @@ def save_state(self, filepath):
{
"scores": self.scores,
"credibility": self.miner_credibility,
"hf_boosts": self.hf_boosts,
"hf_credibility": self.hf_credibility,
"scorable_bytes": self.scorable_bytes,
},
filepath,
Expand Down Expand Up @@ -119,13 +129,13 @@ def resize(self, num_neurons: int) -> None:
[self.scorable_bytes, torch.zeros(to_add, dtype=torch.float32)]
)

def update_hf_boost(self, uid: int, hf_vali_percentage: float) -> None:
def update_hf_boost_and_cred(self, uid: int, hf_vali_percentage: float) -> None:
"""Applies a fixed boost to the scaled score if the miner has passed HF validation."""
bt.logging.info(f"Miner passed HF validation with a validation percentage of {hf_vali_percentage}.")
max_boost = 3 * 10**6
self.hf_boost = hf_vali_percentage/100 * max_boost
bt.logging.success(
f"Awarded Miner {uid} a hf_boost of {self.hf_boost} for passing HF validation."
max_boost = 10 * 10**6
self.hf_boosts[uid] = hf_vali_percentage/100 * max_boost
self.hf_credibility[uid] = hf_vali_percentage * self.hf_cred_alpha + (1-self.hf_cred_alpha) * self.hf_credibility[uid]
bt.logging.info(
f"After HF evaluation for miner {uid}: Raw HF Boost = {self.hf_boosts[uid]}. HF Credibility = {self.hf_credibility[uid]}."
)

def on_miner_evaluated(
Expand Down Expand Up @@ -179,8 +189,8 @@ def on_miner_evaluated(
# Hugging Face rewards are active after Jan 27 2025.
if dt.datetime.now(dt.timezone.utc) >= HF_REWARD_DATE:
# Awarding the miner their HF boost based on their last HF evaluation.
score += self.hf_boost
bt.logging.info(f"Awarded Miner {uid} a HF boost of {self.hf_boost} based off of the lastest HF evaluation, adjusting the score to {score}.")
score += self.hf_boosts[uid] * self.hf_credibility[uid]
bt.logging.info(f"Awarded Miner {uid} a HF boost of {self.hf_boosts[uid] * self.hf_credibility[uid]} based off of the lastest HF evaluation, adjusting the score to {score}.")

# Now update the credibility again based on the current validation results.
self._update_credibility(uid, validation_results)
Expand Down
30 changes: 0 additions & 30 deletions vali_utils/api/auth.py

This file was deleted.

191 changes: 191 additions & 0 deletions vali_utils/api/auth/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from fastapi import HTTPException, Security
from fastapi.security.api_key import APIKeyHeader
import os
import sqlite3
from typing import Dict, Optional, List
from dotenv import load_dotenv
import time
from datetime import datetime, timedelta
import threading
import secrets
import bittensor as bt

load_dotenv()

API_KEY_NAME = "X-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME)


class APIKeyManager:
def __init__(self, db_path: str = "api_keys.db"):
# Master key from environment
self.master_key = os.getenv('MASTER_KEY')
if not self.master_key:
bt.logging.error("MASTER_KEY not found in environment. API will be disabled.")
raise ValueError(
"MASTER_KEY environment variable is required to enable API. "
"Please set MASTER_KEY in your .env file."
)
self.db_path = db_path
self.lock = threading.Lock()
self._init_db()

def _init_db(self):
"""Initialize SQLite database"""
with sqlite3.connect(self.db_path) as conn:
# Create API keys table
conn.execute("""
CREATE TABLE IF NOT EXISTS api_keys (
key TEXT PRIMARY KEY,
name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT TRUE
)
""")
# Create rate limiting table
conn.execute("""
CREATE TABLE IF NOT EXISTS rate_limits (
key TEXT,
request_time TIMESTAMP,
FOREIGN KEY(key) REFERENCES api_keys(key)
)
""")

def create_api_key(self, name: str) -> str:
"""Create a new API key"""
api_key = f"sk_live_{secrets.token_urlsafe(32)}"
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"INSERT INTO api_keys (key, name) VALUES (?, ?)",
(api_key, name)
)
return api_key

def deactivate_api_key(self, key: str):
"""Deactivate an API key"""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"UPDATE api_keys SET is_active = FALSE WHERE key = ?",
(key,)
)

def list_api_keys(self) -> List[Dict]:
"""List all API keys"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(
"SELECT key, name, created_at, is_active FROM api_keys"
)
return [dict(row) for row in cursor.fetchall()]

def is_valid_key(self, api_key: str) -> bool:
"""Check if API key is valid"""
if api_key == self.master_key:
return True

with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"SELECT is_active FROM api_keys WHERE key = ?",
(api_key,)
)
result = cursor.fetchone()
return bool(result and result[0])

def is_master_key(self, api_key: str) -> bool:
"""Check if key is the master key"""
return api_key == self.master_key

def _clean_old_requests(self):
"""Remove requests older than 1 hour"""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"DELETE FROM rate_limits WHERE request_time < datetime('now', '-1 hour')"
)

def check_rate_limit(self, api_key: str) -> tuple[bool, Dict]:
"""Check if request is within rate limits"""
with self.lock:
self._clean_old_requests()

is_master = self.is_master_key(api_key)
rate_limit = 1000 if is_master else 100 # Master key gets higher limit

with sqlite3.connect(self.db_path) as conn:
# Count recent requests
cursor = conn.execute("""
SELECT COUNT(*) FROM rate_limits
WHERE key = ? AND request_time > datetime('now', '-1 hour')
""", (api_key,))
count = cursor.fetchone()[0]

if count >= rate_limit:
# Get reset time
cursor = conn.execute("""
SELECT request_time FROM rate_limits
WHERE key = ?
ORDER BY request_time ASC
LIMIT 1
""", (api_key,))
oldest = cursor.fetchone()
reset_time = datetime.fromisoformat(oldest[0]) + timedelta(hours=1)

return False, {
"X-RateLimit-Limit": str(rate_limit),
"X-RateLimit-Reset": reset_time.isoformat()
}

# Record new request
conn.execute(
"INSERT INTO rate_limits (key, request_time) VALUES (?, datetime('now'))",
(api_key,)
)

return True, {
"X-RateLimit-Limit": str(rate_limit),
"X-RateLimit-Remaining": str(rate_limit - count - 1),
"X-RateLimit-Reset": (datetime.utcnow() + timedelta(hours=1)).isoformat()
}


# Create global instance
key_manager = APIKeyManager()


async def verify_api_key(api_key_header: str = Security(api_key_header)):
"""Verify API key and check rate limits"""
if not key_manager.is_valid_key(api_key_header):
raise HTTPException(
status_code=403,
detail="Invalid API key"
)

# Check rate limits
within_limit, headers = key_manager.check_rate_limit(api_key_header)
if not within_limit:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded",
headers=headers
)

return api_key_header


async def require_master_key(api_key_header: str = Security(api_key_header)):
"""Verify master API key"""
if not key_manager.is_master_key(api_key_header):
raise HTTPException(
status_code=403,
detail="Invalid master key"
)

# Check rate limits even for master key
within_limit, headers = key_manager.check_rate_limit(api_key_header)
if not within_limit:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded",
headers=headers
)

return True
45 changes: 45 additions & 0 deletions vali_utils/api/auth/key_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from fastapi import APIRouter, Depends
from .auth import require_master_key, key_manager
from pydantic import BaseModel
from typing import List
from vali_utils.api.utils import endpoint_error_handler

class APIKeyCreate(BaseModel):
name: str


class APIKeyResponse(BaseModel):
key: str
name: str


router = APIRouter(tags=["key management"])


@router.post("", response_model=APIKeyResponse)
@endpoint_error_handler
async def create_api_key(
request: APIKeyCreate,
_: bool = Depends(require_master_key)
):
"""Create new API key (requires master key)"""
key = key_manager.create_api_key(request.name)
return {"key": key, "name": request.name}


@router.get("")
@endpoint_error_handler
async def list_api_keys(_: bool = Depends(require_master_key)):
"""List all API keys (requires master key)"""
return {"keys": key_manager.list_api_keys()}


@router.post("/{key}/deactivate")
@endpoint_error_handler
async def deactivate_api_key(
key: str,
_: bool = Depends(require_master_key)
):
"""Deactivate an API key (requires master key)"""
key_manager.deactivate_api_key(key)
return {"status": "success"}
Loading

0 comments on commit a1835e8

Please sign in to comment.