Skip to content

Commit

Permalink
Continue refactor for subplots
Browse files Browse the repository at this point in the history
  • Loading branch information
jimmymathews committed Apr 5, 2024
1 parent 0d981ac commit 1811a97
Showing 1 changed file with 90 additions and 101 deletions.
191 changes: 90 additions & 101 deletions analysis_replication/gnn_figure/graph_plugin_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from pandas import concat
from pandas import Series
import matplotlib.pyplot as plt
from matplotlib.pyplot import Axes
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
from matplotlib.colors import SymLogNorm
from scipy.stats import fisher_exact # type: ignore
from attr import define
from cattrs import structure as cattrs_structure
Expand Down Expand Up @@ -250,100 +250,96 @@ def _get_attribute_order(specification: PlotSpecification) -> list[str]:
return attribute_order


def plot_scatter_heatmap(df: DataFrame,
ax: plt.Axes,
title: str | None,
title_side: str,
label_phenotypes: bool,
label_cohorts: bool,
cmap: mcolors.ListedColormap,
norm: Normalize | None = None,
s_factor: float = 200,
):
df = df.sort_values('cohort')
df = df.sort_index()

# Extract the 'cohort' column
s_cohort = df['cohort']

# Add a row of NaN values between each cohort
dfs = []
cohorts = df['cohort'].unique()
for i, cohort in enumerate(cohorts):
df_cohort = df[df['cohort'] == cohort]
# Skip adding a NaN row before the first cohort
if i != 0:
df_cohort = concat([DataFrame([np.repeat(np.nan, df_cohort.shape[1])],
columns=df_cohort.columns, index=['']), df_cohort])
dfs.append(df_cohort)
df = concat(dfs)
df.index.name = 'Specimen by cohort'
df.drop('cohort', axis=1, level=0, inplace=True)

# Transpose the DataFrame
df = df.transpose().astype(float)

# Separate the p_value and p_important
df_p_value = df.xs('p_value', axis=0, level=1)
df_p_important = df.xs('important_fraction', axis=0, level=1)

# Clip the p_value to the range [0, 0.1] and normalize it to the range [0, 1]
df_p_value_clipped = df_p_value.clip(upper=0.05)
df_p_value_normalized = 1 - df_p_value_clipped / 0.05

# Use SymLogNorm for the colormap
if norm is None:
norm = mcolors.SymLogNorm(linthresh=0.001, linscale=0.001,
vmin=df_p_important.min().min(),
vmax=df_p_important.max().max(),
)

# Create a meshgrid for the cell centers
x = np.arange(df_p_important.shape[1]) + 0.5
y = np.arange(df_p_important.shape[0]) + 0.5
X, Y = np.meshgrid(x, y)

# Flatten the data and the sizes
c = df_p_important.values.flatten()
s = df_p_value_normalized.values.flatten() * s_factor # Scale up the sizes for visibility

# Plot the scatter plot
ax.scatter(X.flatten(), Y.flatten(), c=c, s=s, cmap=cmap, norm=norm, edgecolor='black')
ax.set_aspect('equal')
# ax.set_aspect((df_p_important.shape[0] + len(cohorts) - 1) / df_p_important.shape[1]*1.5)

# Invert the y-axis to match the heatmap orientation
ax.set_xlim(0, df_p_important.shape[1])
ax.set_ylim(0, df_p_important.shape[0])
ax.invert_yaxis()

# Turn off x-tick labels
ax.tick_params(axis='x', length=0)

# Add text annotations to label the cohorts
if label_cohorts:
start = 0
for cohort in cohorts:
df_cohort = s_cohort[s_cohort == cohort]
ax.text(start, -0.25, cohort if isinstance(cohort, str) else f'Cohort {cohort}',
ha='left', va='center')
start += df_cohort.shape[0] + 1 # Add 1 to account for the NaN row
ax.set_xticks([])

# Add the phenotype labels
if label_phenotypes:
ax.set_yticks(np.arange(df_p_important.shape[0]) + 0.5)
ax.set_yticklabels(df_p_important.index, rotation=0)
else:
ax.set_yticks([])
ax.yaxis.set_ticks_position('none')

if title is not None:
class SubplotGenerator:
@classmethod
def plot(
cls,
df: DataFrame,
ax: Axes,
title: str,
title_side: str,
label_vertically: bool,
label_horizontally: bool,
norm: Normalize,
s_factor: float = 200,
) -> None:
# Extract the 'cohort' column
s_cohort = df['cohort']

# Add a row of NaN values between each cohort
dfs = []
cohorts = df['cohort'].unique()
for i, cohort in enumerate(cohorts):
df_cohort = df[df['cohort'] == cohort]
# Skip adding a NaN row before the first cohort
if i != 0:
df_cohort = concat([DataFrame([np.repeat(np.nan, df_cohort.shape[1])],
columns=df_cohort.columns, index=['']), df_cohort])
dfs.append(df_cohort)
df = concat(dfs)
df.index.name = 'Specimen by cohort'
df.drop('cohort', axis=1, level=0, inplace=True)

