diff --git a/tdvisu/construct_dpdb_visu.py b/tdvisu/construct_dpdb_visu.py index fd3b8b3..0c3a847 100644 --- a/tdvisu/construct_dpdb_visu.py +++ b/tdvisu/construct_dpdb_visu.py @@ -27,6 +27,7 @@ import argparse import json import logging +import sys from datetime import datetime, timedelta from pathlib import Path from time import sleep @@ -38,8 +39,7 @@ from tdvisu.utilities import convert_to_adj from tdvisu.reader import TwReader from tdvisu.visualization import flatten -from tdvisu.version import __date__, __version__ as version -from tdvisu.utilities import read_yml_or_cfg, logging_cfg, LOGLEVEL_EPILOG +from tdvisu.utilities import read_yml_or_cfg, logging_cfg, get_parser LOGGER = logging.getLogger('construct_dpdb_visu.py') @@ -573,79 +573,69 @@ def create_json( return {} -def main(args: argparse.Namespace) -> None: +def main(args: List[str]) -> None: """ - Main method running construct_dpdb_visu for arguments in 'args' + Main method running construct_dpdb_visu for arguments in 'args'. Parameters ---------- - args : argparse.Namespace - The namespace containing all (command-line) parameters. + args : List[str] + The array containing all (command-line) flags. Returns ------- None """ + parser = get_parser("Extracts Information from " + "https://github.com/hmarkus/dp_on_dbs runs " + "for further visualization.") - logging_cfg(filename='logging.yml', loglevel=args.loglevel) - LOGGER.info("Called with '%s'", args) + parser.add_argument('problemnumber', type=int, + help="selected problem-id in the postgres-database.") + parser.add_argument('--twfile', + type=argparse.FileType('r', encoding='UTF-8'), + help="tw-file containing the edges of the graph - " + "obtained from dpdb with option --gr-file GR_FILE.") + parser.add_argument('--outfile', default='dbjson%d.json', + help="default:'dbjson%%d.json'") + parser.add_argument('--pretty', action='store_true', + help="pretty-print the JSON.") + parser.add_argument('--inter-nodes', action='store_true', + help="calculate and animate the shortest path between " + "successive bags in the order of evaluation.") + # get cmd-arguments + options = parser.parse_args(args) - problem_ = args.problemnumber + logging_cfg(filename='logging.yml', loglevel=options.loglevel) + LOGGER.info("Called with '%s'", options) + problem_ = options.problemnumber # get twfile if supplied try: - tw_file_ = args.twfile + tw_file_ = options.twfile except AttributeError: tw_file_ = None + + # create JSON result_json = create_json(problem=problem_, tw_file=tw_file_) - # build json filename, can be supplied with problem-number - try: - outfile = args.outfile % problem_ + try: # build json filename, can be supplied with problem-number + outfile = options.outfile % problem_ except TypeError: - outfile = args.outfile + outfile = options.outfile LOGGER.info("Output file-name: %s", outfile) with open(outfile, 'w') as file: json.dump( result_json, file, sort_keys=True, - indent=2 if args.pretty else None, + indent=2 if options.pretty else None, ensure_ascii=False) LOGGER.debug("Wrote to %s", file) -if __name__ == "__main__": - # Parse args, call main +def init(): + """Initialization that is executed at the time of the module import.""" + if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) # call main function - PARSER = argparse.ArgumentParser( - description=""" - Copyright (C) 2020 Martin Röbke - This program comes with ABSOLUTELY NO WARRANTY - This is free software, and you are welcome to redistribute it - under certain conditions; see COPYING for more information. - Extracts Information from https://github.com/hmarkus/dp_on_dbs runs - for further visualization.""", - epilog=LOGLEVEL_EPILOG, - formatter_class=argparse.RawDescriptionHelpFormatter - ) - - PARSER.add_argument('problemnumber', type=int, - help="selected problem-id in the postgres-database.") - PARSER.add_argument('--twfile', - type=argparse.FileType('r', encoding='UTF-8'), - help="tw-file containing the edges of the graph - " - "obtained from dpdb with option --gr-file GR_FILE.") - PARSER.add_argument('--loglevel', help="set the minimal loglevel for root") - PARSER.add_argument('--outfile', default='dbjson%d.json', - help="default:'dbjson%%d.json'") - PARSER.add_argument('--pretty', action='store_true', - help="pretty-print the JSON.") - PARSER.add_argument('--inter-nodes', action='store_true', - help="calculate and animate the shortest path between " - "successive bags in the order of evaluation.") - PARSER.add_argument('--version', action='version', - version='%(prog)s ' + version + ', ' + __date__) - - # get cmd-arguments - _args = PARSER.parse_args() - main(_args) +init() diff --git a/tdvisu/utilities.py b/tdvisu/utilities.py index 5645a09..142a188 100644 --- a/tdvisu/utilities.py +++ b/tdvisu/utilities.py @@ -20,16 +20,18 @@ If not, see https://www.gnu.org/licenses/gpl-3.0.html """ - +import argparse from collections.abc import Iterable as iter_type from configparser import ConfigParser, ParsingError, Error as CfgError +from itertools import chain import logging import logging.config -from itertools import chain from pathlib import Path from typing import Any, Generator, Iterable, Iterator, List, Tuple, TypeVar, Union import yaml +from tdvisu.version import __date__, __version__ + LOGGER = logging.getLogger('utilities.py') CFG_EXT = ('.ini', '.cfg', '.conf', '.config') @@ -446,3 +448,35 @@ def solution_node( result += '|' + bottomlabel return '{' + result + '}' + + +def get_parser(extra_desc: str = '') -> argparse.ArgumentParser: + """ + Prepare an argument parser for TDVisu scripts. + + Parameters + ---------- + extra_desc : str, optional + Description about the script using the parser. The default is ''. + + Returns + ------- + parser : argparse.ArgumentParser + The prepared argument parser object. + + """ + parser = argparse.ArgumentParser( + description=""" + Copyright (C) 2020 Martin Röbke + This program comes with ABSOLUTELY NO WARRANTY + This is free software, and you are welcome to redistribute it + under certain conditions; see COPYING for more information. + """ + + "\n" + extra_desc, + epilog=LOGLEVEL_EPILOG, + formatter_class=argparse.RawDescriptionHelpFormatter) + + parser.add_argument('--version', action='version', + version='%(prog)s ' + __version__ + ', ' + __date__) + parser.add_argument('--loglevel', help="set the minimal loglevel for root") + return parser diff --git a/tdvisu/visualization.py b/tdvisu/visualization.py index 0c38ad8..4420ea9 100644 --- a/tdvisu/visualization.py +++ b/tdvisu/visualization.py @@ -31,17 +31,16 @@ import itertools import json import logging +import sys from dataclasses import asdict from pathlib import Path -from sys import stdin from typing import Iterable, List, Optional, Union, NewType from graphviz import Digraph, Graph from tdvisu.visualization_data import (VisualizationData, IncidenceGraphData, GeneralGraphData, SvgJoinData) -from tdvisu.version import __date__, __version__ as version from tdvisu.svgjoin import svg_join -from tdvisu.utilities import flatten, LOGLEVEL_EPILOG, logging_cfg +from tdvisu.utilities import flatten, logging_cfg, get_parser from tdvisu.utilities import bag_node, solution_node, base_style from tdvisu.utilities import style_hide_edge, style_hide_node, emphasise_node @@ -699,25 +698,38 @@ def call_svgjoin(self) -> None: svg_join(**asdict(sj_data)) -def main(args: argparse.Namespace) -> None: +def main(args: List[str]) -> None: """ - Main method running construct_dpdb_visu for arguments in 'args' + Main method running visualization for arguments in 'args'. Parameters ---------- - args : argparse.Namespace - The namespace containing all (command-line) parameters. + args : List[str] + The array containing all (command-line) flags. Returns ------- None """ + parser = get_parser( + "Visualizing Dynamic Programming on Tree-Decompositions.") + # possible to use stdin for the file. + parser.add_argument('infile', nargs='?', + type=argparse.FileType('r', encoding='UTF-8'), + default=sys.stdin, + help="Input file for the visualization " + "must conform with the 'JsonAPI.md'") + parser.add_argument('outfolder', + help="Folder to output the visualization results") + + # get cmd-arguments + options = parser.parse_args(args) - logging_cfg(filename='logging.yml', loglevel=args.loglevel) - LOGGER.info("Called with '%s'", args) + logging_cfg(filename='logging.yml', loglevel=options.loglevel) + LOGGER.info("Called with '%s'", options) - infile = args.infile - outfolder = args.outfolder + infile = options.infile + outfolder = options.outfolder if not outfolder: outfolder = 'outfolder' outfolder = Path(outfolder).resolve() @@ -728,33 +740,10 @@ def main(args: argparse.Namespace) -> None: visu.tree_dec_timeline() -if __name__ == "__main__": - # Parse args, call main - - PARSER = argparse.ArgumentParser( - description=""" - Copyright (C) 2020 Martin Röbke - This program comes with ABSOLUTELY NO WARRANTY - This is free software, and you are welcome to redistribute it - under certain conditions; see COPYING for more information. +def init(): + """Initialization that is executed at the time of the module import.""" + if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) # call main function - Visualizing Dynamic Programming on Tree-Decompositions.""", - epilog=LOGLEVEL_EPILOG, - formatter_class=argparse.RawDescriptionHelpFormatter) - # possible to use stdin for the file. - PARSER.add_argument('infile', nargs='?', - type=argparse.FileType('r', encoding='UTF-8'), - default=stdin, - help="Input file for the visualization " - "must conform with the 'JsonAPI.md'") - PARSER.add_argument('outfolder', - help="Folder to output the visualization results") - PARSER.add_argument('--version', action='version', - version='%(prog)s ' + version + ', ' + __date__) - PARSER.add_argument('--loglevel', help="set the minimal loglevel for root") - - # get cmd-arguments - _args = PARSER.parse_args() - # call main() - main(_args) +init() diff --git a/test/test_construct_dpdb.py b/test/test_construct_dpdb.py index 67557ac..cd699c4 100644 --- a/test/test_construct_dpdb.py +++ b/test/test_construct_dpdb.py @@ -21,12 +21,11 @@ """ -import argparse import datetime from pathlib import Path import psycopg2 as pg - +from tdvisu import construct_dpdb_visu as module from tdvisu.construct_dpdb_visu import (read_cfg, db_config, DEFAULT_DBCONFIG, IDpdbVisuConstruct, DpdbSharpSatVisu, DpdbSatVisu, DpdbMinVcVisu, main) @@ -137,28 +136,10 @@ def test_main(mocker, tmp_path): query_edgearray = mocker.patch( 'tdvisu.construct_dpdb_visu.query_edgearray', return_value=[ (2, 1), (3, 2), (4, 2), (5, 4)]) - - parser = argparse.ArgumentParser() - parser.add_argument('problemnumber', type=int, - help="selected problem-id in the postgres-database.") - parser.add_argument('--twfile', - type=argparse.FileType('r', encoding='UTF-8'), - help="tw-file containing the edges of the graph - " - "obtained from dpdb with option --gr-file GR_FILE.") - parser.add_argument('--loglevel', help="set the minimal loglevel for root") - parser.add_argument('--outfile', default='dbjson%d.json', - help="default:'dbjson%%d.json'") - parser.add_argument('--pretty', action='store_true', - help="pretty-print the JSON.") - parser.add_argument('--inter-nodes', action='store_true', - help="calculate and animate the shortest path between " - "successive bags in the order of evaluation.") - # set cmd-arguments outfile = str(tmp_path / 'test_main.json') - _args = parser.parse_args(['1', '--outfile', outfile]) # one mocked run - main(_args) + main(['1', '--outfile', outfile]) # Assertions mock_connect.assert_called_once() @@ -172,3 +153,15 @@ def test_main(mocker, tmp_path): assert query_column_name.call_count == 5 assert query_td_bag.call_count == 5 assert query_td_node_status.call_count == 5 + + +def test_init(mocker): + """Test that main is called correctly if called as __main__.""" + expected = -1000 + main = mocker.patch.object(module, "main", return_value=expected) + mock_exit = mocker.patch.object(module.sys, 'exit') + mocker.patch.object(module, "__name__", "__main__") + module.init() + + main.assert_called_once() + assert mock_exit.call_args[0][0] == expected diff --git a/test/test_visualization.py b/test/test_visualization.py index 5b4e277..8737d51 100644 --- a/test/test_visualization.py +++ b/test/test_visualization.py @@ -21,9 +21,8 @@ """ -import argparse from pathlib import Path - +from tdvisu import visualization as module from tdvisu.visualization import main EXPECT_DIR = Path(__file__).parent / 'expected_files' @@ -31,21 +30,13 @@ def test_sat_and_join(tmpdir): """Complete visualization run with svgjoin.""" - parser = argparse.ArgumentParser() - parser.add_argument('infile', - type=argparse.FileType('r', encoding='UTF-8'), - help="Input file for the visualization " - "must conform with the 'JsonAPI.md'") - parser.add_argument('outfolder', - help="Folder to output the visualization results") - parser.add_argument('--loglevel', help="set the minimal loglevel for root") # get cmd-arguments infile = Path(__file__).parent / 'dbjson4andjoin.json' outfolder = Path(tmpdir) / 'temp-test_sat_and_join' - _args = parser.parse_args([str(infile), str(outfolder)]) + args = [str(infile), str(outfolder)] # call main() - main(_args) + main(args) files = [file for file in outfolder.iterdir() if file.is_file()] assert len(files ) == 42, "total files" @@ -67,21 +58,13 @@ def test_sat_and_join(tmpdir): def test_vc_multiple_and_join(tmp_path): """Complete visualization run with svgjoin for MinVC and sorted graph.""" - parser = argparse.ArgumentParser() - parser.add_argument('infile', - type=argparse.FileType('r', encoding='UTF-8'), - help="Input file for the visualization " - "must conform with the 'JsonAPI.md'") - parser.add_argument('outfolder', - help="Folder to output the visualization results") - parser.add_argument('--loglevel', help="set the minimal loglevel for root") # get cmd-arguments infile = Path(__file__).parent / 'visualization_wheelgraph_2graphs.json' outfolder = tmp_path / 'temp-test_vc_multiple_and_join' - _args = parser.parse_args([str(infile), str(outfolder)]) + args = [str(infile), str(outfolder)] # call main() - main(_args) + main(args) files = [file for file in outfolder.iterdir() if file.is_file()] assert len(files ) == 35, "total files" @@ -99,3 +82,15 @@ def test_vc_multiple_and_join(tmp_path): with open(outfolder / file) as result: assert result.read() == expected.read( ), f"{file} should be the same" + + +def test_init(mocker): + """Test that main is called correctly if called as __main__.""" + expected = -1000 + main = mocker.patch.object(module, "main", return_value=expected) + mock_exit = mocker.patch.object(module.sys, 'exit') + mocker.patch.object(module, "__name__", "__main__") + module.init() + + main.assert_called_once() + assert mock_exit.call_args[0][0] == expected