diff --git a/alphafold/model/all_atom.py b/alphafold/model/all_atom.py index c8ebe8b08..f2d1eca2d 100644 --- a/alphafold/model/all_atom.py +++ b/alphafold/model/all_atom.py @@ -694,7 +694,7 @@ def between_residue_bond_loss( ca_c_n_cos_angle = jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1) gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] - gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + gt_stddev = residue_constants.between_res_cos_angles_ca_c_n[1] ca_c_n_cos_angle_error = jnp.sqrt( 1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle)) ca_c_n_loss_per_residue = jax.nn.relu(