Skip to content

Commit

Permalink
use a plain method for handling errors
Browse files Browse the repository at this point in the history
  • Loading branch information
danieleades committed Feb 11, 2024
1 parent 2877d73 commit 4010d97
Showing 1 changed file with 18 additions and 34 deletions.
52 changes: 18 additions & 34 deletions copier/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
copier --help-all
```
"""
import inspect

import sys
from os import PathLike
from pathlib import Path
Expand All @@ -56,38 +56,28 @@

import yaml
from plumbum import cli, colors
from typing_extensions import ParamSpec

from .errors import UnsafeTemplateError, UserMessageError
from .main import Worker
from .tools import copier_version
from .types import AnyByStrDict, OptStr, StrSeq

P = ParamSpec("P")


def _handle_exceptions(method: Callable[P, int]) -> Callable[P, int]:
def inner(*args: P.args, **kwargs: P.kwargs) -> int:
def handle_exceptions(method: Callable[[], None]) -> int:
"""Handle keyboard interruption while running a method."""
try:
try:
try:
return method(*args, **kwargs)
except KeyboardInterrupt:
raise UserMessageError("Execution stopped by user")
except UserMessageError as error:
print(colors.red | "\n".join(error.args), file=sys.stderr)
return 1
except UnsafeTemplateError as error:
print(colors.red | "\n".join(error.args), file=sys.stderr)
# DOCS https://github.com/copier-org/copier/issues/1328#issuecomment-1723214165
return 0b100

# See https://github.com/copier-org/copier/pull/1513
if sys.version_info >= (3, 10):
inner.__signature__ = inspect.signature(method, eval_str=True) # type: ignore[attr-defined]
else:
inner.__signature__ = inspect.signature(method) # type: ignore[attr-defined]

return inner
method()
except KeyboardInterrupt:
raise UserMessageError("Execution stopped by user")
except UserMessageError as error:
print(colors.red | "\n".join(error.args), file=sys.stderr)
return 1
except UnsafeTemplateError as error:
print(colors.red | "\n".join(error.args), file=sys.stderr)
# DOCS https://github.com/copier-org/copier/issues/1328#issuecomment-1723214165
return 0b100
return 0


class CopierApp(cli.Application):
Expand Down Expand Up @@ -254,7 +244,6 @@ class CopierCopySubApp(_Subcommand):
help="Overwrite files that already exist, without asking.",
)

@_handle_exceptions
def main(self, template_src: str, destination_path: str) -> int:
"""Call [run_copy][copier.main.Worker.run_copy].
Expand All @@ -274,8 +263,7 @@ def main(self, template_src: str, destination_path: str) -> int:
defaults=self.force or self.defaults,
overwrite=self.force or self.overwrite,
) as worker:
worker.run_copy()
return 0
return handle_exceptions(worker.run_copy)


@CopierApp.subcommand("recopy")
Expand Down Expand Up @@ -322,7 +310,6 @@ class CopierRecopySubApp(_Subcommand):
help="Skip questions that have already been answered",
)

@_handle_exceptions
def main(self, destination_path: cli.ExistingDirectory = ".") -> int:
"""Call [run_recopy][copier.main.Worker.run_recopy].
Expand All @@ -340,8 +327,7 @@ def main(self, destination_path: cli.ExistingDirectory = ".") -> int:
overwrite=self.force or self.overwrite,
skip_answered=self.skip_answered,
) as worker:
worker.run_recopy()
return 0
return handle_exceptions(worker.run_recopy)


@CopierApp.subcommand("update")
Expand Down Expand Up @@ -394,7 +380,6 @@ class CopierUpdateSubApp(_Subcommand):
help="Skip questions that have already been answered",
)

@_handle_exceptions
def main(self, destination_path: cli.ExistingDirectory = ".") -> int:
"""Call [run_update][copier.main.Worker.run_update].
Expand All @@ -414,5 +399,4 @@ def main(self, destination_path: cli.ExistingDirectory = ".") -> int:
skip_answered=self.skip_answered,
overwrite=True,
) as worker:
worker.run_update()
return 0
return handle_exceptions(worker.run_update)

0 comments on commit 4010d97

Please sign in to comment.