df = df.transpose().astype(float)

df_p_value = df.xs('p_value', axis=0, level=1)
df_p_important = df.xs('important_fraction', axis=0, level=1)

# Clip the p_value to the range [0, 0.1] and normalize it to the range [0, 1]
df_p_value_clipped = df_p_value.clip(upper=0.05)
df_p_value_normalized = 1 - df_p_value_clipped / 0.05

# Create a meshgrid for the cell centers
x = np.arange(df_p_important.shape[1]) + 0.5
y = np.arange(df_p_important.shape[0]) + 0.5
X, Y = np.meshgrid(x, y)

# Flatten the data and the sizes
c = df_p_important.values.flatten()
s = df_p_value_normalized.values.flatten() * s_factor # Scale up the sizes for visibility

ax.scatter(X.flatten(), Y.flatten(), c=c, s=s, cmap=cls._get_main_heatmap_colormap(), norm=norm, edgecolor='black')
ax.set_aspect('equal')
# ax.set_aspect((df_p_important.shape[0] + len(cohorts) - 1) / df_p_important.shape[1]*1.5)

# Invert the y-axis to match the heatmap orientation
ax.set_xlim(0, df_p_important.shape[1])
ax.set_ylim(0, df_p_important.shape[0])
ax.invert_yaxis()

# Turn off x-tick labels
ax.tick_params(axis='x', length=0)

# Add text annotations to label the cohorts
if label_horizontally:
start = 0
for cohort in cohorts:
df_cohort = s_cohort[s_cohort == cohort]
ax.text(start, -0.25, cohort if isinstance(cohort, str) else f'Cohort {cohort}',
ha='left', va='center')
start += df_cohort.shape[0] + 1 # Add 1 to account for the NaN row
ax.set_xticks([])

# Add the phenotype labels
if label_vertically:
ax.set_yticks(np.arange(df_p_important.shape[0]) + 0.5)
ax.set_yticklabels(df_p_important.index, rotation=0)
else:
ax.set_yticks([])
ax.yaxis.set_ticks_position('none')

if title_side == 'bottom':
ax.set_xlabel(title)
else:
ax.text(1.021, .5, title, rotation=-90, ha='right', va='center', transform=ax.transAxes)

@staticmethod
def _get_main_heatmap_colormap():
colors = ['white', 'red']
cmap = mcolors.LinearSegmentedColormap.from_list('', colors)
cmap.set_under(color='white')
return cmap


@define
class SubplotSpecification:
Expand Down Expand Up @@ -428,6 +424,7 @@ def _update_cohorts(
dfs = tuple(df.copy() for df in dfs)
for _, df in enumerate(dfs):
df['cohort'] = df['cohort'].map(cohort_map)
dfs = tuple(df.sort_values('cohort').sort_index() for df in dfs)
return dfs

def _gather_subplot_cases(
Expand All @@ -439,17 +436,17 @@ def _gather_subplot_cases(
return subplot_specification, indicators, zip(dfs, self.get_specification().plugins)

def _generate_plot(self, dfs: tuple[DataFrame, ...]) -> None:
plt.rcParams['font.size'] = 14
norm = self._generate_normalization(dfs)
subplot_specification, indicators, cases = self._gather_subplot_cases(dfs)
fig, axs = plt.subplots(
*subplot_specification.grid_dimensions,
figsize=self.get_specification().figure_size,
)
title_location = subplot_specification.title_location
cmap = self._get_main_heatmap_colormap()
for i, ((df, plugin), ax) in enumerate(zip(cases, axs)):
plot_scatter_heatmap(
df, ax, plugin, title_location, indicators[0][i], indicators[1][i], cmap, norm,
SubplotGenerator.plot(
df, ax, plugin, title_location, indicators[0][i], indicators[1][i], norm,
)
fig.suptitle(self.get_specification().study)
plt.tight_layout()
Expand All @@ -460,13 +457,6 @@ def _export(self) -> None:
if self.show_also:
plt.show()

@staticmethod
def _get_main_heatmap_colormap():
colors = ['white', 'red']
cmap = mcolors.LinearSegmentedColormap.from_list('', colors)
cmap.set_under(color='white')
return cmap

@staticmethod
def _generate_normalization(dfs: tuple[DataFrame, ...]) -> Normalize:
dfs_values_only = tuple(df.drop('cohort', axis=1, level=0) for df in dfs)
Expand All @@ -482,6 +472,5 @@ def _generate_normalization(dfs: tuple[DataFrame, ...]) -> Normalize:
add('output_directory', nargs='?', type=str, default='.', help='Directory in which to save SVGs.')
add('--show', action='store_true', help='If set, will display figures in addition to saving.')
args = parser.parse_args()
plt.rcParams['font.size'] = 14
generator = PlotGenerator(args.host, args.output_directory, args.show)
generator.generate_plots()

0 comments on commit 1811a97

Please sign in to comment.