Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for click #12

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "springs"
version = "1.12"
version = "1.13"
description = """\
A set of utilities to create and manage typed configuration files \
effectively, built on top of OmegaConf.\
Expand Down
135 changes: 90 additions & 45 deletions src/springs/commandline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import sys
from argparse import Action
from argparse import Action, Namespace
from dataclasses import dataclass, fields, is_dataclass
from inspect import getfullargspec, isclass
from pathlib import Path
Expand All @@ -9,6 +9,7 @@
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Type,
Expand Down Expand Up @@ -65,6 +66,10 @@ class Flag:
def short(self) -> str:
return f"-{self.name[0]}"

@property
def dest(self) -> str:
return self.name.replace("-", "_")

@property
def usage(self) -> str:
extras = (
Expand All @@ -77,7 +82,7 @@ def long(self) -> str:
return f"--{self.name}"

def add_argparse(self, parser: RichArgumentParser) -> Action:
kwargs: Dict[str, Any] = {"help": self.help}
kwargs: Dict[str, Any] = {"help": self.help, "dest": self.dest}
if self.action is not MISSING:
kwargs["action"] = self.action
if self.default is not MISSING:
Expand All @@ -89,6 +94,17 @@ def add_argparse(self, parser: RichArgumentParser) -> Action:

return parser.add_argument(self.short, self.long, **kwargs)

@property
def value(self) -> Any:
try:
return self.__flag_value__
except AttributeError:
raise RuntimeError("Flag value not set.")

@value.setter
def value(self, value: Any) -> None:
self.__flag_value__ = value

def __str__(self) -> str:
return f"{self.short}/{self.long}"

Expand Down Expand Up @@ -185,6 +201,35 @@ def make_cli(self, func: Callable, name: str) -> RichArgumentParser:
self.add_argparse(ap)
return ap

@property
def leftovers(self) -> List[str]:
try:
return self.__leftovers__
except AttributeError:
raise RuntimeError("Leftovers not set.")

@leftovers.setter
def leftovers(self, value: List[str]) -> None:
self.__leftovers__ = value

def add_opts(self, opts: Union[Dict[str, Any], Namespace]) -> None:
"""Parses the options and sets the values of the flags."""
opts = vars(opts) if isinstance(opts, Namespace) else opts
for flag in self.flags:
flag.value = opts[flag.dest]

@classmethod
def parse_args(cls, func: Callable, name: str) -> "CliFlags":
"""Parses the arguments and returns the namespace."""

ap = (cli_flags := cls()).make_cli(func=func, name=name)

opts, leftovers = ap.parse_known_args()
cli_flags.leftovers = leftovers
cli_flags.add_opts(opts)

return cli_flags


def check_if_callable_can_be_decorated(func: Callable):
expected_args = getfullargspec(func).args
Expand Down Expand Up @@ -280,31 +325,25 @@ def load_from_file_or_nickname(
return loaded_config


def wrap_main_method(
func: Callable[Concatenate[Any, MP], RT],
name: str,
def parse_input_config(
func: Callable[Concatenate[CT, MP], Any],
flags: CliFlags,
config_node: DictConfig,
*args: MP.args,
**kwargs: MP.kwargs,
) -> RT:
) -> CT:

if not isinstance(config_node, DictConfig):
raise TypeError("Config node must be a DictConfig")

# Making sure I can decorate this function
check_if_callable_can_be_decorated(func=func)
check_if_valid_main_args(func=func, args=args)

# Get argument parser and arguments
ap = CliFlags().make_cli(func=func, name=name)
opts, leftover_args = ap.parse_known_args()

# Checks if the args are a match for the 'path.to.key=value′ pattern
# expected for configuration overrides.
validate_leftover_args(leftover_args)
validate_leftover_args(flags.leftovers)

# setup logging level for the root logger
configure_logging(logging_level="DEBUG" if opts.debug else opts.log_level)
configure_logging(
logging_level="DEBUG" if flags.debug.value else flags.log_level.value
)

# set up parsers for the various config nodes and tables
tree_parser = ConfigTreeParser()
Expand All @@ -313,15 +352,15 @@ def wrap_main_method(
# We don't run the main program if the user
# has requested to print the any of the config.
do_no_run = (
opts.options
or opts.inputs
or opts.parsed
or opts.resolvers
or opts.nicknames
or opts.save
flags.options.value
or flags.inputs.value
or flags.parsed.value
or flags.resolvers.value
or flags.nicknames.value
or flags.save.value
)

if opts.resolvers:
if flags.resolvers.value:
# relative import here not to mess things up
from .resolvers import all_resolvers

Expand All @@ -337,7 +376,7 @@ def wrap_main_method(
borders=True,
)

if opts.nicknames:
if flags.nicknames.value:
table_parser(
title="Registered Nicknames",
columns=["Nickname", "Path"],
Expand All @@ -351,7 +390,7 @@ def wrap_main_method(
)

# Print default options if requested py the user
if opts.options:
if flags.options.value:
config_name = getattr(get_type(config_node), "__name__", None)
tree_parser(
title="Default Options",
Expand All @@ -366,12 +405,12 @@ def wrap_main_method(

# load options from one or more config files; if multiple config files
# are provided, the latter ones can override the former ones.
for config_file in opts.config:
for config_file in flags.config.value:
# Load config file
file_config = load_from_file_or_nickname(config_file)

# print the configuration if requested by the user
if opts.inputs:
if flags.inputs.value:
tree_parser(
title="Input From File",
subtitle=f"(path: '{config_file}')",
Expand All @@ -383,10 +422,10 @@ def wrap_main_method(
accumulator_config = unsafe_merge(accumulator_config, file_config)

# load options from cli
cli_config = from_options(leftover_args)
cli_config = from_options(flags.leftovers)

# print the configuration if requested by the user
if opts.inputs:
if flags.inputs.value:
tree_parser(
title="Input From Command Line",
config=cli_config,
Expand All @@ -397,7 +436,7 @@ def wrap_main_method(
# so that cli takes precedence over config files.
accumulator_config = unsafe_merge(accumulator_config, cli_config)

if do_no_run and not opts.parsed:
if do_no_run and not flags.parsed.value:
# if the user hasn't requested to print the parsed config
# and we are not running the main program, we can exit here.
sys.exit(0)
Expand All @@ -408,25 +447,24 @@ def wrap_main_method(
parsed_config = merge_and_catch(config_node, accumulator_config)

# print it if requested
if not (opts.quiet) or opts.parsed:
if not (flags.quiet.value) or flags.parsed.value:
tree_parser(
title="Parsed Config",
config=parsed_config,
print_help=False,
)

if opts.save is not None:
if flags.save.value is not None:
# save the parsed config to a file
with open(opts.save, "w") as f:
with open(flags.save.value, "w") as f:
f.write(to_yaml(parsed_config))

if do_no_run:
# we are not running because the user has requested to print
# either the options, inputs, or parsed config.
sys.exit(0)
else:
# we execute the main method and pass the parsed config to it
return func(parsed_config, *args, **kwargs)

return parsed_config


def cli(
Expand Down Expand Up @@ -490,17 +528,24 @@ def main(cfg: Config):

def wrapper(func: Callable[Concatenate[CT, MP], RT]) -> Callable[MP, RT]:
def wrapping(*args: MP.args, **kwargs: MP.kwargs) -> RT:
# I could have used a functools.partial here, but defining
# my own function instead allows me to provide nice typing
# annotations for mypy.
return wrap_main_method(
func,
name,
config_node,
*args,
**kwargs,

# Making sure I can decorate this function
check_if_callable_can_be_decorated(func=func)
check_if_valid_main_args(func=func, args=args)

# Parse the command line arguments
flags = CliFlags.parse_args(func=func, name=name)

# Parse the input config(s)
config = parse_input_config(
func=func,
flags=flags,
config_node=config_node,
)

# Call the main function
return func(config, *args, **kwargs)

return wrapping

return wrapper