diff --git a/.gitignore b/.gitignore index 928a95f..790df63 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ __pycache__ Data/DPO/* _test.py *pkl +Backup/* \ No newline at end of file diff --git a/Data/Temp/Backup/raw_results.json b/Data/Temp/Backup/raw_results.json new file mode 100644 index 0000000..ca512fc --- /dev/null +++ b/Data/Temp/Backup/raw_results.json @@ -0,0 +1,254 @@ +{ + "Valid": { + "fw": { + "Raw_0": [ + 4.04, + 7.86, + 94.87 + ], + "Raw_1": [ + 0.79, + 7.24, + 78.52 + ], + "Raw_2": [ + 0.33, + 6.5, + 69.74 + ], + "Raw_3": [ + 0.23, + 4.92, + 68.67 + ], + "Complete_0": [ + 46.88, + 90.62, + 92.56 + ], + "Complete_1": [ + 12.83, + 89.32, + 72.44 + ], + "Complete_2": [ + 3.26, + 85.25, + 51.6 + ], + "Complete_3": [ + 2.16, + 75.23, + 44.12 + ], + "Expand_0": [ + 64.93, + 90.82, + 93.25 + ], + "Expand_1": [ + 13.65, + 89.56, + 74.14 + ], + "Expand_2": [ + 3.61, + 85.75, + 53.33 + ], + "Expand_3": [ + 2.29, + 76.15, + 45.11 + ] + }, + "bw": { + "Raw_0": [ + 4.58, + 7.86, + 97.99 + ], + "Raw_1": [ + 0.61, + 7.24, + 84.84 + ], + "Raw_2": [ + 0.35, + 6.5, + 73.3 + ], + "Raw_3": [ + 0.2, + 4.92, + 64.07 + ], + "Complete_0": [ + 45.38, + 91.8, + 96.35 + ], + "Complete_1": [ + 15.88, + 90.04, + 89.78 + ], + "Complete_2": [ + 10.86, + 86.09, + 85.48 + ], + "Complete_3": [ + 7.93, + 76.09, + 82.03 + ], + "Expand_0": [ + 65.14, + 92.0, + 96.87 + ], + "Expand_1": [ + 16.23, + 90.28, + 89.87 + ], + "Expand_2": [ + 11.43, + 86.51, + 85.69 + ], + "Expand_3": [ + 8.34, + 76.93, + 82.24 + ] + } + }, + "Test": { + "fw": { + "Raw_0": [ + 4.04, + 7.92, + 95.21 + ], + "Raw_1": [ + 0.81, + 7.7, + 78.68 + ], + "Raw_2": [ + 0.35, + 6.62, + 71.35 + ], + "Raw_3": [ + 0.23, + 4.92, + 68.67 + ], + "Complete_0": [ + 45.68, + 90.38, + 92.54 + ], + "Complete_1": [ + 12.25, + 89.22, + 72.71 + ], + "Complete_2": [ + 3.22, + 84.45, + 51.74 + ], + "Complete_3": [ + 2.13, + 74.21, + 44.93 + ], + "Expand_0": [ + 63.16, + 90.58, + 93.19 + ], + "Expand_1": [ + 13.0, + 89.4, + 74.02 + ], + "Expand_2": [ + 3.55, + 84.95, + 53.56 + ], + "Expand_3": [ + 2.26, + 75.15, + 45.81 + ] + }, + "bw": { + "Raw_0": [ + 4.47, + 7.92, + 98.06 + ], + "Raw_1": [ + 0.59, + 7.7, + 84.28 + ], + "Raw_2": [ + 0.35, + 6.62, + 72.72 + ], + "Raw_3": [ + 0.2, + 4.92, + 64.07 + ], + "Complete_0": [ + 45.11, + 91.9, + 96.1 + ], + "Complete_1": [ + 15.83, + 90.24, + 89.07 + ], + "Complete_2": [ + 10.84, + 85.45, + 85.02 + ], + "Complete_3": [ + 7.89, + 75.15, + 81.74 + ], + "Expand_0": [ + 65.2, + 92.1, + 96.63 + ], + "Expand_1": [ + 16.17, + 90.42, + 89.15 + ], + "Expand_2": [ + 11.36, + 85.91, + 85.18 + ], + "Expand_3": [ + 8.29, + 76.07, + 81.79 + ] + } + } +} \ No newline at end of file diff --git a/Data/Temp/Benchmark/raw_results.json b/Data/Temp/Benchmark/raw_results.json index ca512fc..fdb48b6 100644 --- a/Data/Temp/Benchmark/raw_results.json +++ b/Data/Temp/Benchmark/raw_results.json @@ -22,44 +22,44 @@ 68.67 ], "Complete_0": [ - 46.88, - 90.62, - 92.56 + 71.13, + 94.5, + 97.18 ], "Complete_1": [ - 12.83, - 89.32, - 72.44 + 22.88, + 92.92, + 89.08 ], "Complete_2": [ - 3.26, - 85.25, - 51.6 + 5.3, + 88.6, + 67.5 ], "Complete_3": [ - 2.16, - 75.23, - 44.12 - ], - "Expand_0": [ - 64.93, - 90.82, - 93.25 - ], - "Expand_1": [ - 13.65, - 89.56, - 74.14 - ], - "Expand_2": [ - 3.61, - 85.75, - 53.33 - ], - "Expand_3": [ - 2.29, - 76.15, - 45.11 + 3.22, + 78.03, + 58.88 + ], + "Refine_0": [ + 84.29, + 94.7, + 97.66 + ], + "Refine_1": [ + 23.78, + 93.16, + 89.58 + ], + "Refine_2": [ + 5.73, + 89.12, + 68.68 + ], + "Refine_3": [ + 3.3, + 78.93, + 58.97 ] }, "bw": { @@ -84,170 +84,44 @@ 64.07 ], "Complete_0": [ - 45.38, - 91.8, - 96.35 - ], - "Complete_1": [ - 15.88, - 90.04, - 89.78 - ], - "Complete_2": [ - 10.86, - 86.09, - 85.48 - ], - "Complete_3": [ - 7.93, - 76.09, - 82.03 - ], - "Expand_0": [ - 65.14, - 92.0, - 96.87 - ], - "Expand_1": [ - 16.23, - 90.28, - 89.87 - ], - "Expand_2": [ - 11.43, - 86.51, - 85.69 - ], - "Expand_3": [ - 8.34, - 76.93, - 82.24 - ] - } - }, - "Test": { - "fw": { - "Raw_0": [ - 4.04, - 7.92, - 95.21 - ], - "Raw_1": [ - 0.81, - 7.7, - 78.68 - ], - "Raw_2": [ - 0.35, - 6.62, - 71.35 - ], - "Raw_3": [ - 0.23, - 4.92, - 68.67 - ], - "Complete_0": [ - 45.68, - 90.38, - 92.54 - ], - "Complete_1": [ - 12.25, - 89.22, - 72.71 - ], - "Complete_2": [ - 3.22, - 84.45, - 51.74 - ], - "Complete_3": [ - 2.13, - 74.21, - 44.93 - ], - "Expand_0": [ - 63.16, - 90.58, - 93.19 - ], - "Expand_1": [ - 13.0, - 89.4, - 74.02 - ], - "Expand_2": [ - 3.55, - 84.95, - 53.56 - ], - "Expand_3": [ - 2.26, - 75.15, - 45.81 - ] - }, - "bw": { - "Raw_0": [ - 4.47, - 7.92, - 98.06 - ], - "Raw_1": [ - 0.59, - 7.7, - 84.28 - ], - "Raw_2": [ - 0.35, - 6.62, - 72.72 - ], - "Raw_3": [ - 0.2, - 4.92, - 64.07 - ], - "Complete_0": [ - 45.11, - 91.9, - 96.1 + 73.26, + 93.46, + 97.86 ], "Complete_1": [ - 15.83, - 90.24, - 89.07 + 22.0, + 92.84, + 92.98 ], "Complete_2": [ - 10.84, - 85.45, - 85.02 + 13.92, + 88.5, + 89.29 ], "Complete_3": [ - 7.89, - 75.15, - 81.74 - ], - "Expand_0": [ - 65.2, - 92.1, - 96.63 - ], - "Expand_1": [ - 16.17, - 90.42, - 89.15 - ], - "Expand_2": [ - 11.36, - 85.91, - 85.18 - ], - "Expand_3": [ - 8.29, - 76.07, - 81.79 + 9.4, + 77.97, + 85.12 + ], + "Refine_0": [ + 93.4, + 93.66, + 98.11 + ], + "Refine_1": [ + 22.51, + 93.08, + 93.13 + ], + "Refine_2": [ + 14.55, + 89.02, + 89.43 + ], + "Refine_3": [ + 9.78, + 78.87, + 85.2 ] } } diff --git a/Docs/Analysis/_4_templates_analysis.ipynb b/Docs/Analysis/_4_templates_analysis.ipynb index d1e34f7..6d93188 100644 --- a/Docs/Analysis/_4_templates_analysis.ipynb +++ b/Docs/Analysis/_4_templates_analysis.ipynb @@ -25,22 +25,22 @@ "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", - "from syntemp.SynUtils.utils import train_val_test_split_df, save_database\n", + "# import pandas as pd\n", + "# from syntemp.SynUtils.utils import train_val_test_split_df, save_database\n", "\n", - "original_data = load_database(\"../../Data/Temp/data_aam.json.gz\")\n", - "original_data = pd.DataFrame(original_data)\n", + "# original_data = load_database(\"../../Data/Temp/data_aam.json.gz\")\n", + "# original_data = pd.DataFrame(original_data)\n", "\n", - "train, test, valid = train_val_test_split_df(original_data, target=\"class\")\n", - "train, test, valid = (\n", - " train.to_dict(\"records\"),\n", - " test.to_dict(\"records\"),\n", - " valid.to_dict(\"records\"),\n", - ")\n", + "# train, test, valid = train_val_test_split_df(original_data, target=\"class\")\n", + "# train, test, valid = (\n", + "# train.to_dict(\"records\"),\n", + "# test.to_dict(\"records\"),\n", + "# valid.to_dict(\"records\"),\n", + "# )\n", "\n", - "save_database(train, \"../../Data/Temp/Benchmark/train.json.gz\")\n", - "save_database(test, \"../../Data/Temp/Benchmark/test.json.gz\")\n", - "save_database(valid, \"../../Data/Temp/Benchmark/valid.json.gz\")" + "# save_database(train, \"../../Data/Temp/Benchmark/train.json.gz\")\n", + "# save_database(test, \"../../Data/Temp/Benchmark/test.json.gz\")\n", + "# save_database(valid, \"../../Data/Temp/Benchmark/valid.json.gz\")" ] }, { @@ -58,7 +58,9 @@ "source": [ "raw = load_from_pickle(\"../../Data/Temp/Benchmark/Raw/templates.pkl.gz\")\n", "complete = load_from_pickle(\"../../Data/Temp/Benchmark/Complete/templates.pkl.gz\")\n", - "expand = load_from_pickle(\"../../Data/Temp/Benchmark/Expand/templates.pkl.gz\")" + "complete_expand = load_from_pickle(\"../../Data/Temp/Benchmark/Complete_expand/templates.pkl.gz\")\n", + "refine = load_from_pickle(\"../../Data/Temp/Benchmark/Refine/templates.pkl.gz\")\n", + "refine_expand = load_from_pickle(\"../../Data/Temp/Benchmark/Refine_expand/templates.pkl.gz\")" ] }, { @@ -76,11 +78,15 @@ "\n", "raw_result = calculate(raw)\n", "complete_result = calculate(complete)\n", - "expand_result = calculate(expand)\n", + "complete_expand_result = calculate(complete_expand)\n", + "refine_result = calculate(refine)\n", + "refine_expand_result = calculate(refine_expand)\n", "\n", "print(raw_result)\n", "print(complete_result)\n", - "print(expand_result)" + "print(complete_expand_result)\n", + "print(refine_result)\n", + "print(refine_expand_result)" ] }, { @@ -109,7 +115,16 @@ "metadata": {}, "outputs": [], "source": [ - "len(temp_0)" + "len(data_cluster)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "34395 / 40012" ] }, { @@ -118,7 +133,7 @@ "metadata": {}, "outputs": [], "source": [ - "33690 / 40012" + "1-0.8596171148655404" ] }, { @@ -154,6 +169,13 @@ "- 37: C-I + OH" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -162,11 +184,13 @@ "source": [ "import matplotlib.pyplot as plt\n", "import matplotlib.gridspec as gridspec\n", - "from _analysis._plot_analysis import plot_top_rules_with_seaborn, load_and_title_png\n", + "#from _analysis._plot_analysis import plot_top_rules_with_seaborn, load_and_title_png\n", "\n", "# Set up the figure and GridSpec layout\n", "fig = plt.figure(figsize=(16, 10))\n", "gs = gridspec.GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1], figure=fig)\n", + "plt.rc(\"text\", usetex=True)\n", + "plt.rc(\"text.latex\", preamble=r\"\\usepackage{amsmath}\")\n", "\n", "# Create a subplot that spans the first row across both columns\n", "ax1 = fig.add_subplot(\n", @@ -307,8 +331,82 @@ "metadata": {}, "outputs": [], "source": [ - "from _analysis._plot_analysis import create_pie_chart\n", + "#from _analysis._plot_analysis import create_pie_chart\n", "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "\n", + "def create_pie_chart(data, column, ax=None, title=None, color_pallet=\"pastel\"):\n", + " \"\"\"\n", + " Generates a pie chart for the specified column from a list of dictionaries.\n", + " Displays percentage labels inside the slices only and category names in an external\n", + " legend without percentages. Allows customization of the plot title, supporting LaTeX\n", + " formatted strings.\n", + "\n", + " Parameters:\n", + " - data (list of dict): Data to plot.\n", + " - column (str): Column name to plot percentages for.\n", + " - ax (matplotlib.axes.Axes, optional): Matplotlib axis object to plot on.\n", + " - title (str, optional): Title for the pie chart, supports LaTeX formatted strings.\n", + "\n", + " Returns:\n", + " - matplotlib.axes.Axes: The axis with the pie chart.\n", + " \"\"\"\n", + " # Enable LaTeX formatting for better quality text rendering\n", + " plt.rc(\"text\", usetex=True)\n", + " plt.rc(\"font\", family=\"serif\")\n", + "\n", + " # Convert list of dictionaries to DataFrame\n", + " df = pd.DataFrame(data)\n", + "\n", + " # Calculate percentage\n", + " percentage = df[column].value_counts(normalize=True) * 100\n", + "\n", + " # Define a color palette using Seaborn\n", + " colors = sns.color_palette(color_pallet, len(percentage))\n", + "\n", + " # Create pie plot\n", + " if ax is None:\n", + " fig, ax = plt.subplots()\n", + "\n", + " wedges, texts, autotexts = ax.pie(\n", + " percentage,\n", + " startangle=90,\n", + " colors=colors,\n", + " autopct=\"%1.1f%%\",\n", + " pctdistance=0.85,\n", + " explode=[0.05] * len(percentage),\n", + " )\n", + "\n", + " # Draw a circle at the center of pie to make it look like a donut\n", + " centre_circle = plt.Circle((0, 0), 0.70, fc=\"white\")\n", + " ax.add_artist(centre_circle)\n", + "\n", + " # Equal aspect ratio ensures that pie is drawn as a circle.\n", + " ax.axis(\"equal\")\n", + "\n", + " # Add legend with category names only\n", + " ax.legend(\n", + " wedges,\n", + " [f\"{label}\" for label in percentage.index],\n", + " title=column,\n", + " loc=\"upper right\",\n", + " bbox_to_anchor=(0.6, 0.1, 0.6, 1),\n", + " prop={\"size\": 16},\n", + " title_fontsize=20,\n", + " ) # Set label font size\n", + "\n", + " # Set title using LaTeX if provided, else default to a generic title\n", + " if title:\n", + " ax.set_title(title, fontsize=32)\n", + " else:\n", + " ax.set_title(f\"Pie Chart of {column}\", fontsize=32)\n", + "\n", + " # Enhance the font size and color of the autotexts\n", + " for autotext in autotexts:\n", + " autotext.set_color(\"black\")\n", + " autotext.set_fontsize(18)\n", + "\n", + " return ax\n", "\n", "fig, axs = plt.subplots(2, 2, figsize=(16, 10)) # Adjust size as needed\n", "\n", @@ -407,7 +505,13 @@ "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "sns.set_theme(style=\"whitegrid\")\n", "\n", + "# Enable LaTeX rendering in matplotlib\n", + "plt.rc(\"text\", usetex=True)\n", + "plt.rc(\"text.latex\", preamble=r\"\\usepackage{amsmath}\") # Ensure amsmath is loaded\n", "fig, axs = plt.subplots(2, 2, figsize=(16, 12))\n", "\n", "\n", @@ -619,6 +723,15 @@ "triple = [value for value in temp_0 if value[\"Reaction Step\"] == 3]" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(single)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -647,6 +760,15 @@ "write_gml([double], double_path)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "write_gml([single], single_path)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/Docs/Analysis/_5_rule_application.ipynb b/Docs/Analysis/_5_rule_application.ipynb index 7f30030..139e360 100644 --- a/Docs/Analysis/_5_rule_application.ipynb +++ b/Docs/Analysis/_5_rule_application.ipynb @@ -19,6 +19,60 @@ "from _analysis._rule_app_analysis import automatic_results, save_results_to_json" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import *\n", + "from _analysis._rule_app_analysis import load_database, coverage_rate\n", + "def automatic_results(\n", + " test_types: List[str],\n", + " temp_types: List[str],\n", + " predict_types: List[str],\n", + " radii: List[int],\n", + " base_path=\"../../Data/Temp/Benchmark\",\n", + ") -> Dict[str, Dict[str, Tuple[float, float, float]]]:\n", + " \"\"\"\n", + " Automatically computes coverage rates for combinations of test type, template type,\n", + " predict type, and radii. Iterates over the provided parameter lists, loads data,\n", + " and computes statistics.\n", + "\n", + " Parameters:\n", + " - test_types (List[str]): List of test types.\n", + " - temp_types (List[str]): List of template types.\n", + " - predict_types (List[str]): List of prediction types.\n", + " - radii (List[int]): List of radii values.\n", + " - base_path (str): path to data\n", + "\n", + " Returns:\n", + " - Dict[str, Dict[str, Tuple[float, float, float]]]: A dictionary where the key\n", + " is the test type and the value is another dictionary. The inner dictionary's keys are\n", + " combinations of parameters as strings, and its values are tuples with the results from\n", + " `coverage_rate` (average solutions, coverage rate, false positive rate).\n", + " \"\"\"\n", + " all_results = {}\n", + "\n", + " for test in test_types:\n", + " test_results = {}\n", + " for predict in predict_types:\n", + " predict_results = {}\n", + " for temp in temp_types:\n", + " for rad in radii:\n", + " path = f\"{base_path}/{temp}/Output/{test}/{predict}_{rad}.json.gz\"\n", + " name = f\"{temp}_{rad}\"\n", + " data = load_database(path)\n", + " if data:\n", + " predict_results[name] = coverage_rate(data)\n", + " else:\n", + " predict_results[name] = (0.0, 0.0, 0.0)\n", + " test_results[predict] = predict_results\n", + " all_results[test] = test_results\n", + "\n", + " return all_results" + ] + }, { "cell_type": "code", "execution_count": null, @@ -26,13 +80,55 @@ "outputs": [], "source": [ "base_path = \"../../Data/Temp/Benchmark/\"\n", - "test_types = [\"Valid\", \"Test\"]\n", - "temp_types = [\"Raw\", \"Complete\", \"Expand\"]\n", + "test_types = [\"Valid\"]\n", + "# temp_types = [\"Raw\", \"Complete\", \"Complete_expand\", \"Refine\", \"Refine_expand\"]\n", + "temp_types = [\"Raw\", \"Complete\", \"Refine\"]\n", "predict_types = [\"fw\", \"bw\"]\n", "radius = [0, 1, 2, 3]\n", + "# radius = [0, 1]\n", "results = automatic_results(test_types, temp_types, predict_types, radius, base_path)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_results_to_json(results, \"../../Data/Temp/Benchmark/raw_results.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "results_df = pd.DataFrame(results['Valid']['fw'])\n", + "results_df.T" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "results_df = pd.DataFrame(results['Valid']['bw'])\n", + "results_df" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -72,7 +168,7 @@ " 0: \"average_solution\",\n", " # 1: r'\\mathcal(C)',\n", " 1: \"C\",\n", - " 2: \"FPR\",\n", + " 2: \"NR\",\n", " },\n", " inplace=True,\n", ")\n", @@ -81,7 +177,7 @@ " 0: \"average_solution\",\n", " # 1: r'\\mathcal(C)',\n", " 1: \"C\",\n", - " 2: \"FPR\",\n", + " 2: \"NR\",\n", " },\n", " inplace=True,\n", ")" @@ -96,6 +192,122 @@ "from _analysis._rule_app_analysis import plot_percentage" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fw['Type'].unique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_percentage(\n", + " df: pd.DataFrame,\n", + " ax: plt.Axes,\n", + " column: str,\n", + " title: str = \"A\",\n", + " color_map: Optional[List[str]] = None,\n", + " fontsettings: Optional[Dict[str, int]] = None,\n", + ") -> None:\n", + " \"\"\"\n", + " Plot a percentage bar chart for different categories and subcategories within the data.\n", + "\n", + " Parameters:\n", + " df (pd.DataFrame): DataFrame containing the data to plot. Index of the DataFrame\n", + " should be string labels in the format 'category_subcategory'.\n", + " ax (plt.Axes): Matplotlib Axes object where the chart will be drawn.\n", + " column (str): Column name in df that contains the percentage values to plot.\n", + " title (str, optional): Title of the plot. Default is 'A'.\n", + " color_map (List[str], optional): List of hex color strings for the bars. If None,\n", + " a default set of colors will be used.\n", + " fontsettings (Dict[str, int], optional): Dictionary containing font size settings\n", + " for various elements of the plot. If None,\n", + " default settings are applied.\n", + "\n", + " Returns:\n", + " None: This function does not return any value but modifies the ax object by drawing a bar chart.\n", + "\n", + " Example:\n", + " >>> fig, ax = plt.subplots()\n", + " >>> data = pd.DataFrame({'Value': [20, 30, 40, 50]}, index=['Type1_10', 'Type1_20', 'Type2_10', 'Type2_20'])\n", + " >>> plot_percentage(data, ax, 'Value')\n", + " >>> plt.show()\n", + " \"\"\"\n", + " if fontsettings is None:\n", + " fontsettings = {\n", + " \"title_size\": 18,\n", + " \"label_size\": 16,\n", + " \"ticks_size\": 16,\n", + " \"annotation_size\": 12,\n", + " }\n", + "\n", + " # Split the index into template type and radii\n", + " df[\"Type\"] = [i.split(\"_\")[0] for i in df.index]\n", + " df[\"Radii\"] = [int(i.split(\"_\")[1]) for i in df.index]\n", + "\n", + " # Sort data to group by type and then by radii\n", + " df = df.sort_values(by=[\"Radii\"])\n", + "\n", + " # Prepare color map for radii using coolwarm\n", + " if color_map is None:\n", + " color_map = [\"#3A8EBA\", \"#92C5DE\", \"#F4A582\", \"#D6604D\"]\n", + "\n", + " # Plotting logic with annotations\n", + " total_width = 3 # Total width for group\n", + " width = total_width / len(\n", + " df[\"Radii\"].unique()\n", + " ) # Width for each bar within each type group\n", + " type_positions = np.arange(len(df[\"Type\"].unique())) * (\n", + " len(df[\"Radii\"].unique()) + 1\n", + " )\n", + "\n", + " for i, t in enumerate(df[\"Type\"].unique()):\n", + " for j, r in enumerate(df[\"Radii\"].unique()):\n", + " #print(t)\n", + " bar_positions = type_positions[i] + j * width\n", + " heights = df[(df[\"Type\"] == t) & (df[\"Radii\"] == r)][column]\n", + " ax.bar(\n", + " bar_positions,\n", + " heights,\n", + " width=width,\n", + " label=f\"$R_{{{r}}}$\" if i == 0 else \"\",\n", + " color=color_map[j % len(color_map)],\n", + " )\n", + " # Adding annotations\n", + " for rect in ax.patches:\n", + " height = rect.get_height()\n", + " ax.annotate(\n", + " f\"{height:.1f}%\",\n", + " xy=(rect.get_x() + rect.get_width() / 2, height),\n", + " xytext=(0, 3), # 3 points vertical offset\n", + " textcoords=\"offset points\",\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " fontsize=fontsettings[\"annotation_size\"],\n", + " )\n", + "\n", + " # Enhancements like axes labeling, ticks setting, and adding grid\n", + " ax.set_ylabel(rf\"$\\mathcal{{{column}}} (\\%)$\", fontsize=fontsettings[\"label_size\"])\n", + " ax.set_title(title, fontsize=fontsettings[\"title_size\"], weight=\"medium\")\n", + " ax.set_xticks(type_positions + total_width / 2 - width / 2)\n", + " ax.set_xticklabels(\n", + " [f\"$Q_{{\\\\text{{{t}}}}}$\" for t in df[\"Type\"].unique()],\n", + " fontsize=fontsettings[\"ticks_size\"],\n", + " )\n", + " ax.set_yticks(np.arange(0, 101, 20))\n", + " ax.set_yticklabels(\n", + " [f\"{i}%\" for i in range(0, 101, 20)], fontsize=fontsettings[\"ticks_size\"]\n", + " )\n", + " ax.grid(True, which=\"major\", linestyle=\"--\", linewidth=\"0.5\", color=\"grey\")\n", + " ax.set_axisbelow(True)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -112,15 +324,15 @@ " \"title_size\": 32,\n", " \"label_size\": 28,\n", " \"ticks_size\": 28,\n", - " \"annotation_size\": 18,\n", + " \"annotation_size\": 20,\n", "}\n", "\n", "# Plotting data\n", "plot_percentage(fw, ax1, \"C\", title=r\"A\", fontsettings=fontsettings)\n", "plot_percentage(bw, ax2, \"C\", title=r\"B\", fontsettings=fontsettings)\n", "\n", - "plot_percentage(fw, ax3, \"FPR\", title=r\"C\", fontsettings=fontsettings)\n", - "plot_percentage(bw, ax4, \"FPR\", title=r\"D\", fontsettings=fontsettings)\n", + "plot_percentage(fw, ax3, \"NR\", title=r\"C\", fontsettings=fontsettings)\n", + "plot_percentage(bw, ax4, \"NR\", title=r\"D\", fontsettings=fontsettings)\n", "\n", "\n", "fig.legend(\n", @@ -164,7 +376,7 @@ "\n", "results = load_results_from_json(\"../../Data/Temp/Benchmark/raw_results.json\")\n", "\n", - "valid = results[\"Test\"]\n", + "valid = results[\"Valid\"]\n", "\n", "valid_fw = valid[\"fw\"]\n", "valid_bw = valid[\"bw\"]\n", @@ -175,7 +387,7 @@ " 0: \"average_solution\",\n", " # 1: r'\\mathcal(C)',\n", " 1: \"C\",\n", - " 2: \"FPR\",\n", + " 2: \"NR\",\n", " },\n", " inplace=True,\n", ")\n", @@ -184,7 +396,7 @@ " 0: \"average_solution\",\n", " # 1: r'\\mathcal(C)',\n", " 1: \"C\",\n", - " 2: \"FPR\",\n", + " 2: \"NR\",\n", " },\n", " inplace=True,\n", ")\n", @@ -193,6 +405,128 @@ "bw[[\"Type\", \"Radii\"]] = bw.index.to_series().str.split(\"_\", expand=True)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_roc_curves(\n", + " df: pd.DataFrame,\n", + " ax: plt.Axes,\n", + " selected_types: Optional[List[str]] = None,\n", + " fontsettings: Optional[Dict[str, int]] = None,\n", + " title: str = \"A\",\n", + " add_legend: bool = False,\n", + ") -> List[Any]:\n", + " \"\"\"\n", + " Plot ROC curves for specified types from a DataFrame on a given matplotlib Axes.\n", + "\n", + " Parameters:\n", + " df (pd.DataFrame): DataFrame containing the data for plotting. Must include columns 'Type', 'C' for TPR,\n", + " and 'FPR' for FPR, where 'Type' differentiates data series.\n", + " ax (plt.Axes): The matplotlib Axes object where the ROC curves will be drawn.\n", + " selected_types (Optional[List[str]]): List of strings representing the types to be included in the plot.\n", + " If None, all types in the DataFrame will be plotted.\n", + " fontsettings (Optional[Dict[str, int]]): Dictionary containing font settings for titles, labels,\n", + " ticks, and annotations. If None, defaults will be applied.\n", + " title (str): Title of the plot.\n", + " add_legend (bool): If True, add a legend to the plot.\n", + "\n", + " Returns:\n", + " List[Any]: List containing matplotlib line handles for the legend, useful if further customization\n", + " or reference is needed.\n", + "\n", + " Raises:\n", + " ValueError: If selected_types is provided and contains non-string elements.\n", + "\n", + " Example:\n", + " >>> fig, ax = plt.subplots()\n", + " >>> data = pd.DataFrame({\n", + " ... 'Type': ['Type1', 'Type1', 'Type2', 'Type2'],\n", + " ... 'C': [90, 85, 88, 80],\n", + " ... 'FPR': [5, 10, 5, 10]\n", + " ... })\n", + " >>> plot_roc_curves(data, ax, ['Type1', 'Type2'])\n", + " >>> plt.show()\n", + " \"\"\"\n", + " if selected_types is not None:\n", + " if not all(isinstance(t, str) for t in selected_types):\n", + " raise ValueError(\"selected_types must be a list of strings.\")\n", + " original_types = [t for t in selected_types if t in df[\"Type\"].unique()]\n", + " else:\n", + " original_types = df[\"Type\"].unique()\n", + "\n", + " types = [f\"$Q_{{\\\\text{{{t}}}}}$\" for t in original_types]\n", + "\n", + " if fontsettings is None:\n", + " fontsettings = {\n", + " \"title_size\": 28,\n", + " \"label_size\": 24,\n", + " \"ticks_size\": 24,\n", + " \"annotation_size\": 18,\n", + " }\n", + "\n", + " markers = [\"o\", \"^\", \"s\", \"p\"]\n", + " markers.reverse()\n", + " marker_labels = [r\"$R_{0}$\", r\"$R_{1}$\", r\"$R_{2}$\", r\"$R_{3}$\"]\n", + " marker_labels.reverse()\n", + " marker_color = \"gray\"\n", + "\n", + " colors = plt.cm.coolwarm(np.linspace(0, 1, len(types)))\n", + " colors = [\"#3A8EBA\", \"#D6604D\"]\n", + " legend_handles = []\n", + "\n", + " for index, type_ in enumerate(original_types):\n", + " type_data = df[df[\"Type\"] == type_]\n", + " tpr = type_data[\"C\"].tolist()\n", + " fpr = type_data[\"NR\"].tolist()\n", + " tpr = [x / 100 for x in tpr]\n", + " fpr = [x / 100 for x in fpr]\n", + " tpr.reverse()\n", + " fpr.reverse()\n", + "\n", + " (line,) = ax.plot(\n", + " fpr, tpr, linestyle=\"-\", color=colors[index], label=f\"{types[index]}\"\n", + " )\n", + " legend_handles.append(line)\n", + "\n", + " for i, (f, t) in enumerate(zip(fpr, tpr)):\n", + " marker = ax.plot(\n", + " f, t, marker=markers[i % len(markers)], color=marker_color\n", + " )[0]\n", + " if index == 1:\n", + " marker_handle = plt.Line2D(\n", + " [0],\n", + " [0],\n", + " marker=markers[i % len(markers)],\n", + " color=\"none\",\n", + " markerfacecolor=marker_color,\n", + " markersize=10,\n", + " label=marker_labels[i],\n", + " )\n", + " legend_handles.append(marker_handle)\n", + "\n", + " ax.set_xlabel(r\"$\\mathcal{NR}\\ (\\%)$\", fontsize=fontsettings[\"label_size\"])\n", + " ax.set_ylabel(r\"$\\mathcal{C}\\ (\\%)$\", fontsize=fontsettings[\"label_size\"])\n", + " ax.set_title(rf\"{title}\", fontsize=fontsettings[\"title_size\"], weight=\"medium\")\n", + " ax.tick_params(axis=\"both\", which=\"major\", labelsize=fontsettings[\"ticks_size\"])\n", + " ax.grid(True)\n", + "\n", + " if add_legend:\n", + " ax.legend(\n", + " handles=legend_handles,\n", + " loc=\"lower right\",\n", + " fancybox=True,\n", + " title_fontsize=fontsettings[\"label_size\"],\n", + " fontsize=fontsettings[\"annotation_size\"],\n", + " ncol=3,\n", + " )\n", + "\n", + " ax.grid(True)\n", + " return legend_handles" + ] + }, { "cell_type": "code", "execution_count": null, @@ -201,7 +535,7 @@ "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", - "from _analysis._rule_app_analysis import plot_roc_curves\n", + "#from _analysis._rule_app_analysis import plot_roc_curves\n", "\n", "plt.rc(\"text\", usetex=True)\n", "plt.rc(\"text.latex\", preamble=r\"\\usepackage{amsmath}\") # Ensure amsmath is loaded\n", @@ -215,14 +549,14 @@ "plot_roc_curves(\n", " fw,\n", " axs[0],\n", - " selected_types=[\"Complete\", \"Expand\"],\n", + " selected_types=[\"Complete\", \"Refine\"],\n", " fontsettings=fontsettings,\n", " title=\"A\",\n", ")\n", "legend_handles = plot_roc_curves(\n", " bw,\n", " axs[1],\n", - " selected_types=[\"Complete\", \"Expand\"],\n", + " selected_types=[\"Complete\", \"Refine\"],\n", " fontsettings=fontsettings,\n", " title=\"B\",\n", ")\n", @@ -239,7 +573,7 @@ "\n", "fig.tight_layout()\n", "fig.subplots_adjust(hspace=0.15, wspace=0.2, bottom=0.2)\n", - "fig.savefig(\"./fig/ROC_test.pdf\", dpi=600, bbox_inches=\"tight\", pad_inches=0)" + "#fig.savefig(\"./fig/ROC_test.pdf\", dpi=600, bbox_inches=\"tight\", pad_inches=0)" ] }, { @@ -272,19 +606,19 @@ "bw = pd.DataFrame(valid_bw).T\n", "fw.rename(\n", " columns={\n", - " 0: \"average_solution\",\n", + " 0: \"AG\",\n", " # 1: r'\\mathcal(C)',\n", " 1: \"C\",\n", - " 2: \"FPR\",\n", + " 2: \"NR\",\n", " },\n", " inplace=True,\n", ")\n", "bw.rename(\n", " columns={\n", - " 0: \"average_solution\",\n", + " 0: \"AG\",\n", " # 1: r'\\mathcal(C)',\n", " 1: \"C\",\n", - " 2: \"FPR\",\n", + " 2: \"NR\",\n", " },\n", " inplace=True,\n", ")\n", @@ -301,19 +635,28 @@ "source": [ "def gmean(tpr, fpr):\n", " tnr = 1 - fpr # True Negative Rate\n", - " g_mean = np.sqrt(tpr * tnr)\n", + " g_mean = (tpr+tnr)/2\n", " return g_mean\n", "\n", "\n", "# Calculate G-mean for each row and add it as a new column\n", - "fw[\"G-mean-forward\"] = fw.apply(\n", - " lambda row: gmean(row[\"C\"] / 100, row[\"FPR\"] / 100), axis=1\n", + "fw[\"Sc-forward\"] = fw.apply(\n", + " lambda row: gmean(row[\"C\"] / 100, row[\"NR\"] / 100), axis=1\n", ")\n", - "bw[\"G-mean-backward\"] = bw.apply(\n", - " lambda row: gmean(row[\"C\"] / 100, row[\"FPR\"] / 100), axis=1\n", + "bw[\"Sc-backward\"] = bw.apply(\n", + " lambda row: gmean(row[\"C\"] / 100, row[\"NR\"] / 100), axis=1\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fw" + ] + }, { "cell_type": "code", "execution_count": null, @@ -321,10 +664,19 @@ "outputs": [], "source": [ "valid_result = pd.concat(\n", - " [fw[\"G-mean-forward\"], bw[[\"G-mean-backward\", \"Type\", \"Radii\"]]], axis=1\n", + " [fw[\"Sc-forward\"], bw[[\"Sc-backward\", \"Type\", \"Radii\"]]], axis=1\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "valid_result" + ] + }, { "cell_type": "code", "execution_count": null, @@ -343,14 +695,16 @@ "outputs": [], "source": [ "def gmean(tpr, fpr):\n", + " \n", " tnr = 1 - fpr # True Negative Rate\n", - " g_mean = np.sqrt(tpr * tnr)\n", + "\n", + " g_mean = np.sqrt(tpr * fpr)\n", " return g_mean\n", "\n", "\n", "# Calculate G-mean for each row and add it as a new column\n", - "fw[\"G-mean\"] = fw.apply(lambda row: gmean(row[\"C\"] / 100, row[\"FPR\"] / 100), axis=1)\n", - "bw[\"G-mean\"] = bw.apply(lambda row: gmean(row[\"C\"] / 100, row[\"FPR\"] / 100), axis=1)" + "fw[\"G-mean\"] = fw.apply(lambda row: gmean(row[\"C\"] / 100, row[\"AG\"] / 100), axis=1)\n", + "bw[\"G-mean\"] = bw.apply(lambda row: gmean(row[\"C\"] / 100, row[\"AG\"] / 100), axis=1)" ] }, { @@ -362,15 +716,6 @@ "fw" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "bw" - ] - }, { "cell_type": "code", "execution_count": null, @@ -443,7 +788,7 @@ "\n", "\n", "# Example usage\n", - "log_file_path = \"../../Data/Temp/Benchmark/Raw/Log/Test\"" + "log_file_path = \"../../Data/Temp/Benchmark/Raw/Log/Valid\"" ] }, { @@ -480,9 +825,9 @@ "outputs": [], "source": [ "valid_times_compare = {\n", - " r\"$Q_{\\text{raw}}$\": [9169.663, 34687.0292, 212203.163, 450822.167],\n", - " r\"$Q_{\\text{complete}}$\": [10067.556, 36444.666, 213271.507, 458275.105],\n", - " r\"$Q_{\\text{expand}}$\": [10882.231, 36850.825, 215493.644, 465514.313],\n", + " r\"$Q_{\\text{raw}}$\": [7248.640, 19696.1595, 153555.485, 329958.076],\n", + " r\"$Q_{\\text{complete}}$\": [7488.947, 20696.089, 160297.79, 384013.142],\n", + " r\"$Q_{\\text{refine}}$\": [8560.03, 21858.446, 166933.66, 400155.498],\n", "}" ] }, @@ -505,7 +850,7 @@ "metadata": {}, "outputs": [], "source": [ - "from _analysis._rule_app_analysis import plot_roc_curves, plot_processing_times" + "from _analysis._rule_app_analysis import plot_processing_times" ] }, { @@ -514,8 +859,104 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", + "def plot_processing_times(\n", + " times: Dict[str, List[float]], ax: Optional[plt.Axes] = None, title: str = \"A\"\n", + ") -> None:\n", + " \"\"\"\n", + " Plot processing times for various methods across different stages.\n", + "\n", + " This function takes a dictionary of processing times, converts them into hours,\n", + " and plots them using a bar chart.\n", + "\n", + " Parameters:\n", + " times (Dict[str, List[float]]): A dictionary where keys are method names and values\n", + " are lists of processing times in seconds for each stage.\n", + " ax (Optional[plt.Axes]): Matplotlib Axes object where the plot will be drawn. If None,\n", + " the current active Axes will be used.\n", + " title (str): The title of the plot.\n", + "\n", + " Returns:\n", + " None: The function creates a plot but does not return any value.\n", "\n", + " Example:\n", + " >>> times = {\n", + " ... \"Method1\": [3600, 7200, 1800, 5400],\n", + " ... \"Method2\": [1800, 3600, 900, 2700],\n", + " ... }\n", + " >>> fig, ax = plt.subplots()\n", + " >>> plot_processing_times(times, ax=ax, title=\"Processing Times Analysis\")\n", + " >>> plt.show()\n", + " \"\"\"\n", + " plt.rc(\"text\", usetex=True)\n", + " plt.rc(\"text.latex\", preamble=r\"\\usepackage{amsmath}\") # Ensure amsmath is loaded\n", + " # Convert to hours\n", + " for key in times:\n", + " times[key] = np.array(times[key]) / 3600\n", + "\n", + " # Stages\n", + " stages = [r\"$R_{0}$\", r\"$R_{1}$\", r\"$R_{2}$\", r\"$R_{3}$\"]\n", + "\n", + " # Create a DataFrame\n", + " df = (\n", + " pd.DataFrame(times, index=stages)\n", + " .reset_index()\n", + " .melt(id_vars=\"index\", var_name=\"Method\", value_name=\"Time (hours)\")\n", + " )\n", + " df.rename(columns={\"index\": \"Stage\"}, inplace=True)\n", + "\n", + " # Create the plot on the provided ax\n", + " if ax is None:\n", + " ax = plt.gca() # Get current axis if not provided\n", + "\n", + " custom_colors = [\"#5e4fa2\", \"#3A8EBA\", \"#D6604D\"]\n", + " palette = sns.color_palette(custom_colors[: len(times.keys())])\n", + " bar_plot = sns.barplot(\n", + " x=\"Stage\", y=\"Time (hours)\", hue=\"Method\", data=df, palette=palette, ax=ax\n", + " )\n", + "\n", + " ax.set_title(rf\"{title}\", fontsize=24, weight=\"bold\")\n", + " ax.set_xlabel(None)\n", + " ax.set_ylabel(rf\"Time (Hours)\", fontsize=20)\n", + " ax.set_xticklabels(ax.get_xticklabels(), fontsize=20)\n", + " ax.set_yticklabels([rf\"{y:.0f}\" for y in ax.get_yticks()], fontsize=20)\n", + " ax.legend(\n", + " title=\"Template Type\",\n", + " title_fontsize=\"24\",\n", + " fontsize=\"20\",\n", + " loc=\"upper left\",\n", + " bbox_to_anchor=(0.01, 1),\n", + " )\n", + "\n", + " # Add text annotations on the bars\n", + " for p in bar_plot.patches:\n", + " bar_height = p.get_height()\n", + " if bar_height > 0.01: # Adjust this threshold as needed\n", + " annotation = format(\n", + " p.get_height(), \".1f\" if p.get_height() < 100 else \".0f\"\n", + " )\n", + " ax.annotate(\n", + " rf\"{annotation}\",\n", + " (p.get_x() + p.get_width() / 2, p.get_height()),\n", + " ha=\"center\",\n", + " va=\"center\",\n", + " xytext=(0, 9),\n", + " textcoords=\"offset points\",\n", + " fontsize=20,\n", + " )\n", + "\n", + " ax.grid(True, linestyle=\"--\", alpha=0.6)\n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.rc(\"text\", usetex=True)\n", + "plt.rc(\"text.latex\", preamble=r\"\\usepackage{amsmath}\") # Ensure amsmath is loaded\n", "fig, axs = plt.subplots(1, 1, figsize=(16, 8))\n", "plot_processing_times(valid_times_compare, ax=axs, title=\"A. Time benchmarking\")\n", "# plot_processing_times(test_times_compare, ax=axs[1], title = 'B')\n", @@ -548,7 +989,7 @@ " 0: \"average_solution\",\n", " # 1: r'\\mathcal(C)',\n", " 1: \"C\",\n", - " 2: \"FPR\",\n", + " 2: \"NR\",\n", " },\n", " inplace=True,\n", ")\n", @@ -557,7 +998,7 @@ " 0: \"average_solution\",\n", " # 1: r'\\mathcal(C)',\n", " 1: \"C\",\n", - " 2: \"FPR\",\n", + " 2: \"NR\",\n", " },\n", " inplace=True,\n", ")\n", @@ -602,14 +1043,14 @@ "legend_handles_fw = plot_roc_curves(\n", " fw,\n", " axs[1, 0],\n", - " selected_types=[\"Complete\", \"Expand\"],\n", + " selected_types=[\"Complete\", \"Refine\"],\n", " fontsettings=fontsettings,\n", " title=r\"B. ROC Curves Validation\",\n", ")\n", "legend_handles_bw = plot_roc_curves(\n", " bw,\n", " axs[1, 1],\n", - " selected_types=[\"Complete\", \"Expand\"],\n", + " selected_types=[\"Complete\", \"Refine\"],\n", " fontsettings=fontsettings,\n", " title=r\"C. ROC Curves Test\",\n", ")\n", diff --git a/Docs/Analysis/_analysis/_plot_analysis.py b/Docs/Analysis/_analysis/_plot_analysis.py index f09bdc1..8c9073d 100644 --- a/Docs/Analysis/_analysis/_plot_analysis.py +++ b/Docs/Analysis/_analysis/_plot_analysis.py @@ -65,20 +65,20 @@ def plot_top_rules_with_seaborn( # Add labels on top of each bar for p in barplot.patches: ax.annotate( - f"{p.get_height():.1f}%", + rf"{p.get_height():.1f}%", (p.get_x() + p.get_width() / 2.0, p.get_height()), ha="center", va="center", xytext=(0, 9), textcoords="offset points", - fontsize=16, + fontsize=20, ) # Setting plot labels and titles - ax.set_xlabel("Rule ID", fontsize=18) - ax.set_ylabel("Percentage (%)", fontsize=18) - ax.set_title(f"Top {top_n} Popular Rules", fontsize=24, weight="medium") - ax.tick_params(axis="both", labelsize=18) + ax.set_xlabel(r"Rule ID", fontsize=24) + ax.set_ylabel(r"Percentage (\%)", fontsize=24) + ax.set_title(rf"Top {top_n} Popular Rules", fontsize=32, weight="medium") + ax.tick_params(axis="both", labelsize=24) plt.xticks(rotation=45) # Rotate x-axis labels for better readability # Show the plot if ax was not provided @@ -131,7 +131,7 @@ def load_and_title_png( # Display the image on the specified axis ax.imshow(img) ax.set_title( - title, fontsize=24, weight="medium" + rf"{title}", fontsize=32, weight="medium" ) # Set the title with a specified fontsize ax.axis("off") # Hide the axes diff --git a/Docs/Analysis/_analysis/_rule_app_analysis.py b/Docs/Analysis/_analysis/_rule_app_analysis.py index c872306..2be6606 100644 --- a/Docs/Analysis/_analysis/_rule_app_analysis.py +++ b/Docs/Analysis/_analysis/_rule_app_analysis.py @@ -357,7 +357,7 @@ def plot_roc_curves( for index, type_ in enumerate(original_types): type_data = df[df["Type"] == type_] tpr = type_data["C"].tolist() - fpr = type_data["FPR"].tolist() + fpr = type_data["NR"].tolist() tpr = [x / 100 for x in tpr] fpr = [x / 100 for x in fpr] tpr.reverse() @@ -384,7 +384,7 @@ def plot_roc_curves( ) legend_handles.append(marker_handle) - ax.set_xlabel(r"$\mathcal{FPR}\ (\%)$", fontsize=fontsettings["label_size"]) + ax.set_xlabel(r"$\mathcal{NR}\ (\%)$", fontsize=fontsettings["label_size"]) ax.set_ylabel(r"$\mathcal{C}\ (\%)$", fontsize=fontsettings["label_size"]) ax.set_title(rf"{title}", fontsize=fontsettings["title_size"], weight="medium") ax.tick_params(axis="both", which="major", labelsize=fontsettings["ticks_size"]) diff --git a/Test/SynRule/test_rule_writing.py b/Test/SynRule/test_rule_writing.py index 766f603..e4dcf2b 100644 --- a/Test/SynRule/test_rule_writing.py +++ b/Test/SynRule/test_rule_writing.py @@ -10,8 +10,8 @@ def setUp(self) -> None: self.data = load_from_pickle("Data/Testcase/templates.pkl.gz")[0] def test_charge_to_string(self): - self.assertEqual(RuleWriting.charge_to_string(3), "+++") - self.assertEqual(RuleWriting.charge_to_string(-2), "--") + self.assertEqual(RuleWriting.charge_to_string(3), "3+") + self.assertEqual(RuleWriting.charge_to_string(-2), "2-") self.assertEqual(RuleWriting.charge_to_string(0), "") def test_convert_graph_to_gml_context(self): diff --git a/syntemp/SynITS/its_extraction.py b/syntemp/SynITS/its_extraction.py index e5d3b35..67350c0 100644 --- a/syntemp/SynITS/its_extraction.py +++ b/syntemp/SynITS/its_extraction.py @@ -8,6 +8,7 @@ from syntemp.SynITS.its_construction import ITSConstruction from syntemp.SynChemistry.mol_to_graph import MolToGraph from syntemp.SynRule.rules_extraction import RuleExtraction +from syntemp.SynUtils.chemutils import remove_atom_mapping class ITSExtraction: @@ -136,10 +137,13 @@ def process_mapped_smiles( graphs_by_map[mapper] = (one_node_graph, one_node_graph, one_node_graph) rules_by_map[mapper] = (one_node_graph, one_node_graph, one_node_graph) rules_graphs.append(one_node_graph) - if check_method == "RC": - _, equivariant = ITSExtraction.check_equivariant_graph(rules_graphs) - elif check_method == "ITS": - _, equivariant = ITSExtraction.check_equivariant_graph(its_graphs) + if len(rules_graphs) > 1: + if check_method == "RC": + _, equivariant = ITSExtraction.check_equivariant_graph(rules_graphs) + elif check_method == "ITS": + _, equivariant = ITSExtraction.check_equivariant_graph(its_graphs) + else: + equivariant = 0 # graphs_by_map['check_equivariant'] = classified graphs_by_map["equivariant"] = equivariant @@ -156,13 +160,16 @@ def process_mapped_smiles( # Check if mapper_names is not empty to avoid IndexError if mapper_names: - # Update the target dictionary based on the determined conditions - if confident_mapper in mapper_names: - target_dict["ITSGraph"] = graphs_by_map.get(confident_mapper, None) - target_dict["GraphRules"] = rules_by_map.get(confident_mapper, None) - else: + if "[O]" in remove_atom_mapping(mapped_smiles[mapper_names[0]]): target_dict["ITSGraph"] = graphs_by_map.get(mapper_names[0], None) target_dict["GraphRules"] = rules_by_map.get(mapper_names[0], None) + else: + if confident_mapper in mapper_names: + target_dict["ITSGraph"] = graphs_by_map.get(confident_mapper, None) + target_dict["GraphRules"] = rules_by_map.get(confident_mapper, None) + else: + target_dict["ITSGraph"] = graphs_by_map.get(mapper_names[0], None) + target_dict["GraphRules"] = rules_by_map.get(mapper_names[0], None) return graphs_by_map_correct, graphs_by_map_incorrect diff --git a/syntemp/SynITS/its_hadjuster.py b/syntemp/SynITS/its_hadjuster.py index bbf1ba1..6a60a13 100644 --- a/syntemp/SynITS/its_hadjuster.py +++ b/syntemp/SynITS/its_hadjuster.py @@ -22,6 +22,7 @@ def process_single_graph_data( return_all: bool = False, ignore_aromaticity: bool = False, balance_its: bool = True, + get_random_results=False, ) -> Dict: """ Processes a single dictionary containing graph information by applying @@ -63,6 +64,7 @@ def process_single_graph_data( ignore_aromaticity, return_all, balance_its, + get_random_results, ) else: graph_data = ITSHAdjuster.process_high_hcount_change( @@ -73,6 +75,7 @@ def process_single_graph_data( ignore_aromaticity, return_all, balance_its, + get_random_results, ) return graph_data @@ -105,6 +108,7 @@ def process_multiple_hydrogens( ignore_aromaticity, return_all, balance_its, + get_random_results=False, ): """ Handles cases with hydrogen count changes between 2 and 4, inclusive. @@ -142,6 +146,7 @@ def process_multiple_hydrogens( ignore_aromaticity, return_all, balance_its, + get_random_results, ) return graph_data @@ -154,6 +159,7 @@ def process_high_hcount_change( ignore_aromaticity, return_all, balance_its: bool = True, + get_random_results=False, ): """ Handles cases with hydrogen count changes of 5 or more. @@ -179,7 +185,11 @@ def process_high_hcount_change( for i in its_list ] - its_list, rc_list = get_priority(its_list, reaction_centers) + priority_indices = get_priority(reaction_centers) + rc_list = [reaction_centers[i] for i in priority_indices] + its_list = [its_list[i] for i in priority_indices] + combinations_solution = [combinations_solution[i] for i in priority_indices] + _, equivariant = ITSExtraction.check_equivariant_graph(rc_list) pairwise_combinations = len(its_list) - 1 if equivariant == pairwise_combinations: @@ -187,12 +197,18 @@ def process_high_hcount_change( graph_data, *combinations_solution[0], its_list[0] ) else: - if return_all: + if get_random_results is True: graph_data = ITSHAdjuster.update_graph_data( - graph_data, react_graph, prod_graph, its + graph_data, *combinations_solution[0], its_list[0] ) + else: - graph_data["ITSGraph"], graph_data["GraphRules"] = None, None + if return_all: + graph_data = ITSHAdjuster.update_graph_data( + graph_data, react_graph, prod_graph, its + ) + else: + graph_data["ITSGraph"], graph_data["GraphRules"] = None, None return graph_data @staticmethod @@ -204,6 +220,7 @@ def process_graph_data_parallel( return_all: bool = False, ignore_aromaticity: bool = False, balance_its: bool = True, + get_random_results: bool = False, ) -> List[Dict]: """ Processes a list of dictionaries containing graph information in parallel. @@ -220,7 +237,12 @@ def process_graph_data_parallel( """ processed_data = Parallel(n_jobs=n_jobs, verbose=verbose)( delayed(ITSHAdjuster.process_single_graph_data)( - graph_data, column, return_all, ignore_aromaticity, balance_its + graph_data, + column, + return_all, + ignore_aromaticity, + balance_its, + get_random_results, ) for graph_data in graph_data_list ) @@ -301,7 +323,7 @@ def add_hydrogen_nodes_multiple( """ react_graph_copy = deepcopy(react_graph) prod_graph_copy = deepcopy(prod_graph) - react_explicit_h, _ = check_explicit_hydrogen(react_graph_copy) + react_explicit_h, hydrogen_nodes = check_explicit_hydrogen(react_graph_copy) prod_explicit_h, _ = check_explicit_hydrogen(prod_graph_copy) hydrogen_nodes_form, hydrogen_nodes_break = [], [] @@ -328,22 +350,13 @@ def add_hydrogen_nodes_multiple( max(react_graph_copy.nodes, default=0), max(prod_graph_copy.nodes, default=0), ) - permutations = list( - itertools.permutations( - range( - max_index + 1 - react_explicit_h, - max_index + 1 + len(hydrogen_nodes_form) - react_explicit_h, - ) - ) + range_implicit_h = range( + max_index + 1, + max_index + 1 + len(hydrogen_nodes_form) - react_explicit_h, ) - permutations_seed = list( - itertools.permutations( - range( - max_index + 1 - prod_explicit_h, - max_index + 1 + len(hydrogen_nodes_break) - prod_explicit_h, - ) - ) - )[0] + combined_indices = list(range_implicit_h) + hydrogen_nodes + permutations = list(itertools.permutations(combined_indices)) + permutations_seed = permutations[0] updated_graphs = [] for permutation in permutations: diff --git a/syntemp/SynRule/rule_writing.py b/syntemp/SynRule/rule_writing.py index 8fcc421..8cdd63e 100644 --- a/syntemp/SynRule/rule_writing.py +++ b/syntemp/SynRule/rule_writing.py @@ -18,12 +18,12 @@ def charge_to_string(charge): """ if charge > 0: return ( - "+" * charge - ) # Repeat the '+' symbol 'charge' times for positive charges + "+" if charge == 1 else f"{charge}+" + ) # '+' for +1, '2+', '3+', etc., for higher values elif charge < 0: - return "-" * abs( - charge - ) # Repeat the '-' symbol 'abs(charge)' times for negative charges + return ( + "-" if charge == -1 else f"{-charge}-" + ) # '-' for -1, '2-', '3-', etc., for lower values else: return "" # No charge symbol for neutral atoms @@ -51,7 +51,11 @@ def convert_graph_to_gml( for node in graph.nodes(data=True): if node[0] not in changed_node_ids: element = node[1].get("element", "X") - gml_str += f' node [ id {node[0]} label "{element}" ]\n' + charge = node[1].get("charge", 0) + charge_str = RuleWriting.charge_to_string(charge) + gml_str += ( + f' node [ id {node[0]} label "{element}{charge_str}" ]\n' + ) if section != "context": for edge in graph.edges(data=True): diff --git a/syntemp/SynUtils/chemutils.py b/syntemp/SynUtils/chemutils.py index 682b3cc..ea05efc 100644 --- a/syntemp/SynUtils/chemutils.py +++ b/syntemp/SynUtils/chemutils.py @@ -126,30 +126,27 @@ def remove_hydrogens_and_sanitize(mol: Chem.Mol) -> Chem.Mol: return mol -def remove_atom_mapping(smiles: str) -> str: - """ - Removes atom mapping numbers and simplifies atomic notation in a SMILES string. - - This function processes a SMILES string to: - 1. Remove any atom mapping numbers denoted by ':' - followed by one or more digits. - 2. Simplify the atomic notation by removing square - brackets around atoms that do not need them. - - Parameters: - - smiles (str): The SMILES string to be processed. +def remove_atom_mapping(reaction_smiles): + # Split the reaction SMILES into reactants and products + parts = reaction_smiles.split(">>") + if len(parts) != 2: + raise ValueError("Invalid reaction SMILES format.") - Returns: - - str: The processed SMILES string with atom mappings - removed and simplified atomic notations. - """ - # Remove atom mapping numbers - pattern = re.compile(r":\d+") - smiles = pattern.sub("", smiles) - # Simplify atomic notation by removing unnecessary square brackets - pattern = re.compile(r"\[(?P(B|C|N|O|P|S|F|Cl|Br|I){1,2})(?:H\d?)?\]") - smiles = pattern.sub(r"\g", smiles) - return smiles + # Function to remove atom mappings from a SMILES string + def clean_smiles(smiles): + mol = Chem.MolFromSmiles(smiles) # Convert SMILES to an RDKit mol object + if mol is None: + raise ValueError("Invalid SMILES string.") + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) # Remove atom mapping + return Chem.MolToSmiles(mol, True) # Convert mol back to SMILES + + # Apply the cleaning function to both reactants and products + reactants_clean = clean_smiles(parts[0]) + products_clean = clean_smiles(parts[1]) + + # Combine the cleaned reactants and products back into a reaction SMILES + return f"{reactants_clean}>>{products_clean}" def mol_from_smiles(smiles: str) -> Optional[Chem.Mol]: diff --git a/syntemp/SynUtils/graph_utils.py b/syntemp/SynUtils/graph_utils.py index b3dce67..bd44951 100644 --- a/syntemp/SynUtils/graph_utils.py +++ b/syntemp/SynUtils/graph_utils.py @@ -1,6 +1,6 @@ import networkx as nx import copy -from typing import List, Dict, Any, Tuple +from typing import List, Dict, Any def is_acyclic_graph(G: nx.Graph) -> bool: @@ -313,63 +313,50 @@ def check_graph_connectivity(graph): return "Disconnected." -def get_priority( - its_list: List[Any], - reaction_centers: List[Any], - priority_ring: List[int] = [4, 5, 6], # Standard priority rings - priority_pair: List[int] = [3, 5], # Special priority requiring both rings - not_priority_ring: List[int] = [3], # Non-priority ring that disqualifies alone -) -> Tuple[List[Any], List[Any]]: +def get_priority(reaction_centers: List[Any]) -> List[int]: """ - Filters reaction centers based on their connectivity and specific ring sizes, - including those with both rings in the priority pair, and excluding those with non- - priority ring sizes, - unless a specific pair condition is met (e.g., 3 must appear with 5). + Evaluate reaction centers for specific graph characteristics, selecting indices based + on the shortest reaction paths and maximum ring sizes, and adjusting for certain + graph types by modifying the ring information. Parameters: - - its_list (List[Any]): List of identifiers for the reaction centers. - - reaction_centers (List[Any]): List of reaction centers to evaluate. - - priority_ring (List[int], optional): List of ring sizes given priority. - Defaults to [4, 6]. - - priority_pair (List[int], optional): List of two ring sizes that must both appear - together to qualify. Defaults to [3, 5]. - - not_priority_ring (List[int], optional): List of ring sizes that disqualify a center - unless paired appropriately. Defaults to [3]. + - reaction_centers: List[Any], a list of reaction centers where each center should be + capable of being analyzed for graph type and ring sizes. Returns: - - Tuple[List[Any], List[Any]]: Tuple containing two lists: - - The first list contains the identifiers from its_list that meet all criteria. - - The second list contains the corresponding reaction centers that meet the - criteria. + - List[int]: A list of indices from the original list of reaction centers that meet + the criteria of having the shortest reaction steps and/or the largest ring sizes. + Returns indices with minimum reaction steps if no indices meet both criteria. """ - priority_set = set(priority_ring) - not_priority_set = set(not_priority_ring) - priority_pair_set = set(priority_pair) - - # Filter to include only connected reaction centers - connected_centers = [] - connected_its_list = [] - for index, center in enumerate(reaction_centers): - if check_graph_connectivity(center) == "Connected": - connected_centers.append(center) - connected_its_list.append(its_list[index]) - - cyclic = [get_cycle_member_rings(center) for center in connected_centers] - # Filter indices based on priority and non-priority ring sizes - final_indices = [] - for i, rings in enumerate(cyclic): - ring_set = set(rings) - # Check for priority conditions and special conditions for non-priority rings - if ( - ring_set.intersection(priority_set) or (priority_pair_set <= ring_set) - ) and not ( - ring_set.intersection(not_priority_set) - and not (5 in ring_set and 3 in ring_set) - ): - final_indices.append(i) + # Extract topology types and ring sizes from reaction centers + topo_type = [check_graph_type(center) for center in reaction_centers] + cyclic = [get_cycle_member_rings(center) for center in reaction_centers] + + # Adjust ring information based on the graph type + for index, graph_type in enumerate(topo_type): + if graph_type in ["Acyclic", "Complex Cyclic"]: + cyclic[index] = [0] + cyclic[index] + + # Determine minimum reaction steps + reaction_steps = [len(rings) for rings in cyclic] + min_reaction_step = min(reaction_steps) + + # Filter indices with the minimum reaction steps + indices_shortest = [ + i for i, steps in enumerate(reaction_steps) if steps == min_reaction_step + ] + + # Filter indices with the maximum ring size + max_size = max( + max(rings) for rings in cyclic if rings + ) # Safeguard against empty sublists + prior_indices = [i for i, rings in enumerate(cyclic) if max(rings) == max_size] + + # Combine criteria for final indices + final_indices = [index for index in prior_indices if index in indices_shortest] - # Retrieve final lists based on filtered indices - final_its_list = [connected_its_list[i] for i in final_indices] - final_centers = [connected_centers[i] for i in final_indices] + # Fallback to shortest indices if no indices meet both criteria + if not final_indices: + return indices_shortest - return final_its_list, final_centers + return final_indices diff --git a/syntemp/__main__.py b/syntemp/__main__.py index e078d32..e480032 100644 --- a/syntemp/__main__.py +++ b/syntemp/__main__.py @@ -16,7 +16,7 @@ def parse_arguments(): parser.add_argument( "--mapper_types", nargs="+", - default=["rxn_mapper", "graphormer", "local_mapper"], + default=["local_mapper", "rxn_mapper", "graphormer"], help="Types of atom map techniques used", ) parser.add_argument("--id", type=str, default="R-id", help="ID column") @@ -24,11 +24,11 @@ def parse_arguments(): "--rsmi", type=str, default="reactions", help="Reaction SMILES column" ) parser.add_argument( - "--n_jobs", type=int, default=4, help="Number of jobs to run in parallel" + "--n_jobs", type=int, default=8, help="Number of jobs to run in parallel" ) parser.add_argument("--verbose", type=int, default=2, help="Verbosity level") parser.add_argument( - "--batch_size", type=int, default=200, help="Batch size for processing" + "--batch_size", type=int, default=1000, help="Batch size for processing" ) parser.add_argument("--safe_mode", action="store_true", help="Enable safe mode") parser.add_argument( @@ -51,6 +51,11 @@ def parse_arguments(): parser.add_argument( "--log_level", type=str, default="INFO", help="File to log the process" ) + parser.add_argument( + "--get_random_hydrogen", + action="store_true", + help="Get random full ITS hydrogen", + ) return parser.parse_args() @@ -91,6 +96,7 @@ def main(): rerun_aam=args.rerun_aam, log_file=args.log_file, log_level=args.log_level, + get_random_hydrogen=args.get_random_hydrogen, ) auto.temp_extract(data, lib_path=args.lib_path) logging.info("Extraction successful.") diff --git a/syntemp/auto_template.py b/syntemp/auto_template.py index 02b1f7c..6a9ab4a 100644 --- a/syntemp/auto_template.py +++ b/syntemp/auto_template.py @@ -62,6 +62,7 @@ def __init__( log_file: str = None, log_level: str = "INFO", clean_data: bool = True, + get_random_hydrogen: bool = False, ): """ Initializes the AutoTemp class with specified settings for processing chemical @@ -120,6 +121,7 @@ def __init__( self.reindex = reindex self.rerun_aam = rerun_aam self.clean_data = clean_data + self.get_random_hydrogen = get_random_hydrogen log_level = getattr(logging, log_level.upper(), None) if not isinstance(log_level, int): @@ -181,6 +183,7 @@ def temp_extract( self.fix_hydrogen, self.refinement_its, self.save_dir, + get_random_results=self.get_random_hydrogen, ) # Step 4: Extract rules from the correct ITS graphs diff --git a/syntemp/pipeline.py b/syntemp/pipeline.py index b0eb0a3..4387c1c 100644 --- a/syntemp/pipeline.py +++ b/syntemp/pipeline.py @@ -162,6 +162,7 @@ def extract_its( save_dir: Optional[str] = None, data_name: str = "", symbol: str = ">>", + get_random_results: bool = False, ) -> List[dict]: """ Executes the extraction of ITS graphs from reaction data in batches, @@ -216,8 +217,13 @@ def extract_its( if i == 1 or (i % 10 == 0 and i >= 10): logging.info(f"Fixing hydrogen for batch {i + 1}/{num_batches}.") batch_processed = ITSHAdjuster.process_graph_data_parallel( - batch_correct, "ITSGraph", n_jobs=n_jobs, verbose=verbose + batch_correct, + "ITSGraph", + n_jobs=n_jobs, + verbose=verbose, + get_random_results=get_random_results, ) + uncertain_hydrogen = [ value for value in batch_processed if value["ITSGraph"] is None ] @@ -244,12 +250,16 @@ def extract_its( its_correct = collect_data(num_batches, temp_dir, "batch_correct_{}.pkl") logging.info("Processing unequivalent ITS correct") its_incorrect = collect_data(num_batches, temp_dir, "batch_incorrect_{}.pkl") - all_uncertain_hydrogen = [] - if fix_hydrogen: - logging.info("Processing ambiguous hydrogen-ITS") - all_uncertain_hydrogen = collect_data( - num_batches, temp_dir, "uncertain_hydrogen_{}.pkl" - ) + try: + all_uncertain_hydrogen = [] + if fix_hydrogen: + logging.info("Processing ambiguous hydrogen-ITS") + all_uncertain_hydrogen = collect_data( + num_batches, temp_dir, "uncertain_hydrogen_{}.pkl" + ) + except Exception as e: + logging.error(f"{e}") + all_uncertain_hydrogen = [] # logging.info(f"Number of correct mappers before refinement: {len(its_correct)}") if refinement_its: @@ -267,7 +277,10 @@ def extract_its( logging.info(f"Number of correct mappers: {len(its_correct)}") logging.info(f"Number of incorrect mappers: {len(its_incorrect)}") - logging.info(f"Number of uncertain hydrogen:{len(all_uncertain_hydrogen)}") + logging.info( + "Number of uncertain hydrogen:" + + f"{len(data)-len(its_correct)-len(its_incorrect)}" + ) if save_dir: logging.info("Combining and saving data") diff --git a/syntemp/run_compose.py b/syntemp/run_compose.py index 15b944a..9931477 100644 --- a/syntemp/run_compose.py +++ b/syntemp/run_compose.py @@ -91,4 +91,4 @@ def main(args): # python run_compose.py -s Data/Temp/RuleComp/Single/R0 -# -c Data/Temp/RuleComp/Compose_expand -d Data/Temp/RuleComp/Double/R0/ +# -c Data/Temp/RuleComp/Compose -d Data/Temp/RuleComp/Double/R0/