Skip to content

Commit

Permalink
need the highres
Browse files Browse the repository at this point in the history
  • Loading branch information
tjlane committed Aug 21, 2024
1 parent 9856edd commit 4542b00
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 78 deletions.
35 changes: 17 additions & 18 deletions meteor/tv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import gemmi as gm
import gemmi

Check failure on line 2 in meteor/tv.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F401)

meteor/tv.py:2:8: F401 `gemmi` imported but unused
import reciprocalspaceship as rs

from skimage.restoration import denoise_tv_chambolle
Expand All @@ -12,6 +12,7 @@
compute_map_from_coefficients,
compute_coefficients_from_map,
resolution_limits,
numpy_array_to_map,
)
from .settings import (
TV_LAMBDA_RANGE,
Expand All @@ -23,12 +24,12 @@
)


def _tv_denoise_ccp4_map(*, map: gm.Ccp4Map, weight: float) -> np.ndarray:
def _tv_denoise_array(*, map_as_array: np.ndarray, weight: float) -> np.ndarray:
"""
Closure convienence function to generate more readable code.
"""
denoised_map = denoise_tv_chambolle(
np.array(map.grid),
map_as_array,
weight=weight,
eps=TV_STOP_TOLERANCE,
max_num_iter=TV_MAX_NUM_ITER,
Expand Down Expand Up @@ -59,12 +60,13 @@ def tv_denoise_difference_map(
map_sampling=TV_MAP_SAMPLING,
)

difference_map_as_array = np.array(difference_map.grid)

def negentropy_objective(tv_lambda: float):
denoised_map = _tv_denoise_ccp4_map(map=difference_map, weight=tv_lambda)
denoised_map = _tv_denoise_array(map_as_array=difference_map_as_array, weight=tv_lambda)
return negentropy(denoised_map.flatten())

optimal_lambda: float

if lambda_values_to_scan:
highest_negentropy = -1e8
for tv_lambda in lambda_values_to_scan:
Expand All @@ -81,25 +83,22 @@ def negentropy_objective(tv_lambda: float):
), "Golden minimization failed to find optimal TV lambda"
optimal_lambda = optimizer_result.x

final_map_array = _tv_denoise_ccp4_map(map=difference_map, weight=optimal_lambda)

# TODO: verify correctness
final_map = _tv_denoise_array(map_as_array=difference_map_as_array, weight=optimal_lambda)
final_map = numpy_array_to_map(
final_map,
spacegroup=difference_map_coefficients.spacegroup,
cell=difference_map_coefficients.cell
)

_, high_resolution_limit = resolution_limits(difference_map_coefficients)
_, dmin = resolution_limits(difference_map_coefficients)
final_map_coefficients = compute_coefficients_from_map(
map=final_map_array,
high_resolution_limit=high_resolution_limit,
ccp4_map=final_map,
high_resolution_limit=dmin,
amplitude_label=TV_AMPLITUDE_LABEL,
phase_label=TV_PHASE_LABEL,
)

# TODO: need to be sure HKLs line up
difference_map_coefficients[[TV_AMPLITUDE_LABEL]] = np.abs(final_map_coefficients)
difference_map_coefficients[[TV_PHASE_LABEL]] = np.angle(
final_map_coefficients, deg=True
)

return difference_map_coefficients
return final_map_coefficients


def iterative_tv_phase_retrieval(): ...
59 changes: 17 additions & 42 deletions meteor/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import gemmi as gm
import gemmi
import reciprocalspaceship as rs
from typing import overload, Literal

