From b7987f9389c159075bc6a253cf6f8e3b6f3deff1 Mon Sep 17 00:00:00 2001 From: "Leandro A. Bugnon" Date: Mon, 30 Dec 2024 17:07:30 -0300 Subject: [PATCH] Fixes to process cif files and use them in training --- scripts/process/mmcif.py | 3 ++- scripts/process/rcsb.py | 9 ++++++--- src/boltz/data/module/training.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/scripts/process/mmcif.py b/scripts/process/mmcif.py index 19bf147..7443358 100644 --- a/scripts/process/mmcif.py +++ b/scripts/process/mmcif.py @@ -966,6 +966,7 @@ def parse_mmcif( # noqa: C901, PLR0915, PLR0912 entity=entity.name, residues=residues, type=const.chain_type_ids["NONPOLYMER"], + sequence=None ) ) @@ -1119,4 +1120,4 @@ def parse_mmcif( # noqa: C901, PLR0915, PLR0912 mask=mask, ) - return ParsedStructure(data=data, info=info) + return ParsedStructure(data=data, info=info, covalents=[]) diff --git a/scripts/process/rcsb.py b/scripts/process/rcsb.py index 7443f3a..2886ec0 100644 --- a/scripts/process/rcsb.py +++ b/scripts/process/rcsb.py @@ -92,15 +92,18 @@ def finalize(outdir: Path) -> None: failed_count = 0 records = [] for record in records_dir.iterdir(): - path = records_dir / record + path = record try: with path.open("r") as f: records.append(json.load(f)) except: # noqa: E722 failed_count += 1 print(f"Failed to parse {record}") # noqa: T201 - print(f"Failed to parse {failed_count} entries)") # noqa: T201 - + if failed_count > 0: + print(f"Failed to parse {failed_count} entries.") # noqa: T201 + else: + print("All entries parsed successfully.") + # Save manifest outpath = outdir / "manifest.json" with outpath.open("w") as f: diff --git a/src/boltz/data/module/training.py b/src/boltz/data/module/training.py index 2616d2e..c96beee 100644 --- a/src/boltz/data/module/training.py +++ b/src/boltz/data/module/training.py @@ -115,7 +115,7 @@ def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input: for chain in record.chains: msa_id = chain.msa_id # Load the MSA for this chain, if any - if msa_id != -1: + if msa_id != -1 and msa_id != "": msa = np.load(msa_dir / f"{msa_id}.npz") msas[chain.chain_id] = MSA(**msa)