-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from alisiafadini/tjlane/tv-test
add simple tests to TV denoise implementation
- Loading branch information
Showing
12 changed files
with
646 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
*.ccp4 | ||
|
||
# C extensions | ||
*.so | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,4 @@ | ||
TV_LAMBDA_RANGE: tuple[float, float] = (1e-8, 1.0) | ||
TV_LAMBDA_RANGE: tuple[float, float] = (1e-3, 1.0) | ||
TV_STOP_TOLERANCE: float = 0.00000005 | ||
TV_MAX_NUM_ITER: int = 50 | ||
TV_MAP_SAMPLING: int = 3 | ||
TV_AMPLITUDE_LABEL: str = "dFtv" | ||
TV_PHASE_LABEL: str = "dPHItv" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,85 +1,192 @@ | ||
from dataclasses import dataclass | ||
from typing import Literal, Sequence, overload | ||
|
||
import numpy as np | ||
import gemmi as gm | ||
import reciprocalspaceship as rs | ||
|
||
from skimage.restoration import denoise_tv_chambolle | ||
from scipy.optimize import minimize_scalar | ||
from skimage.restoration import denoise_tv_chambolle | ||
|
||
from .utils import compute_map_from_coefficients, compute_coefficients_from_map | ||
from .settings import ( | ||
TV_LAMBDA_RANGE, | ||
TV_STOP_TOLERANCE, | ||
TV_MAP_SAMPLING, | ||
TV_MAX_NUM_ITER, | ||
TV_AMPLITUDE_LABEL, | ||
TV_PHASE_LABEL, | ||
TV_STOP_TOLERANCE, | ||
) | ||
from .utils import ( | ||
compute_coefficients_from_map, | ||
compute_map_from_coefficients, | ||
numpy_array_to_map, | ||
resolution_limits, | ||
) | ||
from .validate import negentropy | ||
|
||
|
||
@dataclass | ||
class TvDenoiseResult: | ||
optimal_lambda: float | ||
optimal_negentropy: float | ||
map_sampling_used_for_tv: float | ||
|
||
|
||
def _tv_denoise_array(*, map_as_array: np.ndarray, weight: float) -> np.ndarray: | ||
"""Closure convienence function to generate more readable code.""" | ||
denoised_map = denoise_tv_chambolle( | ||
map_as_array, | ||
weight=weight, | ||
eps=TV_STOP_TOLERANCE, | ||
max_num_iter=TV_MAX_NUM_ITER, | ||
) | ||
return denoised_map | ||
|
||
|
||
@overload | ||
def tv_denoise_difference_map( | ||
difference_map_coefficients: rs.DataSet, | ||
full_output: Literal[False], | ||
difference_map_amplitude_column: str = "DF", | ||
difference_map_phase_column: str = "PHIC", | ||
) -> tuple[rs.DataSet, float]: | ||
""" | ||
Returns: | ||
rs.Dataset: denoised dataset with new columns `DFtv`, `DPHItv` | ||
""" | ||
lambda_values_to_scan: Sequence[float] | None = None, | ||
) -> rs.DataSet: ... | ||
|
||
|
||
@overload | ||
def tv_denoise_difference_map( | ||
difference_map_coefficients: rs.DataSet, | ||
full_output: Literal[True], | ||
difference_map_amplitude_column: str = "DF", | ||
difference_map_phase_column: str = "PHIC", | ||
lambda_values_to_scan: Sequence[float] | None = None, | ||
) -> tuple[rs.DataSet, TvDenoiseResult]: ... | ||
|
||
|
||
def tv_denoise_difference_map( | ||
difference_map_coefficients: rs.DataSet, | ||
full_output: bool = False, | ||
difference_map_amplitude_column: str = "DF", | ||
difference_map_phase_column: str = "PHIC", | ||
lambda_values_to_scan: Sequence[float] | np.ndarray | None = None, | ||
) -> rs.DataSet | tuple[rs.DataSet, TvDenoiseResult]: | ||
"""Single-pass TV denoising of a difference map. | ||
Automatically selects the optimal level of regularization (the TV lambda parameter) by | ||
maximizing the negentropy of the denoised map. Two modes can be used to dictate which | ||
candidate values of lambda are assessed: | ||
1. By default (`lambda_values_to_scan=None`), the golden-section search algorithm selects | ||
a lambda value according to the bounds and convergence criteria set in meteor.settings. | ||
2. Alternatively, an explicit list of lambda values to assess can be provided using | ||
`lambda_values_to_scan`. | ||
Parameters | ||
---------- | ||
difference_map_coefficients : rs.DataSet | ||
The input dataset containing the difference map coefficients (amplitude and phase) | ||
that will be used to compute the difference map. | ||
full_output : bool, optional | ||
If `True`, the function returns both the denoised map coefficients and a `TvDenoiseResult` | ||
object containing the optimal lambda and the associated negentropy. If `False`, only | ||
the denoised map coefficients are returned. Default is `False`. | ||
difference_map_amplitude_column : str, optional | ||
The column name in `difference_map_coefficients` that contains the amplitude values for | ||
the difference map. Default is "DF". | ||
difference_map_phase_column : str, optional | ||
The column name in `difference_map_coefficients` that contains the phase values for the | ||
difference map. Default is "PHIC". | ||
lambda_values_to_scan : Sequence[float] | None, optional | ||
A sequence of lambda values to explicitly scan for determining the optimal value. If | ||
`None`, the function uses the golden-section search method to determine the optimal | ||
lambda. Default is `None`. | ||
Returns | ||
------- | ||
rs.DataSet | tuple[rs.DataSet, TvDenoiseResult] | ||
If `full_output` is `False`, returns a `rs.DataSet`, the denoised map coefficients. | ||
If `full_output` is `True`, returns a tuple containing: | ||
- `rs.DataSet`: The denoised map coefficients. | ||
- `TvDenoiseResult`: An object w/ the optimal lambda and the corresponding negentropy. | ||
Raises | ||
------ | ||
AssertionError | ||
If the golden-section search fails to find an optimal lambda. | ||
Notes | ||
----- | ||
- The function is designed to maximize the negentropy of the denoised map, which is a | ||
measure of the map's "randomness." | ||
Higher negentropy generally corresponds to a more informative and less noisy map. | ||
- The golden-section search is a robust method for optimizing unimodal functions, | ||
particularly suited for scenarios where | ||
an explicit list of candidate values is not provided. | ||
Example | ||
------- | ||
>>> coefficients = rs.read_mtz("./path/to/difference_map.mtz") # load dataset | ||
>>> denoised_map, result = tv_denoise_difference_map(coefficients, full_output=True) | ||
>>> print(f"Optimal Lambda: {result.optimal_lambda}, Negentropy: {result.optimal_negentropy}") | ||
""" | ||
difference_map = compute_map_from_coefficients( | ||
map_coefficients=difference_map_coefficients, | ||
amplitude_label=difference_map_amplitude_column, | ||
phase_label=difference_map_phase_column, | ||
map_sampling=TV_MAP_SAMPLING, | ||
) | ||
difference_map_as_array = np.array(difference_map.grid) | ||
|
||
def negentropy_objective(tv_lambda: float): | ||
return denoise_tv_chambolle( | ||
np.array(difference_map.grid), | ||
eps=TV_STOP_TOLERANCE, | ||
weight=tv_lambda, | ||
max_num_iter=TV_MAX_NUM_ITER, | ||
denoised_map = _tv_denoise_array(map_as_array=difference_map_as_array, weight=tv_lambda) | ||
return -1.0 * negentropy(denoised_map.flatten()) | ||
|
||
optimal_lambda: float | ||
|
||
# scan a specific set of lambda values and find the best one | ||
if lambda_values_to_scan is not None: | ||
# use no denoising as the default to beat | ||
optimal_lambda = 0.0 # initialization | ||
highest_negentropy = negentropy(difference_map_as_array.flatten()) | ||
|
||
for tv_lambda in lambda_values_to_scan: | ||
trial_negentropy = -1.0 * negentropy_objective(tv_lambda) | ||
if trial_negentropy > highest_negentropy: | ||
optimal_lambda = tv_lambda | ||
highest_negentropy = trial_negentropy | ||
|
||
# use golden ratio optimization to pick an optimal lambda | ||
else: | ||
optimizer_result = minimize_scalar( | ||
negentropy_objective, bracket=TV_LAMBDA_RANGE, method="golden" | ||
) | ||
|
||
optimizer_result = minimize_scalar( | ||
negentropy_objective, bracket=TV_LAMBDA_RANGE, method="golden" | ||
assert optimizer_result.success, "Golden minimization failed to find optimal TV lambda" | ||
optimal_lambda = optimizer_result.x | ||
highest_negentropy = negentropy_objective(optimal_lambda) | ||
|
||
# denoise using the optimized parameters and convert to an rs.DataSet | ||
final_map = _tv_denoise_array(map_as_array=difference_map_as_array, weight=optimal_lambda) | ||
final_map_as_ccp4 = numpy_array_to_map( | ||
final_map, | ||
spacegroup=difference_map_coefficients.spacegroup, | ||
cell=difference_map_coefficients.cell, | ||
) | ||
assert optimizer_result.success | ||
|
||
optimal_lambda: float = optimizer_result.x | ||
|
||
final_map_array: np.ndarray = denoise_tv_chambolle( | ||
np.array(difference_map.grid), | ||
eps=TV_STOP_TOLERANCE, | ||
weight=optimal_lambda, | ||
max_num_iter=TV_MAX_NUM_ITER, | ||
) | ||
|
||
# TODO: we may be able to simplify the code by going directly from a numpy | ||
# array to rs.DataSet here -- right now, we go through gemmi format | ||
|
||
ccp4_map = gm.Ccp4Map() | ||
|
||
ccp4_map.grid = gm.FloatGrid(final_map_array) | ||
ccp4_map.grid.set_unit_cell(gm.UnitCell(*difference_map_coefficients.cell)) | ||
ccp4_map.grid.set_size(difference_map_coefficients.get_reciprocal_grid_size()) | ||
ccp4_map.grid.spacegroup = gm.find_spacegroup_by_name( | ||
difference_map_coefficients.space_group | ||
) | ||
ccp4_map.grid.symmetrize_max() | ||
ccp4_map.update_ccp4_header() | ||
|
||
high_resolution_limit = np.min(difference_map_coefficients.compute_dHKL()) | ||
denoised_dataset = compute_coefficients_from_map( | ||
ccp4_map=ccp4_map, | ||
high_resolution_limit=high_resolution_limit, | ||
amplitude_label=TV_AMPLITUDE_LABEL, | ||
phase_label=TV_PHASE_LABEL, | ||
_, dmin = resolution_limits(difference_map_coefficients) | ||
final_map_coefficients = compute_coefficients_from_map( | ||
ccp4_map=final_map_as_ccp4, | ||
high_resolution_limit=dmin, | ||
amplitude_label=difference_map_amplitude_column, | ||
phase_label=difference_map_phase_column, | ||
) | ||
|
||
# ^^^ replace this with something better! | ||
|
||
return denoised_dataset | ||
if full_output: | ||
tv_result = TvDenoiseResult( | ||
optimal_lambda=optimal_lambda, | ||
optimal_negentropy=highest_negentropy, | ||
map_sampling_used_for_tv=TV_MAP_SAMPLING, | ||
) | ||
return final_map_coefficients, tv_result | ||
else: | ||
return final_map_coefficients | ||
|
||
|
||
def iterative_tv_phase_retrieval(): ... | ||
def iterative_tv_phase_retrieval(): | ||
raise NotImplementedError() |
Oops, something went wrong.