From 0cef75cda2da9d52700cdd2ed106f1249ddd9320 Mon Sep 17 00:00:00 2001 From: Guillaume Bouvignies Date: Wed, 29 May 2024 18:27:30 +0200 Subject: [PATCH] Freezing observed when changing the parameter `sw` in the module `pick_cest` (#237) Fixes #236 --- chemex/tools/pick_cest/buttons.py | 116 ++++++++++++++++++++---------- 1 file changed, 78 insertions(+), 38 deletions(-) diff --git a/chemex/tools/pick_cest/buttons.py b/chemex/tools/pick_cest/buttons.py index c4ea8859..b38f5b0f 100644 --- a/chemex/tools/pick_cest/buttons.py +++ b/chemex/tools/pick_cest/buttons.py @@ -26,54 +26,87 @@ class Buttons: + """A class to manage the interaction with Matplotlib plots for chemical shift analysis. + + Attributes: + fig (Figure): The Matplotlib figure object. + axis (Axes): The Matplotlib axes object. + data (dict): Data structure to hold curves for each spin system. + spin_systems (list): Sorted list of spin systems. + out (Path): Path to save output files. + spin_system (SpinSystem): Current spin system being analyzed. + curves (list): List of curves for the current spin system. + cs_a (dict): Dictionary to store chemical shift 'a' for each spin system. + cs_b (dict): Dictionary to store chemical shift 'b' for each spin system. + artists (list): List of Matplotlib artist objects. + sw (Optional[float]): Sweep width for the curves. + cursor (Cursor): Matplotlib cursor object for interaction. + index (int): Index to track the current spin system. + """ + def __init__( self, figure: Figure, axis: Axes, experiments: Experiments, path: Path, - sw: float | None, + sw: float | None = None, ) -> None: self.fig = figure self.axis = axis - self.data: dict[SpinSystem, list[Curve]] = {} - spin_systems: set[SpinSystem] = set() - - for experiment in experiments: - for profile in experiment: - spin_system = profile.spin_system - spin_systems.add(spin_system) - self.data.setdefault(spin_system, []).append(Curve(profile, sw)) - - self.spin_systems = sorted(spin_systems) + self.data: dict[SpinSystem, list[Curve]] = self._init_data(experiments, sw) + self.spin_systems = sorted(self.data.keys()) self.out = path - self.spin_system = SpinSystem(name="") self.curves: list[Curve] = [] self.cs_a: dict[SpinSystem, float | None] = {} self.cs_b: dict[SpinSystem, float | None] = {} self.artists: list[Artist] = [] self.sw = sw - self.cursor = Cursor(self.axis, horizOn=False, useblit=True) - self.index = -1 self.next() + @staticmethod + def _init_data( + experiments: Experiments, sw: float | None + ) -> dict[SpinSystem, list[Curve]]: + """Initialize data structure from experiments. + + Args: + experiments (Experiments): Container of experiments. + sw (Optional[float]): Sweep width for the curves. + + Returns: + dict: Data structure with spin systems and their corresponding curves. + """ + data = {} + for experiment in experiments: + for profile in experiment: + spin_system = profile.spin_system + if spin_system not in data: + data[spin_system] = [] + data[spin_system].append(Curve(profile, sw)) + return data + def _clear_artists(self) -> None: + """Remove all artist objects from the plot.""" while self.artists: self.artists.pop().remove() def _clear_axis(self) -> None: + """Clear the current axis and remove all artists.""" self._clear_artists() self.axis.clear() def _show_labels(self) -> None: + """Set the title and axis labels for the plot.""" self.axis.set_title(str(self.spin_system)) self.axis.set_xlabel(XLABELS[self.spin_system.nuclei["i"]]) self.axis.set_ylabel("$I/I_0$") def _plot_profiles(self) -> None: + """Plot the experimental profiles and their splines.""" if not self.curves: return xranges = np.concatenate([curve.get_xrange(self.sw) for curve in self.curves]) @@ -89,13 +122,20 @@ def _plot_profiles(self) -> None: self.fig.canvas.draw_idle() def _get_click_position(self, event: Event) -> float | None: - if not isinstance(event, LocationEvent): - return None - if event.inaxes != self.axis: - return None - return event.xdata + """Get the x-coordinate of a click event. + + Args: + event (Event): Matplotlib event object. + + Returns: + Optional[float]: The x-coordinate of the click, or None if invalid. + """ + if isinstance(event, LocationEvent) and event.inaxes == self.axis: + return event.xdata + return None def _add_line(self, position: float, state: Literal["a", "b"]) -> None: + """Add a vertical line and a text label to the plot.""" text_ = rf"$\varpi_{state}$ = {position:.3f} ppm" text = self.fig.text(0.82, TEXT_Y[state], text_) line = self.axis.axvline( @@ -107,13 +147,14 @@ def _add_line(self, position: float, state: Literal["a", "b"]) -> None: self.artists.extend([line, text]) def _add_text_dw(self, dw_ab: float) -> None: + """Add a text label for the chemical shift difference.""" text_ = rf"$\Delta\varpi_{{ab}}$ = {dw_ab:.3f} ppm" text = self.fig.text(0.82, 0.7, text_) self.artists.append(text) def _save(self) -> None: + """Save the chemical shift data to TOML files.""" self.out.mkdir(parents=True, exist_ok=True) - fname1 = self.out / "cs_a.toml" fname2 = self.out / "dw_ab.toml" @@ -137,7 +178,7 @@ def _save(self) -> None: file2.write(f"{str(name).upper():10s} = {dw_ab:8.3f}\n") def set_cs(self, event: Event) -> None: - """Set the chemical shift.""" + """Set the chemical shift based on a click event.""" xdata = self._get_click_position(event) if xdata is None: return @@ -153,6 +194,7 @@ def set_cs(self, event: Event) -> None: self._plot_lines() def _plot_lines(self) -> None: + """Plot the vertical lines for chemical shifts.""" key = self.spin_system cs_a = self.cs_a.get(key) cs_b = self.cs_b.get(key) @@ -169,47 +211,45 @@ def _plot_lines(self) -> None: self._add_text_dw(cs_b - cs_a) self._save() - self.fig.canvas.draw_idle() - self.fig.canvas.flush_events() def _plot(self, event: Event | None = None) -> None: + """Main plotting function to clear axis and plot profiles and lines.""" self._clear_axis() self._plot_lines() self._plot_profiles() self.fig.canvas.draw_idle() def _shift(self, step: int) -> None: - self.index += step - self.index %= len(self.spin_systems) + """Shift the current spin system index by a given step.""" + self.index = (self.index + step) % len(self.spin_systems) self.spin_system = self.spin_systems[self.index] self.curves = self.data[self.spin_system] - self._clear_axis() self._plot() def next(self, _event: Event | None = None) -> None: - """Go to next residue.""" - self._shift(+1) + """Go to the next residue.""" + self._shift(1) - def previous(self, _event: Event) -> None: - """Go to previous residue.""" + def previous(self, _event: Event | None = None) -> None: + """Go to the previous residue.""" self._shift(-1) def swap(self, event: Event) -> None: - """Swap peak peak positions for major/minor states.""" - name = self.spin_system - if self.cs_b[name] is not None: - self.cs_a[name], self.cs_b[name] = self.cs_b[name], self.cs_a[name] + """Swap peak positions for major/minor states.""" + key = self.spin_system + if self.cs_b[key] is not None: + self.cs_a[key], self.cs_b[key] = self.cs_b[key], self.cs_a[key] self._plot_lines() def clear(self, event: Event) -> None: - name = self.spin_system - self.cs_a[name], self.cs_b[name] = None, None + """Clear the chemical shifts for the current spin system.""" + key = self.spin_system + self.cs_a[key], self.cs_b[key] = None, None self._plot_lines() def set_sw(self, sw: float) -> None: + """Set the sweep width and update the plot.""" with contextlib.suppress(ValueError): self.sw = sw - - self._clear_axis() self._plot()