From 9e154bf21e35c87f12317b622bdfcc707ff940ea Mon Sep 17 00:00:00 2001 From: J08nY <johny@neuromancer.sk> Date: Wed, 26 Jul 2023 22:33:46 +0200 Subject: [PATCH] Add rudimentary profiling. Fixes #288. --- src/sec_certs/dataset/cc.py | 66 ++++++++++++++++---------------- src/sec_certs/dataset/dataset.py | 8 +++- src/sec_certs/dataset/fips.py | 23 +++++------ src/sec_certs/utils/profiling.py | 45 ++++++++++++++++++++++ 4 files changed, 93 insertions(+), 49 deletions(-) create mode 100644 src/sec_certs/utils/profiling.py diff --git a/src/sec_certs/dataset/cc.py b/src/sec_certs/dataset/cc.py index 0db196c9..16c95e84 100644 --- a/src/sec_certs/dataset/cc.py +++ b/src/sec_certs/dataset/cc.py @@ -33,6 +33,7 @@ from sec_certs.serialization.json import ComplexSerializableType, serialize from sec_certs.utils import helpers from sec_certs.utils import parallel_processing as cert_processing +from sec_certs.utils.profiling import staged @dataclass @@ -270,6 +271,7 @@ def _download_csv_html_resources(self, get_active: bool = True, get_archived: bo helpers.download_parallel(csv_urls, csv_paths) @serialize + @staged(logger, "Downloading and processing CSV and HTML files of certificates.") def get_certs_from_web( self, to_download: bool = True, keep_metadata: bool = True, get_active: bool = True, get_archived: bool = True ) -> None: @@ -520,12 +522,11 @@ def _download_all_artifacts_body(self, fresh: bool = True) -> None: self._download_reports(fresh) self._download_targets(fresh) + @staged(logger, "Downloading PDFs of CC certification reports.") def _download_reports(self, fresh: bool = True) -> None: self.reports_pdf_dir.mkdir(parents=True, exist_ok=True) certs_to_process = [x for x in self if x.state.report_is_ok_to_download(fresh) and x.report_link] - if fresh: - logger.info("Downloading PDFs of CC certification reports.") if not fresh and certs_to_process: logger.info( f"Downloading {len(certs_to_process)} PDFs of CC certification reports for which previous download failed." @@ -537,12 +538,11 @@ def _download_reports(self, fresh: bool = True) -> None: progress_bar_desc="Downloading PDFs of CC certification reports", ) + @staged(logger, "Downloading PDFs of CC security targets.") def _download_targets(self, fresh: bool = True) -> None: self.targets_pdf_dir.mkdir(parents=True, exist_ok=True) certs_to_process = [x for x in self if x.state.report_is_ok_to_download(fresh)] - if fresh: - logger.info("Downloading PDFs of CC security targets.") if not fresh and certs_to_process: logger.info( f"Downloading {len(certs_to_process)} PDFs of CC security targets for which previous download failed.." @@ -554,12 +554,11 @@ def _download_targets(self, fresh: bool = True) -> None: progress_bar_desc="Downloading PDFs of CC security targets", ) + @staged(logger, "Converting PDFs of certification reports to txt.") def _convert_reports_to_txt(self, fresh: bool = True) -> None: self.reports_txt_dir.mkdir(parents=True, exist_ok=True) certs_to_process = [x for x in self if x.state.report_is_ok_to_convert(fresh)] - if fresh: - logger.info("Converting PDFs of certification reports to txt.") if not fresh and certs_to_process: logger.info( f"Converting {len(certs_to_process)} PDFs of certification reports to txt for which previous conversion failed." @@ -571,6 +570,7 @@ def _convert_reports_to_txt(self, fresh: bool = True) -> None: progress_bar_desc="Converting PDFs of certification reports to txt", ) + @staged(logger, "Converting PDFs of security targets to txt.") def _convert_targets_to_txt(self, fresh: bool = True) -> None: self.targets_txt_dir.mkdir(parents=True, exist_ok=True) certs_to_process = [x for x in self if x.state.st_is_ok_to_convert(fresh)] @@ -592,8 +592,8 @@ def _convert_all_pdfs_body(self, fresh: bool = True) -> None: self._convert_reports_to_txt(fresh) self._convert_targets_to_txt(fresh) + @staged(logger, "Extracting report metadata") def _extract_report_metadata(self) -> None: - logger.info("Extracting report metadata") certs_to_process = [x for x in self if x.state.report_is_ok_to_analyze()] processed_certs = cert_processing.process_parallel( CCCertificate.extract_report_pdf_metadata, @@ -603,8 +603,8 @@ def _extract_report_metadata(self) -> None: ) self.update_with_certs(processed_certs) - def _extract_targets_metadata(self) -> None: - logger.info("Extracting target metadata") + @staged(logger, "Extracting target metadata") + def _extract_target_metadata(self) -> None: certs_to_process = [x for x in self if x.state.st_is_ok_to_analyze()] processed_certs = cert_processing.process_parallel( CCCertificate.extract_st_pdf_metadata, @@ -616,10 +616,10 @@ def _extract_targets_metadata(self) -> None: def _extract_pdf_metadata(self) -> None: self._extract_report_metadata() - self._extract_targets_metadata() + self._extract_target_metadata() + @staged(logger, "Extracting report frontpages") def _extract_report_frontpage(self) -> None: - logger.info("Extracting report frontpages") certs_to_process = [x for x in self if x.state.report_is_ok_to_analyze()] processed_certs = cert_processing.process_parallel( CCCertificate.extract_report_pdf_frontpage, @@ -629,8 +629,8 @@ def _extract_report_frontpage(self) -> None: ) self.update_with_certs(processed_certs) - def _extract_targets_frontpage(self) -> None: - logger.info("Extracting target frontpages") + @staged(logger, "Extracting target frontpages") + def _extract_target_frontpage(self) -> None: certs_to_process = [x for x in self if x.state.st_is_ok_to_analyze()] processed_certs = cert_processing.process_parallel( CCCertificate.extract_st_pdf_frontpage, @@ -642,10 +642,10 @@ def _extract_targets_frontpage(self) -> None: def _extract_pdf_frontpage(self) -> None: self._extract_report_frontpage() - self._extract_targets_frontpage() + self._extract_target_frontpage() + @staged(logger, "Extracting report keywords") def _extract_report_keywords(self) -> None: - logger.info("Extracting report keywords") certs_to_process = [x for x in self if x.state.report_is_ok_to_analyze()] processed_certs = cert_processing.process_parallel( CCCertificate.extract_report_pdf_keywords, @@ -655,8 +655,8 @@ def _extract_report_keywords(self) -> None: ) self.update_with_certs(processed_certs) - def _extract_targets_keywords(self) -> None: - logger.info("Extracting target keywords") + @staged(logger, "Extracting target keywords") + def _extract_target_keywords(self) -> None: certs_to_process = [x for x in self if x.state.st_is_ok_to_analyze()] processed_certs = cert_processing.process_parallel( CCCertificate.extract_st_pdf_keywords, @@ -668,7 +668,7 @@ def _extract_targets_keywords(self) -> None: def _extract_pdf_keywords(self) -> None: self._extract_report_keywords() - self._extract_targets_keywords() + self._extract_target_keywords() def extract_data(self) -> None: logger.info("Extracting various data from certification artifacts") @@ -676,19 +676,19 @@ def extract_data(self) -> None: self._extract_pdf_frontpage() self._extract_pdf_keywords() + @staged(logger, "Computing heuristics: Deriving information about laboratories involved in certification.") def _compute_cert_labs(self) -> None: - logger.info("Computing heuristics: Deriving information about laboratories involved in certification.") certs_to_process = [x for x in self if x.state.report_is_ok_to_analyze()] for cert in certs_to_process: cert.compute_heuristics_cert_lab() + @staged(logger, "Computing heuristics: Deriving information about certificate ids from artifacts.") def _compute_normalized_cert_ids(self) -> None: - logger.info("Computing heuristics: Deriving information about certificate ids from artifacts.") for cert in self: cert.compute_heuristics_cert_id() + @staged(logger, "Computing heuristics: Transitive vulnerabilities in referenc(ed/ing) certificates.") def _compute_transitive_vulnerabilities(self): - logger.info("omputing heuristics: computing transitive vulnerabilities in referenc(ed/ing) certificates.") transitive_cve_finder = TransitiveVulnerabilityFinder(lambda cert: cert.heuristics.cert_id) transitive_cve_finder.fit(self.certs, lambda cert: cert.heuristics.report_references) @@ -698,9 +698,9 @@ def _compute_transitive_vulnerabilities(self): self.certs[dgst].heuristics.direct_transitive_cves = transitive_cve.direct_transitive_cves self.certs[dgst].heuristics.indirect_transitive_cves = transitive_cve.indirect_transitive_cves + @staged(logger, "Computing heuristics: Matching scheme data.") def _compute_scheme_data(self): if self.auxiliary_datasets.scheme_dset: - print("here") for scheme in self.auxiliary_datasets.scheme_dset: if certified := scheme.lists.get(EntryType.Certified): certs = [cert for cert in self if cert.status == "active"] @@ -713,6 +713,12 @@ def _compute_scheme_data(self): for dgst, match in matches.items(): self[dgst].heuristics.scheme_data = match + @staged(logger, "Computing heuristics: SARs") + def _compute_sars(self) -> None: + transformer = SARTransformer().fit(self.certs.values()) + for cert in self: + cert.heuristics.extracted_sars = transformer.transform_single_cert(cert) + def _compute_heuristics(self) -> None: self._compute_normalized_cert_ids() super()._compute_heuristics() @@ -720,12 +726,7 @@ def _compute_heuristics(self) -> None: self._compute_cert_labs() self._compute_sars() - def _compute_sars(self) -> None: - logger.info("Computing heuristics: Computing SARs") - transformer = SARTransformer().fit(self.certs.values()) - for cert in self: - cert.heuristics.extracted_sars = transformer.transform_single_cert(cert) - + @staged(logger, "Computing heuristics: references between certificates.") def _compute_references(self) -> None: def ref_lookup(kw_attr): def func(cert): @@ -744,7 +745,6 @@ def func(cert): return func - logger.info("omputing heuristics: references between certificates.") for ref_source in ("report", "st"): kw_source = f"{ref_source}_keywords" dep_attr = f"{ref_source}_references" @@ -768,6 +768,7 @@ def process_auxiliary_datasets(self, download_fresh: bool = False) -> None: to_download=download_fresh, only_schemes={cert.scheme for cert in self} ) + @staged(logger, "Processing protection profiles.") def process_protection_profiles( self, to_download: bool = True, keep_metadata: bool = True ) -> ProtectionProfileDataset: @@ -779,7 +780,6 @@ def process_protection_profiles( :param bool keep_metadata: If json related to the PP dataset should be kept on drive, defaults to True :raises RuntimeError: When building of PPDataset fails """ - logger.info("Processing protection profiles.") self.auxiliary_datasets_dir.mkdir(parents=True, exist_ok=True) @@ -798,13 +798,12 @@ def process_protection_profiles( return pp_dataset + @staged(logger, "Processing maintenace updates.") def process_maintenance_updates(self, to_download: bool = True) -> CCDatasetMaintenanceUpdates: """ Downloads or loads from json a dataset of maintenance updates. Runs analysis on that dataset if it's not completed. :return CCDatasetMaintenanceUpdates: the resulting dataset of maintenance updates """ - - logger.info("Processing maintenace updates") self.mu_dataset_dir.mkdir(parents=True, exist_ok=True) if to_download or not self.mu_dataset_path.exists(): @@ -827,12 +826,11 @@ def process_maintenance_updates(self, to_download: bool = True) -> CCDatasetMain return update_dset + @staged(logger, "Processing CC scheme dataset.") def process_schemes(self, to_download: bool = True, only_schemes: set[str] | None = None) -> CCSchemeDataset: """ Downloads or loads from json a dataset of CC scheme data. """ - logger.info("Processing CC schemes") - self.auxiliary_datasets_dir.mkdir(parents=True, exist_ok=True) if to_download or not self.scheme_dataset_path.exists(): diff --git a/src/sec_certs/dataset/dataset.py b/src/sec_certs/dataset/dataset.py index e836f492..6466630f 100644 --- a/src/sec_certs/dataset/dataset.py +++ b/src/sec_certs/dataset/dataset.py @@ -25,6 +25,7 @@ from sec_certs.serialization.json import ComplexSerializableType, get_class_fullname, serialize from sec_certs.utils import helpers from sec_certs.utils.nvd_dataset_builder import CpeMatchNvdDatasetBuilder, CpeNvdDatasetBuilder, CveNvdDatasetBuilder +from sec_certs.utils.profiling import staged from sec_certs.utils.tqdm import tqdm logger = logging.getLogger(__name__) @@ -348,6 +349,7 @@ def _compute_references(self) -> None: def _compute_transitive_vulnerabilities(self) -> None: raise NotImplementedError("Not meant to be implemented by the base class.") + @staged(logger, "Processing CPEDataset.") def _prepare_cpe_dataset(self, download_fresh: bool = False) -> CPEDataset: if not self.auxiliary_datasets_dir.exists(): self.auxiliary_datasets_dir.mkdir(parents=True) @@ -371,6 +373,7 @@ def _prepare_cpe_dataset(self, download_fresh: bool = False) -> CPEDataset: return cpe_dataset + @staged(logger, "Processing CVEDataset.") def _prepare_cve_dataset(self, download_fresh: bool = False) -> CVEDataset: if not self.auxiliary_datasets_dir.exists(): logger.info("Loading CVEDataset from json.") @@ -395,6 +398,7 @@ def _prepare_cve_dataset(self, download_fresh: bool = False) -> CVEDataset: return cve_dataset + @staged(logger, "Processing CPE match dict.") def _prepare_cpe_match_dict(self, download_fresh: bool = False) -> dict: if self.cpe_match_json_path.exists(): logger.info("Preparing CPE Match feed from json.") @@ -433,6 +437,7 @@ def _prepare_cpe_match_dict(self, download_fresh: bool = False) -> dict: return cpe_match_dict @serialize + @staged(logger, "Computing heuristics: Finding CPE matches for certificates") def compute_cpe_heuristics(self) -> CPEClassifier: """ Computes matching CPEs for the certificates. @@ -465,7 +470,6 @@ def filter_condition(cpe: CPE) -> bool: return False return True - logger.info("Computing heuristics: Finding CPE matches for certificates") if not self.auxiliary_datasets.cpe_dset: self.auxiliary_datasets.cpe_dset = self._prepare_cpe_dataset() @@ -574,11 +578,11 @@ def _get_all_cpes_in_dataset(self) -> set[CPE]: return set(itertools.chain.from_iterable(cpe_matches)) @serialize + @staged(logger, "Computing heuristics: CVEs in certificates.") def compute_related_cves(self) -> None: """ Computes CVEs for the certificates, given their CPE matches. """ - logger.info("Computing heuristics: CVEs in certificates.") if not self.auxiliary_datasets.cpe_dset: self.auxiliary_datasets.cpe_dset = self._prepare_cpe_dataset() diff --git a/src/sec_certs/dataset/fips.py b/src/sec_certs/dataset/fips.py index 24536b9b..a3d24b15 100644 --- a/src/sec_certs/dataset/fips.py +++ b/src/sec_certs/dataset/fips.py @@ -24,6 +24,7 @@ from sec_certs.utils import helpers from sec_certs.utils import parallel_processing as cert_processing from sec_certs.utils.helpers import fips_dgst +from sec_certs.utils.profiling import staged logger = logging.getLogger(__name__) @@ -230,13 +231,13 @@ def _set_local_paths(self) -> None: cert.set_local_paths(self.policies_pdf_dir, self.policies_txt_dir, self.module_dir) @serialize + @staged(logger, "Downloading and processing certificates.") def get_certs_from_web(self, to_download: bool = True, keep_metadata: bool = True) -> None: self.web_dir.mkdir(parents=True, exist_ok=True) if to_download: self._download_html_resources() - logger.info("Adding unprocessed FIPS certificates into FIPSDataset.") self.certs = {x.dgst: x for x in self._get_all_certs_from_html_sources()} logger.info(f"The dataset now contains {len(self)} certificates.") @@ -251,8 +252,8 @@ def process_auxiliary_datasets(self, download_fresh: bool = False) -> None: super().process_auxiliary_datasets(download_fresh) self.auxiliary_datasets.algorithm_dset = self._prepare_algorithm_dataset(download_fresh) + @staged(logger, "Processing FIPSAlgorithm dataset.") def _prepare_algorithm_dataset(self, download_fresh_algs: bool = False) -> FIPSAlgorithmDataset: - logger.info("Preparing FIPSAlgorithm dataset.") if not self.algorithm_dataset_path.exists() or download_fresh_algs: alg_dset = FIPSAlgorithmDataset.from_web(self.algorithm_dataset_path) alg_dset.to_json() @@ -261,8 +262,8 @@ def _prepare_algorithm_dataset(self, download_fresh_algs: bool = False) -> FIPSA return alg_dset + @staged(logger, "Extracting Algorithms from policy tables") def _extract_algorithms_from_policy_tables(self): - logger.info("Extracting Algorithms from policy tables") certs_to_process = [x for x in self if x.state.policy_is_ok_to_analyze()] cert_processing.process_parallel( FIPSCertificate.get_algorithms_from_policy_tables, @@ -271,8 +272,8 @@ def _extract_algorithms_from_policy_tables(self): progress_bar_desc="Extracting Algorithms from policy tables", ) + @staged(logger, "Extracting security policy metadata from the pdfs") def _extract_policy_pdf_metadata(self) -> None: - logger.info("Extracting security policy metadata from the pdfs") certs_to_process = [x for x in self if x.state.policy_is_ok_to_analyze()] processed_certs = cert_processing.process_parallel( FIPSCertificate.extract_policy_pdf_metadata, @@ -282,8 +283,8 @@ def _extract_policy_pdf_metadata(self) -> None: ) self.update_with_certs(processed_certs) + @staged(logger, "Computing heuristics: Transitive vulnerabilities in referenc(ed/ing) certificates.") def _compute_transitive_vulnerabilities(self) -> None: - logger.info("Computing heuristics: Computing transitive vulnerabilities in referenc(ed/ing) certificates.") transitive_cve_finder = TransitiveVulnerabilityFinder(lambda cert: str(cert.cert_id)) transitive_cve_finder.fit(self.certs, lambda cert: cert.heuristics.policy_processed_references) @@ -292,20 +293,16 @@ def _compute_transitive_vulnerabilities(self) -> None: self.certs[dgst].heuristics.direct_transitive_cves = transitive_cve.direct_transitive_cves self.certs[dgst].heuristics.indirect_transitive_cves = transitive_cve.indirect_transitive_cves - def _prune_reference_candidates(self) -> None: - for cert in self: - cert.prune_referenced_cert_ids() - + @staged(logger, "Computing heuristics: references between certificates.") + def _compute_references(self, keep_unknowns: bool = False) -> None: # Previously, a following procedure was used to prune reference_candidates: # - A set of algorithms was obtained via self.auxiliary_datasets.algorithm_dset.get_algorithms_by_id(reference_candidate) # - If any of these algorithms had the same vendor as the reference_candidate, the candidate was rejected # - The rationale is that if an ID appears in a certificate s.t. an algorithm with the same ID was produced by the same vendor, the reference likely refers to alg. # - Such reference should then be discarded. # - We are uncertain of the effectivity of such measure, disabling it for now. - - def _compute_references(self, keep_unknowns: bool = False) -> None: - logger.info("Computing heuristics: Recovering references between certificates") - self._prune_reference_candidates() + for cert in self: + cert.prune_referenced_cert_ids() policy_reference_finder = ReferenceFinder() policy_reference_finder.fit( diff --git a/src/sec_certs/utils/profiling.py b/src/sec_certs/utils/profiling.py new file mode 100644 index 00000000..7da67070 --- /dev/null +++ b/src/sec_certs/utils/profiling.py @@ -0,0 +1,45 @@ +import gc +from contextlib import contextmanager +from datetime import datetime +from functools import wraps +from logging import Logger + +import psutil + + +@contextmanager +def log_stage(logger: Logger, msg: str, collect_garbage: bool = False): + """Contextmanager that logs a message to the logger when it is entered and exited. + The message has debug information about memory use. Optionally, it can + run garbage collection when exiting. + """ + meminfo = psutil.Process().memory_full_info() + logger.info(f">> Starting >> {msg}") + logger.debug(str(meminfo)) + start_time = datetime.now() + + try: + yield + finally: + end_time = datetime.now() + duration = end_time - start_time + meminfo = psutil.Process().memory_full_info() + logger.info(f"<< Finished << {msg} ({duration})") + logger.debug(str(meminfo)) + + if collect_garbage: + gc.collect() + + +def staged(logger: Logger, log_message: str, collect_garbage: bool = False): + """Like log_stage but a decorator.""" + + def deco(func): + @wraps(func) + def wrapper(*args, **kwargs): + with log_stage(logger, log_message, collect_garbage): + return func(*args, **kwargs) + + return wrapper + + return deco