Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing Validator Processing #234

Open
wants to merge 3 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions folding/validators/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from folding.utils.openmm_forcefields import FORCEFIELD_REGISTRY
from folding.validators.hyperparameters import HyperParameters
from folding.validators.reward import RewardPipeline
from folding.utils.ops import (
load_and_sample_random_pdb_ids,
get_response_info,
Expand Down Expand Up @@ -108,12 +109,12 @@ def run_step(
)
return event

energies, energy_event = get_energies(
protein=protein, responses=responses_serving, uids=active_uids
)
RP = RewardPipeline(protein=protein, responses=responses_serving, uids=active_uids)
RP.process_energies()
RP.check_run_validities()

# Log the step event.
event.update({"energies": energies.tolist(), **energy_event})
event.update({"energies": RP.energies.tolist(), **RP.event})

if len(protein.md_inputs) > 0:
event["md_inputs"] = list(protein.md_inputs.keys())
Expand Down
141 changes: 71 additions & 70 deletions folding/validators/protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
import base64
import random
import shutil
from pathlib import Path

from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Literal
from typing import Dict, List, Literal, Tuple

import bittensor as bt
import numpy as np
import pandas as pd
import plotly.express as px

from openmm import app, unit
from pdbfixer import PDBFixer

from folding.base.simulation import OpenMMSimulation
from folding.store import Job
from folding.base.simulation import OpenMMSimulation
from folding.utils.opemm_simulation_config import SimulationConfig
from folding.utils.ops import (
OpenMMException,
Expand Down Expand Up @@ -146,6 +148,14 @@ def _get_pdb_complexity(pdb_path):
pdb_complexity[key] += 1
return pdb_complexity

@staticmethod
def save_pdb(output_path: str, simulation: app.Simulation):
"""Save the pdb file to the output path."""
positions = simulation.context.getState(getPositions=True).getPositions()
topology = simulation.topology
with open(output_path, "w") as f:
app.PDBFile.writeFile(topology, positions, f)

def gather_pdb_id(self):
if self.pdb_id is None:
self.pdb_id = load_and_sample_random_pdb_ids(
Expand Down Expand Up @@ -365,17 +375,16 @@ def delete_files(self, directory: str):
# os.rmdir(output_directory)

def get_miner_data_directory(self, hotkey: str):
self.miner_data_directory = os.path.join(self.validator_directory, hotkey[:8])
return os.path.join(self.validator_directory, hotkey[:8])

def process_md_output(
self, md_output: dict, seed: int, state: str, hotkey: str
) -> bool:
) -> Tuple[bool, Dict]:
MIN_LOGGING_ENTRIES = 500
MIN_SIMULATION_STEPS = 5000

required_files_extensions = ["cpt", "log"]
hotkey_alias = hotkey[:8]
self.current_state = state

# This is just mapper from the file extension to the name of the file stores in the dict.
self.md_outputs_exts = {
Expand All @@ -386,69 +395,64 @@ def process_md_output(
bt.logging.warning(
f"Miner {hotkey_alias} returned empty md_output... Skipping!"
)
return False
return False, None

for ext in required_files_extensions:
if ext not in self.md_outputs_exts:
bt.logging.error(f"Missing file with extension {ext} in md_output")
return False
return False, None

self.get_miner_data_directory(hotkey=hotkey)
miner_data_directory = self.get_miner_data_directory(hotkey=hotkey)

# Save files so we can check the hash later.
self.save_files(
files=md_output,
output_directory=self.miner_data_directory,
output_directory=miner_data_directory,
)

try:
# NOTE: The seed written in the self.system_config is not used here
# because the miner could have used something different and we want to
# make sure that we are using the correct seed.

bt.logging.info(
f"Recreating miner {hotkey_alias} simulation in state: {self.current_state}"
f"Recreating miner {hotkey_alias} simulation in state: {state}"
)
self.simulation, self.system_config = self.create_simulation(
simulation, local_system_config = self.create_simulation(
pdb=self.load_pdb_file(pdb_file=self.pdb_location),
system_config=self.system_config.get_config(),
seed=seed,
state=state,
)

checkpoint_path = os.path.join(
self.miner_data_directory, f"{self.current_state}.cpt"
)
checkpoint_path = os.path.join(miner_data_directory, f"{state}.cpt")

log_file_path = os.path.join(
self.miner_data_directory, self.md_outputs_exts["log"]
miner_data_directory, self.md_outputs_exts["log"]
)

self.simulation.loadCheckpoint(checkpoint_path)
self.log_file = pd.read_csv(log_file_path)
self.log_step = self.log_file['#"Step"'].iloc[-1]
simulation.loadCheckpoint(checkpoint_path)
log_file = pd.read_csv(log_file_path)
log_step = log_file['#"Step"'].iloc[-1]

# Checks to see if we have enough steps in the log file to start validation
if len(self.log_file) < MIN_LOGGING_ENTRIES:
if len(log_file) < MIN_LOGGING_ENTRIES:
raise ValidationError(
f"Miner {hotkey_alias} did not run enough steps in the simulation... Skipping!"
)

# Make sure that we are enough steps ahead in the log file compared to the checkpoint file.
# Checks if log_file is MIN_STEPS steps ahead of checkpoint
if (self.log_step - self.simulation.currentStep) < MIN_SIMULATION_STEPS:
if (log_step - simulation.currentStep) < MIN_SIMULATION_STEPS:
# If the miner did not run enough steps, we will load the old checkpoint
checkpoint_path = os.path.join(
self.miner_data_directory, f"{self.current_state}_old.cpt"
)
checkpoint_path = os.path.join(miner_data_directory, f"{state}_old.cpt")
if os.path.exists(checkpoint_path):
bt.logging.warning(
f"Miner {hotkey_alias} did not run enough steps since last checkpoint... Loading old checkpoint"
)
self.simulation.loadCheckpoint(checkpoint_path)
simulation.loadCheckpoint(checkpoint_path)
# Checking to see if the old checkpoint has enough steps to validate
if (
self.log_step - self.simulation.currentStep
) < MIN_SIMULATION_STEPS:
if (log_step - simulation.currentStep) < MIN_SIMULATION_STEPS:
raise ValidationError(
f"Miner {hotkey_alias} did not run enough steps in the simulation... Skipping!"
)
Expand All @@ -457,31 +461,39 @@ def process_md_output(
f"Miner {hotkey_alias} did not run enough steps and no old checkpoint found... Skipping!"
)

self.cpt_step = self.simulation.currentStep
self.checkpoint_path = checkpoint_path

# Save the system config to the miner data directory
system_config_path = os.path.join(
self.miner_data_directory, f"miner_system_config_{seed}.pkl"
miner_data_directory, f"miner_system_config_{seed}.pkl"
)
if not os.path.exists(system_config_path):
write_pkl(
data=self.system_config,
data=local_system_config,
path=system_config_path,
write_mode="wb",
)

except ValidationError as E:
bt.logging.warning(f"{E}")
return False
return False, None

except Exception as e:
bt.logging.error(f"Failed to recreate simulation: {e}")
return False
return False, None

return True
return True, {
"simulation": simulation,
"log_file": log_file,
"log_step": log_step,
}

def is_run_valid(self):
def is_run_valid(
self,
simulation: app.Simulation,
state: str,
hotkey: str,
log_file: pd.DataFrame,
log_step: int,
) -> Tuple[bool, list, list]:
"""
Checks if the run is valid by comparing the potential energy values
between the current simulation and a reference log file.
Expand All @@ -493,41 +505,40 @@ def is_run_valid(self):

# The percentage that we allow the energy to differ from the miner to the validator.
ANOMALY_THRESHOLD = 0.5
miner_data_directory = self.get_miner_data_directory(hotkey=hotkey)

# Check to see if we have a logging resolution of 10 or better, if not the run is not valid
if (self.log_file['#"Step"'][1] - self.log_file['#"Step"'][0]) > 10:
return False
if (log_file['#"Step"'][1] - log_file['#"Step"'][0]) > 10:
return False, [], []

# Run the simulation at most 3000 steps
steps_to_run = min(3000, self.log_step - self.cpt_step)
steps_to_run = min(3000, log_step - self.simulation.currentStep)
mccrindlebrian marked this conversation as resolved.
Show resolved Hide resolved

self.simulation.reporters.append(
simulation.reporters.append(
app.StateDataReporter(
os.path.join(
self.miner_data_directory, f"check_{self.current_state}.log"
),
os.path.join(miner_data_directory, f"check_{state}.log"),
10,
step=True,
potentialEnergy=True,
)
)

bt.logging.info(
f"Running {steps_to_run} steps. log_step: {self.log_step}, cpt_step: {self.cpt_step}"
f"Running {steps_to_run} steps. log_step: {log_step}, cpt_step: {simulation.currentStep}"
)

self.simulation.step(steps_to_run)
simulation.step(steps_to_run)

check_log_file = pd.read_csv(
os.path.join(self.miner_data_directory, f"check_{self.current_state}.log")
os.path.join(miner_data_directory, f"check_{state}.log")
)

max_step = self.cpt_step + steps_to_run
max_step = simulation.currentStep + steps_to_run

check_energies: np.ndarray = check_log_file["Potential Energy (kJ/mole)"].values
miner_energies: np.ndarray = self.log_file[
(self.log_file['#"Step"'] > self.cpt_step)
& (self.log_file['#"Step"'] <= max_step)
miner_energies: np.ndarray = log_file[
(log_file['#"Step"'] > simulation.currentStep)
& (log_file['#"Step"'] <= max_step)
]["Potential Energy (kJ/mole)"].values

# calculating absolute percent difference per step
Expand All @@ -539,25 +550,23 @@ def is_run_valid(self):

fig = px.scatter(
df,
title=f"Energy: {self.pdb_id} for state {self.current_state} starting at checkpoint step: {self.cpt_step}",
title=f"Energy: {self.pdb_id} for state {state} starting at checkpoint step: {simulation.currentStep}",
labels={"index": "Step", "value": "Energy (kJ/mole)"},
height=600,
width=1400,
)
filename = f"{self.pdb_id}_cpt_step_{self.cpt_step}_state_{self.current_state}"
fig.write_image(
os.path.join(self.miner_data_directory, filename + "_energy.png")
)
filename = f"{self.pdb_id}_cpt_step_{simulation.currentStep}_state_{state}"
fig.write_image(os.path.join(miner_data_directory, filename + "_energy.png"))

fig = px.scatter(
percent_diff,
title=f"Percent Diff: {self.pdb_id} for state {self.current_state} starting at checkpoint step: {self.cpt_step}",
title=f"Percent Diff: {self.pdb_id} for state {state} starting at checkpoint step: {simulation.currentStep}",
labels={"index": "Step", "value": "Percent Diff"},
height=600,
width=1400,
)
fig.write_image(
os.path.join(self.miner_data_directory, filename + "_percent_diff.png")
os.path.join(miner_data_directory, filename + "_percent_diff.png")
)

median_percent_diff = np.median(percent_diff)
Expand All @@ -566,23 +575,15 @@ def is_run_valid(self):

if median_percent_diff > ANOMALY_THRESHOLD:
return False, check_energies.tolist(), miner_energies.tolist()
self.save_pdb(
output_path=os.path.join(
self.miner_data_directory, f"{self.pdb_id}_folded.pdb"
)

Protein.save_pdb(
output_path=os.path.join(miner_data_directory, f"{self.pdb_id}_folded.pdb")
)
return True, check_energies.tolist(), miner_energies.tolist()

def save_pdb(self, output_path: str):
"""Save the pdb file to the output path."""
positions = self.simulation.context.getState(getPositions=True).getPositions()
topology = self.simulation.topology
with open(output_path, "w") as f:
app.PDBFile.writeFile(topology, positions, f)
return True, check_energies.tolist(), miner_energies.tolist()

def get_energy(self):
state = self.simulation.context.getState(getEnergy=True)

return state.getPotentialEnergy() / unit.kilojoules_per_mole

def get_rmsd(self, output_path: str = None, xvg_command: str = "-xvg none"):
Expand Down
Loading