Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
TieuLongPhan committed Jul 19, 2024
1 parent ac472c3 commit d5f917c
Show file tree
Hide file tree
Showing 15 changed files with 28,910 additions and 532 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ __pycache__
*txt
Data/DPO/*
_test.py
*pkl
Empty file added Docs/Analysis/__init__.py
Empty file.
288 changes: 288 additions & 0 deletions Docs/Analysis/_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Optional
from matplotlib.axes import Axes

def create_pie_chart(data, column, ax=None, title=None, color_pallet="pastel"):
"""
Generates a pie chart for the specified column from a list of dictionaries.
Displays percentage labels inside the slices only and category names in an external
legend without percentages. Allows customization of the plot title, supporting LaTeX
formatted strings.
Parameters:
data (list of dict): Data to plot.
column (str): Column name to plot percentages for.
ax (matplotlib.axes.Axes, optional): Matplotlib axis object to plot on.
title (str, optional): Title for the pie chart, supports LaTeX formatted strings.
Returns:
matplotlib.axes.Axes: The axis with the pie chart.
"""
# Enable LaTeX formatting for better quality text rendering
plt.rc("text", usetex=True)
plt.rc("font", family="serif")

# Convert list of dictionaries to DataFrame
df = pd.DataFrame(data)

# Calculate percentage
percentage = df[column].value_counts(normalize=True) * 100

# Define a color palette using Seaborn
colors = sns.color_palette(color_pallet, len(percentage))

# Create pie plot
if ax is None:
fig, ax = plt.subplots()

wedges, texts, autotexts = ax.pie(
percentage,
startangle=90,
colors=colors,
autopct="%1.1f%%",
pctdistance=0.85,
explode=[0.05] * len(percentage),
)

# Draw a circle at the center of pie to make it look like a donut
centre_circle = plt.Circle((0, 0), 0.70, fc="white")
ax.add_artist(centre_circle)

# Equal aspect ratio ensures that pie is drawn as a circle.
ax.axis("equal")

# Add legend with category names only
ax.legend(
wedges,
[f"{label}" for label in percentage.index],
title=column,
loc="upper right",
bbox_to_anchor=(0.6, 0.1, 0.6, 1),
prop={"size": 16},
title_fontsize=16,
) # Set label font size

# Set title using LaTeX if provided, else default to a generic title
if title:
ax.set_title(title, fontsize=24)
else:
ax.set_title(f"Pie Chart of {column}", fontsize=32)

# Enhance the font size and color of the autotexts
for autotext in autotexts:
autotext.set_color("black")
autotext.set_fontsize(18)

return ax

def count_column_values(data, column):
"""
Count the occurrences of each unique value in the specified column from a list of
dictionaries. Treats all data types, including lists, as single entities by converting
lists to strings.
Parameters:
data (list of dict): The data to process.
column (str): The column to count values from.
Returns:
dict: A dictionary with keys as unique values (strings if lists) and values as the
count of occurrences.
"""
# Convert the list of dictionaries to a DataFrame
df = pd.DataFrame(data)

# Handle if the column contains lists
if df[column].dtype == object and df[column].apply(
lambda x: isinstance(x, list)).all():

df[column] = df[column].apply(lambda x: str(x))

# Count occurrences of each unique value
return df[column].value_counts().to_dict()



def plot_rules_distribution(
rules: Dict[str, int],
rule_type: str = 'single',
ax: Optional[Axes] = None,
title: Optional[str] = None,
refinement: bool = False,
threshold: float = 1,
remove: bool = True,
color_pallet: str = 'pastel'
) -> None:
"""
Plots the distribution of rules in a bar chart, optionally combining all entries under the threshold into a
single category 'Under 1%' if `refinement` is True.
Parameters:
rules (Dict[str, int]): Dictionary with rule counts keyed by rule name, where the values are counts.
rule_type (str, optional): Specifies the type of rules to plot ('single' or 'complex'). Default is 'single'.
ax (matplotlib.axes.Axes, optional): Matplotlib axis object to plot on. If None, a new figure is created.
title (str, optional): Optional title for the chart. If None, a default title based on `rule_type` is used.
refinement (bool, optional): If True, combines all percentages under the threshold into one category 'Under 1%'.
Default is False.
threshold (float, optional): The percentage threshold under which all categories are combined into 'Under 1%'
if `refinement` is True. Default is 1.
remove (bool, optional): If True, removes the last category from the plot. Default is True.
color_pallet (str, optional): Color palette to use for the plot. Default is 'pastel'.
Returns:
None: The function directly modifies the `ax` object or creates a new plot.
"""
# Calculate total counts for the rules
total_rules = sum(rules.values())

# Convert counts to percentages and optionally combine small values
if refinement:
refined_rules = {}
small_value_aggregate = 0
for key, value in rules.items():
percentage = value / total_rules * 100
if percentage < threshold:
small_value_aggregate += percentage
else:
refined_rules[key] = percentage
if small_value_aggregate > 0:
refined_rules['Under 1%'] = small_value_aggregate
percentages = list(refined_rules.values())
types_of_rules = list(refined_rules.keys())
if remove:
percentages = percentages[:-1]
types_of_rules = types_of_rules[:-1]
else:
percentages = [value / total_rules * 100 for value in rules.values()]
types_of_rules = list(rules.keys())

# Set style
sns.set(style="whitegrid")

# Enable LaTeX rendering in matplotlib
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}') # Ensure amsmath is loaded

# Create figure and axis if not provided
if ax is None:
fig, ax = plt.subplots(figsize=(10, 6), dpi=120)

# Plot the data
sns.barplot(ax=ax, x=types_of_rules, y=percentages, palette=color_pallet)
if title:
ax.set_title(title, fontsize=24)
else:
ax.set_title(f'Distribution of {rule_type.capitalize()} Rules', fontsize=16)
ax.set_xlabel('Type of Rings', fontsize=18)
ax.set_ylabel(r'Percentage (\%)', fontsize=18)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")


# Set font size for x-tick and y-tick labels
ax.tick_params(axis="x", labelsize=16)
ax.tick_params(axis="y", labelsize=16)

# Add text labels above the bars
for index, value in enumerate(percentages):
ax.text(index, value + 0.5, f'{value:.1f}%', ha='center', va='bottom', fontsize=18)

# Only show plot if ax is not provided (i.e., we created the figure here)
if ax is None:
plt.show()



def plot_heatmap(
data,
title="Heatmap of Test Counts by Topo Type and Reaction Step",
color_palette="coolwarm",
title_fontsize=24,
label_fontsize=20,
annot_fontsize=18,
cbar_label_fontsize=18,
legend_fontsize=24,
xtick_fontsize=18,
ytick_fontsize=18,
ax=None,
):
"""
Plots a heatmap based on the provided dataset with options for customization, specific
aggregation, and an enhanced legend.
Parameters:
data (list of dict): Data to be visualized.
title (str, optional): Title for the heatmap. Defaults to a generic title if none
provided.
color_palette (str, optional): Color palette for the heatmap.
Defaults to 'coolwarm'.
title_fontsize (int, optional): Font size for the title. Defaults to 16.
label_fontsize (int, optional): Font size for the axis labels. Defaults to 14.
annot_fontsize (int, optional): Font size for the annotations. Defaults to 12.
cbar_label_fontsize (int, optional): Font size for the color bar label.
Defaults to 12.
legend_fontsize (int, optional): Font size for the legend. Defaults to 10.
ax (matplotlib.axes.Axes, optional): Matplotlib axis object to plot on.
If none, a new figure is created.
"""
# Convert input data to DataFrame
df = pd.DataFrame(data)
df["Test"] = 1

# Define a custom aggregation function to calculate percentages
def custom_agg(series):
total = series.sum()
return (
total / len(data)
) * 100 # Dividing by the total number of data points and multiplying by 100

# Create pivot table for heatmap using the custom aggregation function
pivot_table = df.pivot_table(
index="Topo Type", columns="Reaction Step", values="Test", aggfunc=custom_agg
)

# Check if an axis is provided; if not, create a new figure and axis
if ax is None:
fig, ax = plt.subplots(figsize=(10, 8))

# Plot heatmap on the provided or created axis
heatmap = sns.heatmap(
pivot_table,
annot=True,
cmap=color_palette,
fmt=".1f",
ax=ax,
cbar_kws={"label": r"Percentage (\%)"},
)

# Customize the title and axis labels font size
ax.set_title(title, fontsize=title_fontsize)
ax.set_ylabel("Topo Type", fontsize=label_fontsize)
ax.set_xlabel("Reaction Step", fontsize=label_fontsize)
# ax.set_xticks

# Customize the font size of the annotations
for text in heatmap.texts:
text.set_fontsize(annot_fontsize)

# Set font size for x-tick and y-tick labels
ax.tick_params(axis="x", labelsize=xtick_fontsize)
ax.tick_params(axis="y", labelsize=ytick_fontsize)
# Customize the font size of the color bar label
heatmap.figure.axes[-1].yaxis.label.set_size(cbar_label_fontsize)

# Create a legend with specified font size
handles, labels = ax.get_legend_handles_labels()
if handles:
ax.legend(
handles,
labels,
title="Legend",
loc="upper right",
bbox_to_anchor=(1.05, 1),
fontsize=legend_fontsize,
)

if not ax:
plt.show()
16,040 changes: 15,541 additions & 499 deletions Docs/Analysis/_templates_analysis.ipynb

Large diffs are not rendered by default.

53 changes: 53 additions & 0 deletions Docs/Analysis/_test_cluster_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pathlib
import logging
import time
import sys

# Setup the root directory based on the script's location
root_dir = pathlib.Path(__file__).resolve().parents[2]
sys.path.append(str(root_dir))
# Configure logging
logging.basicConfig(
filename=f"{root_dir}/Docs/Analysis/cluster_time.log",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)

# Importing necessary functions and classes
from SynTemp.SynUtils.utils import load_from_pickle
from SynTemp.SynRule.rule_cluster import RuleCluster
from SynTemp.SynRule.rules_extraction import RuleExtraction

# Load data
data = load_from_pickle(f"{root_dir}/Data/Temp/_its_correct.pkl.gz")
its_graphs = [value['ITSGraph'] for value in data]
cluster = RuleCluster()

# Process the data for different values of k and log the processing time
for k in range(4):
start_time = time.time() # Start time measurement

logging.info(f"Processing templates with k={k}")

if k > 0:
# Extract reaction rules with extension and k-nearest neighbors if k > 0
rc_graphs = [
RuleExtraction.extract_reaction_rules(*value, extend=True, n_knn=k)
for value in its_graphs
]
else:
# Extract reaction rules without extension if k = 0
rc_graphs = [
RuleExtraction.extract_reaction_rules(*value, extend=False)
for value in its_graphs
]

# Fit the rule clusters with the extracted graphs
cluster_indices, templates = cluster.fit(rc_graphs, templates=None, update_template=True)

end_time = time.time() # End time measurement
processing_time = end_time - start_time # Calculate processing time

# Log the processing time
logging.info(f"Finished processing for k={k} in {processing_time:.2f} seconds")

27 changes: 27 additions & 0 deletions Docs/Analysis/_test_hier_cluster_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pathlib
import logging
import sys

# Setup the root directory based on the script's location
root_dir = pathlib.Path(__file__).resolve().parents[2]
sys.path.append(str(root_dir))
# Configure logging
logging.basicConfig(
filename=f"{root_dir}/Docs/Analysis/hier_cluster_time.log",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)

# Importing necessary functions and classes
from SynTemp.SynUtils.utils import load_from_pickle

from SynTemp.SynRule.hierarchical_clustering import HierarchicalClustering
from SynTemp.SynRule.rules_extraction import RuleExtraction

# Load data
data = load_from_pickle(f"{root_dir}/Data/Temp/_its_correct.pkl.gz")
its_graphs = [value['ITSGraph'] for value in data]
cluster = HierarchicalClustering()
logging.info(f"Processing templates")
cluster.fit(data)

8 changes: 8 additions & 0 deletions Docs/Analysis/cluster_time.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
2024-07-18 09:36:48,099 - INFO - Processing templates with k=0
2024-07-18 09:37:45,916 - INFO - Finished processing for k=0 in 57.82 seconds
2024-07-18 09:37:45,916 - INFO - Processing templates with k=1
2024-07-18 09:42:20,938 - INFO - Finished processing for k=1 in 275.02 seconds
2024-07-18 09:42:20,939 - INFO - Processing templates with k=2
2024-07-18 10:12:28,514 - INFO - Finished processing for k=2 in 1807.58 seconds
2024-07-18 10:12:28,514 - INFO - Processing templates with k=3
2024-07-18 11:47:04,055 - INFO - Finished processing for k=3 in 5675.54 seconds
Loading

0 comments on commit d5f917c

Please sign in to comment.