From 9e154bf21e35c87f12317b622bdfcc707ff940ea Mon Sep 17 00:00:00 2001
From: J08nY <johny@neuromancer.sk>
Date: Wed, 26 Jul 2023 22:33:46 +0200
Subject: [PATCH] Add rudimentary profiling.

Fixes #288.
---
 src/sec_certs/dataset/cc.py      | 66 ++++++++++++++++----------------
 src/sec_certs/dataset/dataset.py |  8 +++-
 src/sec_certs/dataset/fips.py    | 23 +++++------
 src/sec_certs/utils/profiling.py | 45 ++++++++++++++++++++++
 4 files changed, 93 insertions(+), 49 deletions(-)
 create mode 100644 src/sec_certs/utils/profiling.py

diff --git a/src/sec_certs/dataset/cc.py b/src/sec_certs/dataset/cc.py
index 0db196c9..16c95e84 100644
--- a/src/sec_certs/dataset/cc.py
+++ b/src/sec_certs/dataset/cc.py
@@ -33,6 +33,7 @@
 from sec_certs.serialization.json import ComplexSerializableType, serialize
 from sec_certs.utils import helpers
 from sec_certs.utils import parallel_processing as cert_processing
+from sec_certs.utils.profiling import staged
 
 
 @dataclass
@@ -270,6 +271,7 @@ def _download_csv_html_resources(self, get_active: bool = True, get_archived: bo
         helpers.download_parallel(csv_urls, csv_paths)
 
     @serialize
+    @staged(logger, "Downloading and processing CSV and HTML files of certificates.")
     def get_certs_from_web(
         self, to_download: bool = True, keep_metadata: bool = True, get_active: bool = True, get_archived: bool = True
     ) -> None:
@@ -520,12 +522,11 @@ def _download_all_artifacts_body(self, fresh: bool = True) -> None:
         self._download_reports(fresh)
         self._download_targets(fresh)
 
+    @staged(logger, "Downloading PDFs of CC certification reports.")
     def _download_reports(self, fresh: bool = True) -> None:
         self.reports_pdf_dir.mkdir(parents=True, exist_ok=True)
         certs_to_process = [x for x in self if x.state.report_is_ok_to_download(fresh) and x.report_link]
 
-        if fresh:
-            logger.info("Downloading PDFs of CC certification reports.")
         if not fresh and certs_to_process:
             logger.info(
                 f"Downloading {len(certs_to_process)} PDFs of CC certification reports for which previous download failed."
@@ -537,12 +538,11 @@ def _download_reports(self, fresh: bool = True) -> None:
             progress_bar_desc="Downloading PDFs of CC certification reports",
         )
 
+    @staged(logger, "Downloading PDFs of CC security targets.")
     def _download_targets(self, fresh: bool = True) -> None:
         self.targets_pdf_dir.mkdir(parents=True, exist_ok=True)
         certs_to_process = [x for x in self if x.state.report_is_ok_to_download(fresh)]
 
-        if fresh:
-            logger.info("Downloading PDFs of CC security targets.")
         if not fresh and certs_to_process:
             logger.info(
                 f"Downloading {len(certs_to_process)} PDFs of CC security targets for which previous download failed.."
@@ -554,12 +554,11 @@ def _download_targets(self, fresh: bool = True) -> None:
             progress_bar_desc="Downloading PDFs of CC security targets",
         )
 
+    @staged(logger, "Converting PDFs of certification reports to txt.")
     def _convert_reports_to_txt(self, fresh: bool = True) -> None:
         self.reports_txt_dir.mkdir(parents=True, exist_ok=True)
         certs_to_process = [x for x in self if x.state.report_is_ok_to_convert(fresh)]
 
-        if fresh:
-            logger.info("Converting PDFs of certification reports to txt.")
         if not fresh and certs_to_process:
             logger.info(
                 f"Converting {len(certs_to_process)} PDFs of certification reports to txt for which previous conversion failed."
@@ -571,6 +570,7 @@ def _convert_reports_to_txt(self, fresh: bool = True) -> None:
             progress_bar_desc="Converting PDFs of certification reports to txt",
         )
 
+    @staged(logger, "Converting PDFs of security targets to txt.")
     def _convert_targets_to_txt(self, fresh: bool = True) -> None:
         self.targets_txt_dir.mkdir(parents=True, exist_ok=True)
         certs_to_process = [x for x in self if x.state.st_is_ok_to_convert(fresh)]
@@ -592,8 +592,8 @@ def _convert_all_pdfs_body(self, fresh: bool = True) -> None:
         self._convert_reports_to_txt(fresh)
         self._convert_targets_to_txt(fresh)
 
+    @staged(logger, "Extracting report metadata")
     def _extract_report_metadata(self) -> None:
-        logger.info("Extracting report metadata")
         certs_to_process = [x for x in self if x.state.report_is_ok_to_analyze()]
         processed_certs = cert_processing.process_parallel(
             CCCertificate.extract_report_pdf_metadata,
@@ -603,8 +603,8 @@ def _extract_report_metadata(self) -> None:
         )
         self.update_with_certs(processed_certs)
 
-    def _extract_targets_metadata(self) -> None:
-        logger.info("Extracting target metadata")
+    @staged(logger, "Extracting target metadata")
+    def _extract_target_metadata(self) -> None:
         certs_to_process = [x for x in self if x.state.st_is_ok_to_analyze()]
         processed_certs = cert_processing.process_parallel(
             CCCertificate.extract_st_pdf_metadata,
@@ -616,10 +616,10 @@ def _extract_targets_metadata(self) -> None:
 
     def _extract_pdf_metadata(self) -> None:
         self._extract_report_metadata()
-        self._extract_targets_metadata()
+        self._extract_target_metadata()
 
+    @staged(logger, "Extracting report frontpages")
     def _extract_report_frontpage(self) -> None:
-        logger.info("Extracting report frontpages")
         certs_to_process = [x for x in self if x.state.report_is_ok_to_analyze()]
         processed_certs = cert_processing.process_parallel(
             CCCertificate.extract_report_pdf_frontpage,
@@ -629,8 +629,8 @@ def _extract_report_frontpage(self) -> None:
         )
         self.update_with_certs(processed_certs)
 
-    def _extract_targets_frontpage(self) -> None:
-        logger.info("Extracting target frontpages")
+    @staged(logger, "Extracting target frontpages")
+    def _extract_target_frontpage(self) -> None:
         certs_to_process = [x for x in self if x.state.st_is_ok_to_analyze()]
         processed_certs = cert_processing.process_parallel(
             CCCertificate.extract_st_pdf_frontpage,
@@ -642,10 +642,10 @@ def _extract_targets_frontpage(self) -> None:
 
     def _extract_pdf_frontpage(self) -> None:
         self._extract_report_frontpage()
-        self._extract_targets_frontpage()
+        self._extract_target_frontpage()
 
+    @staged(logger, "Extracting report keywords")
     def _extract_report_keywords(self) -> None:
-        logger.info("Extracting report keywords")
         certs_to_process = [x for x in self if x.state.report_is_ok_to_analyze()]
         processed_certs = cert_processing.process_parallel(
             CCCertificate.extract_report_pdf_keywords,
@@ -655,8 +655,8 @@ def _extract_report_keywords(self) -> None:
         )
         self.update_with_certs(processed_certs)
 
-    def _extract_targets_keywords(self) -> None:
-        logger.info("Extracting target keywords")
+    @staged(logger, "Extracting target keywords")
+    def _extract_target_keywords(self) -> None:
         certs_to_process = [x for x in self if x.state.st_is_ok_to_analyze()]
         processed_certs = cert_processing.process_parallel(
             CCCertificate.extract_st_pdf_keywords,
@@ -668,7 +668,7 @@ def _extract_targets_keywords(self) -> None:
 
     def _extract_pdf_keywords(self) -> None:
         self._extract_report_keywords()
-        self._extract_targets_keywords()
+        self._extract_target_keywords()
 
     def extract_data(self) -> None:
         logger.info("Extracting various data from certification artifacts")
@@ -676,19 +676,19 @@ def extract_data(self) -> None:
         self._extract_pdf_frontpage()
         self._extract_pdf_keywords()
 
+    @staged(logger, "Computing heuristics: Deriving information about laboratories involved in certification.")
     def _compute_cert_labs(self) -> None:
-        logger.info("Computing heuristics: Deriving information about laboratories involved in certification.")
         certs_to_process = [x for x in self if x.state.report_is_ok_to_analyze()]
         for cert in certs_to_process:
             cert.compute_heuristics_cert_lab()
 
+    @staged(logger, "Computing heuristics: Deriving information about certificate ids from artifacts.")
     def _compute_normalized_cert_ids(self) -> None:
-        logger.info("Computing heuristics: Deriving information about certificate ids from artifacts.")
         for cert in self:
             cert.compute_heuristics_cert_id()
 
+    @staged(logger, "Computing heuristics: Transitive vulnerabilities in referenc(ed/ing) certificates.")
     def _compute_transitive_vulnerabilities(self):
-        logger.info("omputing heuristics: computing transitive vulnerabilities in referenc(ed/ing) certificates.")
         transitive_cve_finder = TransitiveVulnerabilityFinder(lambda cert: cert.heuristics.cert_id)
         transitive_cve_finder.fit(self.certs, lambda cert: cert.heuristics.report_references)
 
@@ -698,9 +698,9 @@ def _compute_transitive_vulnerabilities(self):
             self.certs[dgst].heuristics.direct_transitive_cves = transitive_cve.direct_transitive_cves
             self.certs[dgst].heuristics.indirect_transitive_cves = transitive_cve.indirect_transitive_cves
 
+    @staged(logger, "Computing heuristics: Matching scheme data.")
     def _compute_scheme_data(self):
         if self.auxiliary_datasets.scheme_dset:
-            print("here")
             for scheme in self.auxiliary_datasets.scheme_dset:
                 if certified := scheme.lists.get(EntryType.Certified):
                     certs = [cert for cert in self if cert.status == "active"]
@@ -713,6 +713,12 @@ def _compute_scheme_data(self):
                     for dgst, match in matches.items():
                         self[dgst].heuristics.scheme_data = match
 
+    @staged(logger, "Computing heuristics: SARs")
+    def _compute_sars(self) -> None:
+        transformer = SARTransformer().fit(self.certs.values())
+        for cert in self:
+            cert.heuristics.extracted_sars = transformer.transform_single_cert(cert)
+
     def _compute_heuristics(self) -> None:
         self._compute_normalized_cert_ids()
         super()._compute_heuristics()
@@ -720,12 +726,7 @@ def _compute_heuristics(self) -> None:
         self._compute_cert_labs()
         self._compute_sars()
 
-    def _compute_sars(self) -> None:
-        logger.info("Computing heuristics: Computing SARs")
-        transformer = SARTransformer().fit(self.certs.values())
-        for cert in self:
-            cert.heuristics.extracted_sars = transformer.transform_single_cert(cert)
-
+    @staged(logger, "Computing heuristics: references between certificates.")
     def _compute_references(self) -> None:
         def ref_lookup(kw_attr):
             def func(cert):
@@ -744,7 +745,6 @@ def func(cert):
 
             return func
 
-        logger.info("omputing heuristics: references between certificates.")
         for ref_source in ("report", "st"):
             kw_source = f"{ref_source}_keywords"
             dep_attr = f"{ref_source}_references"
@@ -768,6 +768,7 @@ def process_auxiliary_datasets(self, download_fresh: bool = False) -> None:
             to_download=download_fresh, only_schemes={cert.scheme for cert in self}
         )
 
+    @staged(logger, "Processing protection profiles.")
     def process_protection_profiles(
         self, to_download: bool = True, keep_metadata: bool = True
     ) -> ProtectionProfileDataset:
@@ -779,7 +780,6 @@ def process_protection_profiles(
         :param bool keep_metadata: If json related to the PP dataset should be kept on drive, defaults to True
         :raises RuntimeError: When building of PPDataset fails
         """
-        logger.info("Processing protection profiles.")
 
         self.auxiliary_datasets_dir.mkdir(parents=True, exist_ok=True)
 
@@ -798,13 +798,12 @@ def process_protection_profiles(
 
         return pp_dataset
 
+    @staged(logger, "Processing maintenace updates.")
     def process_maintenance_updates(self, to_download: bool = True) -> CCDatasetMaintenanceUpdates:
         """
         Downloads or loads from json a dataset of maintenance updates. Runs analysis on that dataset if it's not completed.
         :return CCDatasetMaintenanceUpdates: the resulting dataset of maintenance updates
         """
-
-        logger.info("Processing maintenace updates")
         self.mu_dataset_dir.mkdir(parents=True, exist_ok=True)
 
         if to_download or not self.mu_dataset_path.exists():
@@ -827,12 +826,11 @@ def process_maintenance_updates(self, to_download: bool = True) -> CCDatasetMain
 
         return update_dset
 
+    @staged(logger, "Processing CC scheme dataset.")
     def process_schemes(self, to_download: bool = True, only_schemes: set[str] | None = None) -> CCSchemeDataset:
         """
         Downloads or loads from json a dataset of CC scheme data.
         """
-        logger.info("Processing CC schemes")
-
         self.auxiliary_datasets_dir.mkdir(parents=True, exist_ok=True)
 
         if to_download or not self.scheme_dataset_path.exists():
diff --git a/src/sec_certs/dataset/dataset.py b/src/sec_certs/dataset/dataset.py
index e836f492..6466630f 100644
--- a/src/sec_certs/dataset/dataset.py
+++ b/src/sec_certs/dataset/dataset.py
@@ -25,6 +25,7 @@
 from sec_certs.serialization.json import ComplexSerializableType, get_class_fullname, serialize
 from sec_certs.utils import helpers
 from sec_certs.utils.nvd_dataset_builder import CpeMatchNvdDatasetBuilder, CpeNvdDatasetBuilder, CveNvdDatasetBuilder
+from sec_certs.utils.profiling import staged
 from sec_certs.utils.tqdm import tqdm
 
 logger = logging.getLogger(__name__)
@@ -348,6 +349,7 @@ def _compute_references(self) -> None:
     def _compute_transitive_vulnerabilities(self) -> None:
         raise NotImplementedError("Not meant to be implemented by the base class.")
 
+    @staged(logger, "Processing CPEDataset.")
     def _prepare_cpe_dataset(self, download_fresh: bool = False) -> CPEDataset:
         if not self.auxiliary_datasets_dir.exists():
             self.auxiliary_datasets_dir.mkdir(parents=True)
@@ -371,6 +373,7 @@ def _prepare_cpe_dataset(self, download_fresh: bool = False) -> CPEDataset:
 
         return cpe_dataset
 
+    @staged(logger, "Processing CVEDataset.")
     def _prepare_cve_dataset(self, download_fresh: bool = False) -> CVEDataset:
         if not self.auxiliary_datasets_dir.exists():
             logger.info("Loading CVEDataset from json.")
@@ -395,6 +398,7 @@ def _prepare_cve_dataset(self, download_fresh: bool = False) -> CVEDataset:
 
         return cve_dataset
 
+    @staged(logger, "Processing CPE match dict.")
     def _prepare_cpe_match_dict(self, download_fresh: bool = False) -> dict:
         if self.cpe_match_json_path.exists():
             logger.info("Preparing CPE Match feed from json.")
@@ -433,6 +437,7 @@ def _prepare_cpe_match_dict(self, download_fresh: bool = False) -> dict:
         return cpe_match_dict
 
     @serialize
+    @staged(logger, "Computing heuristics: Finding CPE matches for certificates")
     def compute_cpe_heuristics(self) -> CPEClassifier:
         """
         Computes matching CPEs for the certificates.
@@ -465,7 +470,6 @@ def filter_condition(cpe: CPE) -> bool:
                 return False
             return True
 
-        logger.info("Computing heuristics: Finding CPE matches for certificates")
         if not self.auxiliary_datasets.cpe_dset:
             self.auxiliary_datasets.cpe_dset = self._prepare_cpe_dataset()
 
@@ -574,11 +578,11 @@ def _get_all_cpes_in_dataset(self) -> set[CPE]:
         return set(itertools.chain.from_iterable(cpe_matches))
 
     @serialize
+    @staged(logger, "Computing heuristics: CVEs in certificates.")
     def compute_related_cves(self) -> None:
         """
         Computes CVEs for the certificates, given their CPE matches.
         """
-        logger.info("Computing heuristics: CVEs in certificates.")
 
         if not self.auxiliary_datasets.cpe_dset:
             self.auxiliary_datasets.cpe_dset = self._prepare_cpe_dataset()
diff --git a/src/sec_certs/dataset/fips.py b/src/sec_certs/dataset/fips.py
index 24536b9b..a3d24b15 100644
--- a/src/sec_certs/dataset/fips.py
+++ b/src/sec_certs/dataset/fips.py
@@ -24,6 +24,7 @@
 from sec_certs.utils import helpers
 from sec_certs.utils import parallel_processing as cert_processing
 from sec_certs.utils.helpers import fips_dgst
+from sec_certs.utils.profiling import staged
 
 logger = logging.getLogger(__name__)
 
@@ -230,13 +231,13 @@ def _set_local_paths(self) -> None:
             cert.set_local_paths(self.policies_pdf_dir, self.policies_txt_dir, self.module_dir)
 
     @serialize
+    @staged(logger, "Downloading and processing certificates.")
     def get_certs_from_web(self, to_download: bool = True, keep_metadata: bool = True) -> None:
         self.web_dir.mkdir(parents=True, exist_ok=True)
 
         if to_download:
             self._download_html_resources()
 
-        logger.info("Adding unprocessed FIPS certificates into FIPSDataset.")
         self.certs = {x.dgst: x for x in self._get_all_certs_from_html_sources()}
         logger.info(f"The dataset now contains {len(self)} certificates.")
 
@@ -251,8 +252,8 @@ def process_auxiliary_datasets(self, download_fresh: bool = False) -> None:
         super().process_auxiliary_datasets(download_fresh)
         self.auxiliary_datasets.algorithm_dset = self._prepare_algorithm_dataset(download_fresh)
 
+    @staged(logger, "Processing FIPSAlgorithm dataset.")
     def _prepare_algorithm_dataset(self, download_fresh_algs: bool = False) -> FIPSAlgorithmDataset:
-        logger.info("Preparing FIPSAlgorithm dataset.")
         if not self.algorithm_dataset_path.exists() or download_fresh_algs:
             alg_dset = FIPSAlgorithmDataset.from_web(self.algorithm_dataset_path)
             alg_dset.to_json()
@@ -261,8 +262,8 @@ def _prepare_algorithm_dataset(self, download_fresh_algs: bool = False) -> FIPSA
 
         return alg_dset
 
+    @staged(logger, "Extracting Algorithms from policy tables")
     def _extract_algorithms_from_policy_tables(self):
-        logger.info("Extracting Algorithms from policy tables")
         certs_to_process = [x for x in self if x.state.policy_is_ok_to_analyze()]
         cert_processing.process_parallel(
             FIPSCertificate.get_algorithms_from_policy_tables,
@@ -271,8 +272,8 @@ def _extract_algorithms_from_policy_tables(self):
             progress_bar_desc="Extracting Algorithms from policy tables",
         )
 
+    @staged(logger, "Extracting security policy metadata from the pdfs")
     def _extract_policy_pdf_metadata(self) -> None:
-        logger.info("Extracting security policy metadata from the pdfs")
         certs_to_process = [x for x in self if x.state.policy_is_ok_to_analyze()]
         processed_certs = cert_processing.process_parallel(
             FIPSCertificate.extract_policy_pdf_metadata,
@@ -282,8 +283,8 @@ def _extract_policy_pdf_metadata(self) -> None:
         )
         self.update_with_certs(processed_certs)
 
+    @staged(logger, "Computing heuristics: Transitive vulnerabilities in referenc(ed/ing) certificates.")
     def _compute_transitive_vulnerabilities(self) -> None:
-        logger.info("Computing heuristics: Computing transitive vulnerabilities in referenc(ed/ing) certificates.")
         transitive_cve_finder = TransitiveVulnerabilityFinder(lambda cert: str(cert.cert_id))
         transitive_cve_finder.fit(self.certs, lambda cert: cert.heuristics.policy_processed_references)
 
@@ -292,20 +293,16 @@ def _compute_transitive_vulnerabilities(self) -> None:
             self.certs[dgst].heuristics.direct_transitive_cves = transitive_cve.direct_transitive_cves
             self.certs[dgst].heuristics.indirect_transitive_cves = transitive_cve.indirect_transitive_cves
 
-    def _prune_reference_candidates(self) -> None:
-        for cert in self:
-            cert.prune_referenced_cert_ids()
-
+    @staged(logger, "Computing heuristics: references between certificates.")
+    def _compute_references(self, keep_unknowns: bool = False) -> None:
         # Previously, a following procedure was used to prune reference_candidates:
         #   - A set of algorithms was obtained via self.auxiliary_datasets.algorithm_dset.get_algorithms_by_id(reference_candidate)
         #   - If any of these algorithms had the same vendor as the reference_candidate, the candidate was rejected
         #   - The rationale is that if an ID appears in a certificate s.t. an algorithm with the same ID was produced by the same vendor, the reference likely refers to alg.
         #   - Such reference should then be discarded.
         #   - We are uncertain of the effectivity of such measure, disabling it for now.
-
-    def _compute_references(self, keep_unknowns: bool = False) -> None:
-        logger.info("Computing heuristics: Recovering references between certificates")
-        self._prune_reference_candidates()
+        for cert in self:
+            cert.prune_referenced_cert_ids()
 
         policy_reference_finder = ReferenceFinder()
         policy_reference_finder.fit(
diff --git a/src/sec_certs/utils/profiling.py b/src/sec_certs/utils/profiling.py
new file mode 100644
index 00000000..7da67070
--- /dev/null
+++ b/src/sec_certs/utils/profiling.py
@@ -0,0 +1,45 @@
+import gc
+from contextlib import contextmanager
+from datetime import datetime
+from functools import wraps
+from logging import Logger
+
+import psutil
+
+
+@contextmanager
+def log_stage(logger: Logger, msg: str, collect_garbage: bool = False):
+    """Contextmanager that logs a message to the logger when it is entered and exited.
+    The message has debug information about memory use. Optionally, it can
+    run garbage collection when exiting.
+    """
+    meminfo = psutil.Process().memory_full_info()
+    logger.info(f">> Starting >> {msg}")
+    logger.debug(str(meminfo))
+    start_time = datetime.now()
+
+    try:
+        yield
+    finally:
+        end_time = datetime.now()
+        duration = end_time - start_time
+        meminfo = psutil.Process().memory_full_info()
+        logger.info(f"<< Finished << {msg} ({duration})")
+        logger.debug(str(meminfo))
+
+        if collect_garbage:
+            gc.collect()
+
+
+def staged(logger: Logger, log_message: str, collect_garbage: bool = False):
+    """Like log_stage but a decorator."""
+
+    def deco(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            with log_stage(logger, log_message, collect_garbage):
+                return func(*args, **kwargs)
+
+        return wrapper
+
+    return deco