Expand Down Expand Up @@ -64,15 +64,23 @@ def canonicalize_amplitudes(
return None


def numpy_array_to_map(array: np.ndarray, *, spacegroup: str | int, cell: tuple[float, float, float, float, float, float]) -> gemmi.Ccp4Map:
ccp4_map = gemmi.Ccp4Map()
ccp4_map.grid = gemmi.FloatGrid(array, dtype=array.dtype)
ccp4_map.grid.unit_cell.set(*cell)
ccp4_map.grid.spacegroup = gemmi.SpaceGroup(spacegroup)
return ccp4_map


def compute_map_from_coefficients(
*,
map_coefficients: rs.DataSet,
amplitude_label: str,
phase_label: str,
map_sampling: int,
) -> gm.Ccp4Map:
) -> gemmi.Ccp4Map:
map_coefficients_gemmi_format = map_coefficients.to_gemmi()
ccp4_map = gm.Ccp4Map()
ccp4_map = gemmi.Ccp4Map()
ccp4_map.grid = map_coefficients_gemmi_format.transform_f_phi_to_map(
amplitude_label, phase_label, sample_rate=map_sampling
)
Expand All @@ -83,59 +91,26 @@ def compute_map_from_coefficients(

def compute_coefficients_from_map(
*,
map: np.ndarray | gm.Ccp4Map,
high_resolution_limit: float,
amplitude_label: str,
phase_label: str,
) -> rs.DataSet:
if isinstance(map, np.ndarray):
return _compute_coefficients_from_numpy_array(
map_array=map,
high_resolution_limit=high_resolution_limit,
amplitude_label=amplitude_label,
phase_label=phase_label,
)
elif isinstance(map, gm.Ccp4Map):
return _compute_coefficients_from_ccp4_map(
ccp4_map=map,
high_resolution_limit=high_resolution_limit,
amplitude_label=amplitude_label,
phase_label=phase_label,
)
else:
raise TypeError(f"invalid type {type(map)} for `map`")


def _compute_coefficients_from_numpy_array(
*,
map_array: np.ndarray,
high_resolution_limit: float,
amplitude_label: str,
phase_label: str,
) -> rs.DataSet: ...


def _compute_coefficients_from_ccp4_map(
*,
ccp4_map: gm.Ccp4Map,
ccp4_map: gemmi.Ccp4Map,
high_resolution_limit: float,
amplitude_label: str,
phase_label: str,
) -> rs.DataSet:
# to ensure we include the final shell of reflections, add a small buffer to the resolution
high_resolution_buffer = 0.05

gemmi_structure_factors = gm.transform_map_to_f_phi(ccp4_map.grid, half_l=False)

high_resolution_buffer = 1e-8
gemmi_structure_factors = gemmi.transform_map_to_f_phi(ccp4_map.grid, half_l=False)
data = gemmi_structure_factors.prepare_asu_data(
dmin=high_resolution_limit - high_resolution_buffer, with_sys_abs=True
)

mtz = gm.Mtz(with_base=True)
mtz = gemmi.Mtz(with_base=True)
mtz.spacegroup = gemmi_structure_factors.spacegroup
mtz.set_cell_for_all(gemmi_structure_factors.unit_cell)
mtz.add_dataset("FromMap")
mtz.add_column(amplitude_label, "F")
mtz.add_column(phase_label, "P")

mtz.set_data(data)
mtz.switch_to_asu_hkl()

Expand Down
18 changes: 13 additions & 5 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pytest import fixture
import reciprocalspaceship as rs
import numpy as np
import gemmi as gm
import gemmi
from meteor.utils import canonicalize_amplitudes


@fixture
Expand All @@ -11,8 +12,8 @@ def random_intensities() -> rs.DataSet:
"""

params = (10.0, 10.0, 10.0, 90.0, 90.0, 90.0)
cell = gm.UnitCell(*params)
sg_1 = gm.SpaceGroup(1)
cell = gemmi.UnitCell(*params)
sg_1 = gemmi.SpaceGroup(1)
Hall = rs.utils.generate_reciprocal_asu(cell, sg_1, 1.0, anomalous=False)

H, K, L = Hall.T
Expand All @@ -38,8 +39,8 @@ def flat_difference_map() -> rs.DataSet:
"""

params = (10.0, 10.0, 10.0, 90.0, 90.0, 90.0)
cell = gm.UnitCell(*params)
sg_1 = gm.SpaceGroup(1)
cell = gemmi.UnitCell(*params)
sg_1 = gemmi.SpaceGroup(1)
Hall = rs.utils.generate_reciprocal_asu(cell, sg_1, 5.0, anomalous=False)

H, K, L = Hall.T
Expand All @@ -58,4 +59,11 @@ def flat_difference_map() -> rs.DataSet:
ds.set_index(["H", "K", "L"], inplace=True)
ds["DF"] = ds["DF"].astype("SFAmplitude")

canonicalize_amplitudes(
ds,
amplitude_label="DF",
phase_label="PHIC",
inplace=True,
)

return ds
19 changes: 6 additions & 13 deletions test/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from meteor import utils
import reciprocalspaceship as rs
import pytest
import gemmi as gm
import gemmi
import pandas as pd
import numpy as np

Expand Down Expand Up @@ -85,24 +85,17 @@ def test_compute_map_from_coefficients(flat_difference_map: rs.DataSet) -> None:
phase_label="PHIC",
map_sampling=1,
)
assert isinstance(map, gm.Ccp4Map)
assert map.grid.shape == (6, 6, 6)
assert isinstance(map, gemmi.Ccp4Map)
assert map.grid.shape == (6,6,6)


@pytest.mark.parametrize("map_sampling", [1, 2, 3, 5])
def test_map_round_trip_ccp4_format(
@pytest.mark.parametrize("map_sampling", [1, 2, 2.25, 3, 5])
def test_map_to_coefficients_round_trip(
map_sampling: int, flat_difference_map: rs.DataSet
) -> None:
amplitude_label = "DF"
phase_label = "PHIC"

utils.canonicalize_amplitudes(
flat_difference_map,
amplitude_label=amplitude_label,
phase_label=phase_label,
inplace=True,
)

map = utils.compute_map_from_coefficients(
map_coefficients=flat_difference_map,
amplitude_label=amplitude_label,
Expand All @@ -113,7 +106,7 @@ def test_map_round_trip_ccp4_format(
_, dmin = utils.resolution_limits(flat_difference_map)

output_coefficients = utils.compute_coefficients_from_map(
map=map,
ccp4_map=map,
high_resolution_limit=dmin,
amplitude_label=amplitude_label,
phase_label=phase_label,
Expand Down

0 comments on commit 4542b00

Please sign in to comment.