Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

script to isolate internal msa issue #366

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,14 @@ def prediction_to_mmcif(pred_atom_pos: Union[np.ndarray, paddle.Tensor],
- maxit_binary: path to maxit_binary, use to convert pdb to cif
- mmcif_path: path to save *.cif
"""
assert maxit_binary is not None and os.path.exists(maxit_binary), (
f'maxit_binary: {maxit_binary} not exists. '
f'link: https://sw-tools.rcsb.org/apps/MAXIT/source.html')
# assert maxit_binary is not None and os.path.exists(maxit_binary), (
# f'maxit_binary: {maxit_binary} not exists. '
# f'link: https://sw-tools.rcsb.org/apps/MAXIT/source.html')
assert mmcif_path.endswith('.cif'), f'mmcif_path should endswith .cif; got {mmcif_path}'

pdb_path = mmcif_path.replace('.cif', '.pdb')
pdb_path = prediction_to_pdb(pred_atom_pos, FeatsDict, pdb_path)
msg = os.system(f'{maxit_binary} -i {pdb_path} -o 1 -output {mmcif_path}')
msg = os.system(f'structconvert -PDBx {pdb_path} {mmcif_path}')
if msg != 0:
print(f'convert pdb to cif failed, error message: {msg}')
return mmcif_path
55 changes: 29 additions & 26 deletions apps/protein_folding/helixfold3/helixfold/data/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.')

print("MSA SEQUENCE LENGTH", len(msa.sequences))
for sequence_index, sequence in enumerate(msa.sequences):
if sequence in seen_sequences:
continue
Expand Down Expand Up @@ -239,40 +241,40 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
except Exception as exc:
print(f'Task {task} generated an exception : {exc}')

msa_for_templates = msa_results['uniref90']['sto']
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(msa_for_templates)
# msa_for_templates = msa_results['uniref90']['sto']
# msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
# msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(msa_for_templates)

if self.template_searcher.input_format == 'sto':
pdb_templates_result = self.template_searcher.query(msa_for_templates)
elif self.template_searcher.input_format == 'a3m':
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
else:
raise ValueError('Unrecognized template input format: '
f'{self.template_searcher.input_format}')
# if self.template_searcher.input_format == 'sto':
# pdb_templates_result = self.template_searcher.query(msa_for_templates)
# elif self.template_searcher.input_format == 'a3m':
# uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
# pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
# else:
# raise ValueError('Unrecognized template input format: '
# f'{self.template_searcher.input_format}')

pdb_hits_out_path = os.path.join(
msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')
with open(pdb_hits_out_path, 'w') as f:
f.write(pdb_templates_result)
# pdb_hits_out_path = os.path.join(
# msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')
# with open(pdb_hits_out_path, 'w') as f:
# f.write(pdb_templates_result)

uniref90_msa = parsers.parse_stockholm(msa_results['uniref90']['sto'])
mgnify_msa = parsers.parse_stockholm(msa_results['mgnify']['sto'])

pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence)
# pdb_template_hits = self.template_searcher.get_template_hits(
# output_string=pdb_templates_result, input_sequence=input_sequence)

if self._use_small_bfd:
bfd_msa = parsers.parse_stockholm(msa_results['small_bfd']['sto'])
else:
raise ValueError("Doesn't support full BFD yet.")

templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
hits=pdb_template_hits,
query_pdb_code=None,
query_release_date=None)
# templates_result = self.template_featurizer.get_templates(
# query_sequence=input_sequence,
# hits=pdb_template_hits,
# query_pdb_code=None,
# query_release_date=None)

sequence_features = make_sequence_features(
sequence=input_sequence,
Expand All @@ -286,8 +288,9 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa))
logging.info('Final (deduplicated) MSA size: %d sequences.',
msa_features['num_alignments'][0])
logging.info('Total number of templates (NB: this can include bad '
'templates and is later filtered to top 4): %d.',
templates_result.features['template_domain_names'].shape[0])
# logging.info('Total number of templates (NB: this can include bad '
# 'templates and is later filtered to top 4): %d.',
# templates_result.features['template_domain_names'].shape[0])

return {**sequence_features, **msa_features, **templates_result.features}
# return {**sequence_features, **msa_features, **templates_result.features}
return {**sequence_features, **msa_features}
4 changes: 2 additions & 2 deletions apps/protein_folding/helixfold3/helixfold/data/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def _process_single_hit(
TemplateAtomMaskAllZerosError) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings.
warning = ('%s_%s (sum_probs: %s, rank: %s): feature extracting errors: '
warning = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
'%s, mmCIF parsing errors: %s'
% (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
str(e), parsing_result.errors))
Expand All @@ -826,7 +826,7 @@ def _process_single_hit(
else:
return SingleHitResult(features=None, error=None, warning=warning)
except Error as e:
error = ('%s_%s (sum_probs: %s, rank: %s): feature extracting errors: '
error = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
'%s, mmCIF parsing errors: %s'
% (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
str(e), parsing_result.errors))
Expand Down
38 changes: 21 additions & 17 deletions apps/protein_folding/helixfold3/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,24 +467,7 @@ def main(args):
msa_templ_data_pipeline_dict = get_msa_templates_pipeline(args)


### create model
model_config = config.model_config(args.model_name)
print(f'>>> model_config:\n{model_config}')

model = RunModel(model_config)

if (not args.init_model is None) and (not args.init_model == ""):
print(f"Load pretrain model from {args.init_model}")
pd_params = paddle.load(args.init_model)

has_opt = 'optimizer' in pd_params
if has_opt:
model.helixfold.set_state_dict(pd_params['model'])
else:
model.helixfold.set_state_dict(pd_params)

if args.precision == "bf16" and args.amp_level == "O2":
raise NotImplementedError("bf16 O2 is not supported yet.")

print(f"============ Data Loading ============")
job_base = pathlib.Path(args.input_json).stem
Expand All @@ -506,6 +489,27 @@ def main(args):
feature_dict['feat'] = batch_convert(feature_dict['feat'], add_batch=True)
feature_dict['label'] = batch_convert(feature_dict['label'], add_batch=True)

return
print(f"============ Model Loading ============")
### create model
model_config = config.model_config(args.model_name)
print(f'>>> model_config:\n{model_config}')

model = RunModel(model_config)

if (not args.init_model is None) and (not args.init_model == ""):
print(f"Load pretrain model from {args.init_model}")
pd_params = paddle.load(args.init_model)

has_opt = 'optimizer' in pd_params
if has_opt:
model.helixfold.set_state_dict(pd_params['model'])
else:
model.helixfold.set_state_dict(pd_params)

if args.precision == "bf16" and args.amp_level == "O2":
raise NotImplementedError("bf16 O2 is not supported yet.")

print(f"============ Start Inference ============")

infer_times = args.infer_times
Expand Down
20 changes: 20 additions & 0 deletions apps/protein_folding/helixfold3/isolate_msa_issue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from helixfold.data.pipeline_parallel import make_msa_features
from helixfold.data import parsers
import os

def get_text(p):
with open(p, 'r') as f:
return f.read()

def try_creating_msa_features(folder):
uniref90_msa = parsers.parse_stockholm(get_text(os.path.join(folder, 'uniref90_hits.sto')))
mgnify_msa = parsers.parse_stockholm(get_text(os.path.join(folder, 'mgnify_hits.sto')))
bfd_msa = parsers.parse_stockholm(get_text(os.path.join(folder, 'small_bfd_hits.sto')))
make_msa_features((uniref90_msa, bfd_msa, mgnify_msa))

if __name__ == '__main__':
failed_posebuster_folder = "/home/anindit/deep-affinity/experimental/users/anindit/posebuster_5SAK_ZRY/sto"
working_posebuster_folder = "/home/anindit/paddle-helix-fork/apps/protein_folding/helixfold3/output/posebuster_5SAK_ZRY/msas/protein_A/A-HF3"
msa_features = try_creating_msa_features(failed_posebuster_folder)


16 changes: 16 additions & 0 deletions apps/protein_folding/helixfold3/run_all_internal_xtal.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

DIRECTORY="data/internal_xtal_inputs"

# Check if the provided path is a valid directory
if [ ! -d "$DIRECTORY" ]; then
echo "Error: $DIRECTORY is not a valid directory."
exit 1
fi

echo "Files in $DIRECTORY:"
for FILE in "$DIRECTORY"/*; do
if [ -f "$FILE" ]; then
./run_infer.sh "$FILE"
fi
done
18 changes: 6 additions & 12 deletions apps/protein_folding/helixfold3/run_infer.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
#!/bin/bash

PYTHON_BIN="/usr/bin/python3" # changes to your python
ENV_BIN="/root/miniconda3/bin" # change to your env
MAXIT_SRC="PATH/TO/MAXIT/SRC" # changes to your MAXIT
export OBABEL_BIN="PATH/TO/OBABEL/BIN" # changes to your openbabel
PYTHON_BIN="/home/anindit/.conda/envs/helixfold/bin/python" # changes to your python
ENV_BIN="/home/anindit/.conda/envs/helixfold/bin" # change to your env
DATA_DIR="./data"
export PATH="$MAXIT_SRC/bin:$PATH"
export OBABEL_BIN="/opt/schrodinger2024-3/utilities/obabel"

CUDA_VISIBLE_DEVICES=0 "$PYTHON_BIN" inference.py \
--maxit_binary "$MAXIT_SRC/bin/maxit" \
--jackhmmer_binary_path "$ENV_BIN/jackhmmer" \
--hhblits_binary_path "$ENV_BIN/hhblits" \
--hhsearch_binary_path "$ENV_BIN/hhsearch" \
Expand All @@ -17,10 +14,7 @@ CUDA_VISIBLE_DEVICES=0 "$PYTHON_BIN" inference.py \
--hmmbuild_binary_path "$ENV_BIN/hmmbuild" \
--nhmmer_binary_path "$ENV_BIN/nhmmer" \
--preset='reduced_dbs' \
--bfd_database_path "$DATA_DIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt" \
--small_bfd_database_path "$DATA_DIR/small_bfd/bfd-first_non_consensus_sequences.fasta" \
--bfd_database_path "$DATA_DIR/small_bfd/bfd-first_non_consensus_sequences.fasta" \
--uniclust30_database_path "$DATA_DIR/uniclust30/uniclust30_2018_08/uniclust30_2018_08" \
--uniprot_database_path "$DATA_DIR/uniprot/uniprot.fasta" \
--pdb_seqres_database_path "$DATA_DIR/pdb_seqres/pdb_seqres.txt" \
--uniref90_database_path "$DATA_DIR/uniref90/uniref90.fasta" \
Expand All @@ -30,10 +24,10 @@ CUDA_VISIBLE_DEVICES=0 "$PYTHON_BIN" inference.py \
--ccd_preprocessed_path "$DATA_DIR/ccd_preprocessed_etkdg.pkl.gz" \
--rfam_database_path "$DATA_DIR/Rfam-14.9_rep_seq.fasta" \
--max_template_date=2020-05-14 \
--input_json data/demo_6zcy.json \
--input_json $1\
--output_dir ./output \
--model_name allatom_demo \
--init_model init_models/HelixFold3-240814.pdparams \
--infer_times 1 \
--infer_times 5 \
--diff_batch_size 1 \
--precision "fp32"
--precision "fp32"