Skip to content

Commit

Permalink
Refactor entrypoints (#28)
Browse files Browse the repository at this point in the history
* Create utilities.get_parser

Factor out common logic into utilities.
Removes some dependencies and code duplications.

* Use parser in main()

See answer https://stackoverflow.com/questions/37697502/python-unittest-for-argparse
for reference.
Also updated docstrings

* Make testable as init

* Refactor and test init()

Using sys.exit as insurance that test call is executed correclty. #7

Needs import sys in module

* Refactored calls in tests to main

Now even simpler tests! No argparse.

* format
  • Loading branch information
VaeterchenFrost authored Jul 16, 2020
1 parent 51dd14a commit 1ce18ee
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 133 deletions.
88 changes: 39 additions & 49 deletions tdvisu/construct_dpdb_visu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import argparse
import json
import logging
import sys
from datetime import datetime, timedelta
from pathlib import Path
from time import sleep
Expand All @@ -38,8 +39,7 @@
from tdvisu.utilities import convert_to_adj
from tdvisu.reader import TwReader
from tdvisu.visualization import flatten
from tdvisu.version import __date__, __version__ as version
from tdvisu.utilities import read_yml_or_cfg, logging_cfg, LOGLEVEL_EPILOG
from tdvisu.utilities import read_yml_or_cfg, logging_cfg, get_parser

LOGGER = logging.getLogger('construct_dpdb_visu.py')

Expand Down Expand Up @@ -573,79 +573,69 @@ def create_json(
return {}


def main(args: argparse.Namespace) -> None:
def main(args: List[str]) -> None:
"""
Main method running construct_dpdb_visu for arguments in 'args'
Main method running construct_dpdb_visu for arguments in 'args'.
Parameters
----------
args : argparse.Namespace
The namespace containing all (command-line) parameters.
args : List[str]
The array containing all (command-line) flags.
Returns
-------
None
"""
parser = get_parser("Extracts Information from "
"https://github.com/hmarkus/dp_on_dbs runs "
"for further visualization.")

logging_cfg(filename='logging.yml', loglevel=args.loglevel)
LOGGER.info("Called with '%s'", args)
parser.add_argument('problemnumber', type=int,
help="selected problem-id in the postgres-database.")
parser.add_argument('--twfile',
type=argparse.FileType('r', encoding='UTF-8'),
help="tw-file containing the edges of the graph - "
"obtained from dpdb with option --gr-file GR_FILE.")
parser.add_argument('--outfile', default='dbjson%d.json',
help="default:'dbjson%%d.json'")
parser.add_argument('--pretty', action='store_true',
help="pretty-print the JSON.")
parser.add_argument('--inter-nodes', action='store_true',
help="calculate and animate the shortest path between "
"successive bags in the order of evaluation.")
# get cmd-arguments
options = parser.parse_args(args)

problem_ = args.problemnumber
logging_cfg(filename='logging.yml', loglevel=options.loglevel)
LOGGER.info("Called with '%s'", options)
problem_ = options.problemnumber
# get twfile if supplied
try:
tw_file_ = args.twfile
tw_file_ = options.twfile
except AttributeError:
tw_file_ = None

# create JSON
result_json = create_json(problem=problem_, tw_file=tw_file_)
# build json filename, can be supplied with problem-number
try:
outfile = args.outfile % problem_
try: # build json filename, can be supplied with problem-number
outfile = options.outfile % problem_
except TypeError:
outfile = args.outfile
outfile = options.outfile
LOGGER.info("Output file-name: %s", outfile)
with open(outfile, 'w') as file:
json.dump(
result_json,
file,
sort_keys=True,
indent=2 if args.pretty else None,
indent=2 if options.pretty else None,
ensure_ascii=False)
LOGGER.debug("Wrote to %s", file)


if __name__ == "__main__":
# Parse args, call main
def init():
"""Initialization that is executed at the time of the module import."""
if __name__ == "__main__":
sys.exit(main(sys.argv[1:])) # call main function

PARSER = argparse.ArgumentParser(
description="""
Copyright (C) 2020 Martin Röbke
This program comes with ABSOLUTELY NO WARRANTY
This is free software, and you are welcome to redistribute it
under certain conditions; see COPYING for more information.

Extracts Information from https://github.com/hmarkus/dp_on_dbs runs
for further visualization.""",
epilog=LOGLEVEL_EPILOG,
formatter_class=argparse.RawDescriptionHelpFormatter
)

PARSER.add_argument('problemnumber', type=int,
help="selected problem-id in the postgres-database.")
PARSER.add_argument('--twfile',
type=argparse.FileType('r', encoding='UTF-8'),
help="tw-file containing the edges of the graph - "
"obtained from dpdb with option --gr-file GR_FILE.")
PARSER.add_argument('--loglevel', help="set the minimal loglevel for root")
PARSER.add_argument('--outfile', default='dbjson%d.json',
help="default:'dbjson%%d.json'")
PARSER.add_argument('--pretty', action='store_true',
help="pretty-print the JSON.")
PARSER.add_argument('--inter-nodes', action='store_true',
help="calculate and animate the shortest path between "
"successive bags in the order of evaluation.")
PARSER.add_argument('--version', action='version',
version='%(prog)s ' + version + ', ' + __date__)

# get cmd-arguments
_args = PARSER.parse_args()
main(_args)
init()
38 changes: 36 additions & 2 deletions tdvisu/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
If not, see https://www.gnu.org/licenses/gpl-3.0.html
"""

import argparse
from collections.abc import Iterable as iter_type
from configparser import ConfigParser, ParsingError, Error as CfgError
from itertools import chain
import logging
import logging.config
from itertools import chain
from pathlib import Path
from typing import Any, Generator, Iterable, Iterator, List, Tuple, TypeVar, Union
import yaml

from tdvisu.version import __date__, __version__

LOGGER = logging.getLogger('utilities.py')

CFG_EXT = ('.ini', '.cfg', '.conf', '.config')
Expand Down Expand Up @@ -446,3 +448,35 @@ def solution_node(
result += '|' + bottomlabel

return '{' + result + '}'


def get_parser(extra_desc: str = '') -> argparse.ArgumentParser:
"""
Prepare an argument parser for TDVisu scripts.
Parameters
----------
extra_desc : str, optional
Description about the script using the parser. The default is ''.
Returns
-------
parser : argparse.ArgumentParser
The prepared argument parser object.
"""
parser = argparse.ArgumentParser(
description="""
Copyright (C) 2020 Martin Röbke
This program comes with ABSOLUTELY NO WARRANTY
This is free software, and you are welcome to redistribute it
under certain conditions; see COPYING for more information.
"""
+ "\n" + extra_desc,
epilog=LOGLEVEL_EPILOG,
formatter_class=argparse.RawDescriptionHelpFormatter)

parser.add_argument('--version', action='version',
version='%(prog)s ' + __version__ + ', ' + __date__)
parser.add_argument('--loglevel', help="set the minimal loglevel for root")
return parser
67 changes: 28 additions & 39 deletions tdvisu/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,16 @@
import itertools
import json
import logging
import sys
from dataclasses import asdict
from pathlib import Path
from sys import stdin
from typing import Iterable, List, Optional, Union, NewType

from graphviz import Digraph, Graph
from tdvisu.visualization_data import (VisualizationData, IncidenceGraphData,
GeneralGraphData, SvgJoinData)
from tdvisu.version import __date__, __version__ as version
from tdvisu.svgjoin import svg_join
from tdvisu.utilities import flatten, LOGLEVEL_EPILOG, logging_cfg
from tdvisu.utilities import flatten, logging_cfg, get_parser
from tdvisu.utilities import bag_node, solution_node, base_style
from tdvisu.utilities import style_hide_edge, style_hide_node, emphasise_node

Expand Down Expand Up @@ -699,25 +698,38 @@ def call_svgjoin(self) -> None:
svg_join(**asdict(sj_data))


def main(args: argparse.Namespace) -> None:
def main(args: List[str]) -> None:
"""
Main method running construct_dpdb_visu for arguments in 'args'
Main method running visualization for arguments in 'args'.
Parameters
----------
args : argparse.Namespace
The namespace containing all (command-line) parameters.
args : List[str]
The array containing all (command-line) flags.
Returns
-------
None
"""
parser = get_parser(
"Visualizing Dynamic Programming on Tree-Decompositions.")
# possible to use stdin for the file.
parser.add_argument('infile', nargs='?',
type=argparse.FileType('r', encoding='UTF-8'),
default=sys.stdin,
help="Input file for the visualization "
"must conform with the 'JsonAPI.md'")
parser.add_argument('outfolder',
help="Folder to output the visualization results")

# get cmd-arguments
options = parser.parse_args(args)

logging_cfg(filename='logging.yml', loglevel=args.loglevel)
LOGGER.info("Called with '%s'", args)
logging_cfg(filename='logging.yml', loglevel=options.loglevel)
LOGGER.info("Called with '%s'", options)

infile = args.infile
outfolder = args.outfolder
infile = options.infile
outfolder = options.outfolder
if not outfolder:
outfolder = 'outfolder'
outfolder = Path(outfolder).resolve()
Expand All @@ -728,33 +740,10 @@ def main(args: argparse.Namespace) -> None:
visu.tree_dec_timeline()


if __name__ == "__main__":
# Parse args, call main

PARSER = argparse.ArgumentParser(
description="""
Copyright (C) 2020 Martin Röbke
This program comes with ABSOLUTELY NO WARRANTY
This is free software, and you are welcome to redistribute it
under certain conditions; see COPYING for more information.
def init():
"""Initialization that is executed at the time of the module import."""
if __name__ == "__main__":
sys.exit(main(sys.argv[1:])) # call main function

Visualizing Dynamic Programming on Tree-Decompositions.""",
epilog=LOGLEVEL_EPILOG,
formatter_class=argparse.RawDescriptionHelpFormatter)

# possible to use stdin for the file.
PARSER.add_argument('infile', nargs='?',
type=argparse.FileType('r', encoding='UTF-8'),
default=stdin,
help="Input file for the visualization "
"must conform with the 'JsonAPI.md'")
PARSER.add_argument('outfolder',
help="Folder to output the visualization results")
PARSER.add_argument('--version', action='version',
version='%(prog)s ' + version + ', ' + __date__)
PARSER.add_argument('--loglevel', help="set the minimal loglevel for root")

# get cmd-arguments
_args = PARSER.parse_args()
# call main()
main(_args)
init()
35 changes: 14 additions & 21 deletions test/test_construct_dpdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
"""

import argparse
import datetime

from pathlib import Path
import psycopg2 as pg

from tdvisu import construct_dpdb_visu as module
from tdvisu.construct_dpdb_visu import (read_cfg, db_config, DEFAULT_DBCONFIG,
IDpdbVisuConstruct, DpdbSharpSatVisu,
DpdbSatVisu, DpdbMinVcVisu, main)
Expand Down Expand Up @@ -137,28 +136,10 @@ def test_main(mocker, tmp_path):
query_edgearray = mocker.patch(
'tdvisu.construct_dpdb_visu.query_edgearray', return_value=[
(2, 1), (3, 2), (4, 2), (5, 4)])

parser = argparse.ArgumentParser()
parser.add_argument('problemnumber', type=int,
help="selected problem-id in the postgres-database.")
parser.add_argument('--twfile',
type=argparse.FileType('r', encoding='UTF-8'),
help="tw-file containing the edges of the graph - "
"obtained from dpdb with option --gr-file GR_FILE.")
parser.add_argument('--loglevel', help="set the minimal loglevel for root")
parser.add_argument('--outfile', default='dbjson%d.json',
help="default:'dbjson%%d.json'")
parser.add_argument('--pretty', action='store_true',
help="pretty-print the JSON.")
parser.add_argument('--inter-nodes', action='store_true',
help="calculate and animate the shortest path between "
"successive bags in the order of evaluation.")

# set cmd-arguments
outfile = str(tmp_path / 'test_main.json')
_args = parser.parse_args(['1', '--outfile', outfile])
# one mocked run
main(_args)
main(['1', '--outfile', outfile])

# Assertions
mock_connect.assert_called_once()
Expand All @@ -172,3 +153,15 @@ def test_main(mocker, tmp_path):
assert query_column_name.call_count == 5
assert query_td_bag.call_count == 5
assert query_td_node_status.call_count == 5


def test_init(mocker):
"""Test that main is called correctly if called as __main__."""
expected = -1000
main = mocker.patch.object(module, "main", return_value=expected)
mock_exit = mocker.patch.object(module.sys, 'exit')
mocker.patch.object(module, "__name__", "__main__")
module.init()

main.assert_called_once()
assert mock_exit.call_args[0][0] == expected
Loading

0 comments on commit 1ce18ee

Please sign in to comment.