Skip to content

Commit

Permalink
Merge pull request #196 from thorstenwagner/2d_heatmap_selection
Browse files Browse the repository at this point in the history
Use 2D Histogram for visualization
  • Loading branch information
lazigu authored Mar 21, 2023
2 parents bb9e370 + bb661a7 commit e2ac783
Show file tree
Hide file tree
Showing 4 changed files with 372 additions and 57 deletions.
84 changes: 83 additions & 1 deletion napari_clusters_plotter/_Qt_code.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import typing
from pathlib import Path as PathL

import numpy as np
import numpy.typing
import pandas as pd
from magicgui.widgets import create_widget
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
Expand Down Expand Up @@ -351,6 +354,31 @@ def algorithm_choice(name: str, value, options: dict, label: str):
return container, choice_list


class SelectFrom2DHistogram:
def __init__(self, parent, ax, full_data):
self.parent = parent
self.ax = ax
self.canvas = ax.figure.canvas
self.xys = full_data

self.lasso = LassoSelector(ax, onselect=self.onselect)
self.ind = []
self.ind_mask = []

def onselect(self, verts):
path = Path(verts)

self.ind_mask = path.contains_points(self.xys)
self.ind = np.nonzero(self.ind_mask)[0]

if self.parent.manual_clustering_method is not None:
self.parent.manual_clustering_method(self.ind_mask)

def disconnect(self):
self.lasso.disconnect_events()
self.canvas.draw_idle()


# Class below was based upon matplotlib lasso selection example:
# https://matplotlib.org/stable/gallery/widgets/lasso_selector_demo_sgskip.html
class SelectFromCollection:
Expand Down Expand Up @@ -398,8 +426,9 @@ def __init__(self, parent, ax, collection, alpha_other=0.3):

def onselect(self, verts):
path = Path(verts)
self.ind = np.nonzero(path.contains_points(self.xys))[0]
self.ind_mask = path.contains_points(self.xys)
self.ind = np.nonzero(self.ind_mask)[0]

self.fc[:, -1] = self.alpha_other
self.fc[self.ind, -1] = 1
self.collection.set_facecolors(self.fc)
Expand All @@ -422,6 +451,7 @@ def __init__(self, parent=None, width=7, height=4, manual_clustering_method=None
self.manual_clustering_method = manual_clustering_method

self.axes = self.fig.add_subplot(111)
self.histogram = None

self.match_napari_layout()

Expand Down Expand Up @@ -462,6 +492,58 @@ def reset(self):
self.axes.clear()
self.is_pressed = None

def make_2d_histogram(
self,
data_x: "numpy.typing.ArrayLike",
data_y: "numpy.typing.ArrayLike",
colors: "typing.List[str]",
bin_number: int = 400,
log_scale: bool = False,
):
self.colors = colors
norm = None
if log_scale:
norm = "log"
h, xedges, yedges = np.histogram2d(data_x, data_y, bins=bin_number)
self.axes.imshow(
h.T,
extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
origin="lower",
cmap="magma",
aspect="auto",
norm=norm,
)
self.axes.set_xlim(xedges[0], xedges[-1])
self.axes.set_ylim(yedges[0], yedges[-1])
self.histogram = (h, xedges, yedges)

full_data = pd.concat([data_x, data_y], axis=1)
self.selector.disconnect()
self.selector = SelectFrom2DHistogram(self, self.axes, full_data)
self.axes.figure.canvas.draw_idle()

def make_scatter_plot(
self,
data_x: "numpy.typing.ArrayLike",
data_y: "numpy.typing.ArrayLike",
colors: "typing.List[str]",
sizes: "typing.List[float]",
alpha: "typing.List[float]",
):
self.pts = self.axes.scatter(
data_x,
data_y,
c=colors,
s=sizes,
alpha=alpha,
)
self.selector.disconnect()
self.selector = SelectFromCollection(
self,
self.axes,
self.pts,
)

def match_napari_layout(self):
"""Change background and axes colors to match napari layout"""
# changing color of axes background to napari main window color
Expand Down
Loading

0 comments on commit e2ac783

Please sign in to comment.