Skip to content

Commit

Permalink
simplify some redundant type aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
danieleades committed Feb 10, 2024
1 parent 878a7e0 commit 5b2b541
Show file tree
Hide file tree
Showing 17 changed files with 110 additions and 132 deletions.
27 changes: 14 additions & 13 deletions copier/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@
copier --help-all
```
"""
from __future__ import annotations

import sys
from os import PathLike
from pathlib import Path
from textwrap import dedent
from typing import Any

import yaml
from decorator import decorator
Expand All @@ -60,7 +62,7 @@
from .errors import UnsafeTemplateError, UserMessageError
from .main import Worker
from .tools import copier_version
from .types import AnyByStrDict, OptStr, StrSeq
from .types import StrSeq


@decorator
Expand All @@ -84,21 +86,18 @@ class CopierApp(cli.Application):
"""The Copier CLI application."""

DESCRIPTION = "Create a new project from a template."
DESCRIPTION_MORE = (
dedent(
"""\
DESCRIPTION_MORE = dedent(
"""\
Docs in https://copier.readthedocs.io/
"""
)
+ (
colors.yellow
| dedent(
"""\
) + (
colors.yellow
| dedent(
"""\
WARNING! Use only trusted project templates, as they might
execute code with the same level of access as your user.\n
"""
)
)
)
VERSION = copier_version()
Expand All @@ -109,7 +108,7 @@ class _Subcommand(cli.Application):
"""Base class for Copier subcommands."""

def __init__(self, executable: PathLike) -> None:
self.data: AnyByStrDict = {}
self.data: dict[str, Any] = {}
super().__init__(executable)

answers_file = cli.SwitchAttr(
Expand Down Expand Up @@ -188,14 +187,16 @@ def data_file_switch(self, path: cli.ExistingFile) -> None:
path: The path to the YAML file to load.
"""
with open(path) as f:
file_updates: AnyByStrDict = yaml.safe_load(f)
file_updates: dict[str, Any] = yaml.safe_load(f)

updates_without_cli_overrides = {
k: v for k, v in file_updates.items() if k not in self.data
}
self.data.update(updates_without_cli_overrides)

def _worker(self, src_path: OptStr = None, dst_path: str = ".", **kwargs) -> Worker:
def _worker(
self, src_path: str | None = None, dst_path: str = ".", **kwargs
) -> Worker:
"""Run Copier's internal API using CLI switches.
Arguments:
Expand Down
30 changes: 11 additions & 19 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pathlib import Path
from shutil import rmtree
from tempfile import TemporaryDirectory
from typing import Callable, Iterable, Literal, Mapping, Sequence, get_args
from typing import Any, Callable, Iterable, Literal, Mapping, Sequence, get_args
from unicodedata import normalize

from jinja2.loaders import FileSystemLoader
Expand All @@ -36,15 +36,7 @@
from .subproject import Subproject
from .template import Task, Template
from .tools import OS, Style, normalize_git_path, printf, readlink
from .types import (
MISSING,
AnyByStrDict,
JSONSerializable,
OptStr,
RelativePath,
StrOrPath,
StrSeq,
)
from .types import MISSING, JSONSerializable, RelativePath, StrSeq
from .user_data import DEFAULT_DATA, AnswersMap, Question
from .vcs import get_git

Expand Down Expand Up @@ -162,14 +154,14 @@ class Worker:
src_path: str | None = None
dst_path: Path = Path(".")
answers_file: RelativePath | None = None
vcs_ref: OptStr = None
data: AnyByStrDict = field(default_factory=dict)
vcs_ref: str | None = None
data: dict[str, Any] = field(default_factory=dict)
exclude: StrSeq = ()
use_prereleases: bool = False
skip_if_exists: StrSeq = ()
cleanup_on_error: bool = True
defaults: bool = False
user_defaults: AnyByStrDict = field(default_factory=dict)
user_defaults: dict[str, Any] = field(default_factory=dict)
overwrite: bool = False
pretend: bool = False
quiet: bool = False
Expand Down Expand Up @@ -229,7 +221,7 @@ def _print_message(self, message: str) -> None:
def _answers_to_remember(self) -> Mapping:
"""Get only answers that will be remembered in the copier answers file."""
# All internal values must appear first
answers: AnyByStrDict = {}
answers: dict[str, Any] = {}
commit = self.template.commit
src = self.template.url
for key, value in (("_commit", commit), ("_src_path", src)):
Expand Down Expand Up @@ -999,8 +991,8 @@ def _git_initialize_repo(self):

def run_copy(
src_path: str,
dst_path: StrOrPath = ".",
data: AnyByStrDict | None = None,
dst_path: str | Path = ".",
data: dict[str, Any] | None = None,
**kwargs,
) -> Worker:
"""Copy a template to a destination, from zero.
Expand All @@ -1017,7 +1009,7 @@ def run_copy(


def run_recopy(
dst_path: StrOrPath = ".", data: AnyByStrDict | None = None, **kwargs
dst_path: str | Path = ".", data: dict[str, Any] | None = None, **kwargs
) -> Worker:
"""Update a subproject from its template, discarding subproject evolution.
Expand All @@ -1033,8 +1025,8 @@ def run_recopy(


def run_update(
dst_path: StrOrPath = ".",
data: AnyByStrDict | None = None,
dst_path: str | Path = ".",
data: dict[str, Any] | None = None,
**kwargs,
) -> Worker:
"""Update a subproject, from its template.
Expand Down
8 changes: 4 additions & 4 deletions copier/subproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from dataclasses import field
from functools import cached_property
from pathlib import Path
from typing import Callable
from typing import Any, Callable

import yaml
from plumbum.machines import local
from pydantic.dataclasses import dataclass

from .template import Template
from .types import AbsolutePath, AnyByStrDict, VCSTypes
from .types import AbsolutePath, VCSTypes
from .vcs import get_git, is_in_git_repo


Expand Down Expand Up @@ -51,7 +51,7 @@ def _cleanup(self):
method()

@property
def _raw_answers(self) -> AnyByStrDict:
def _raw_answers(self) -> dict[str, Any]:
"""Get last answers, loaded raw as yaml."""
try:
return yaml.safe_load(
Expand All @@ -61,7 +61,7 @@ def _raw_answers(self) -> AnyByStrDict:
return {}

@cached_property
def last_answers(self) -> AnyByStrDict:
def last_answers(self) -> dict[str, Any]:
"""Last answers, excluding private ones (except _src_path and _commit)."""
return {
key: value
Expand Down
28 changes: 14 additions & 14 deletions copier/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from functools import cached_property
from pathlib import Path
from shutil import rmtree
from typing import Literal, Mapping, Sequence
from typing import Any, Literal, Mapping, Sequence
from warnings import warn

import dunamai
Expand All @@ -29,7 +29,7 @@
UnsupportedVersionError,
)
from .tools import copier_version, handle_remove_readonly
from .types import AnyByStrDict, Env, OptStr, StrSeq, Union, VCSTypes
from .types import Env, StrSeq, VCSTypes
from .vcs import checkout_latest_tag, clone, get_git, get_repo

# Default list of files in the template to exclude from the rendered project
Expand All @@ -47,9 +47,9 @@
DEFAULT_TEMPLATES_SUFFIX = ".jinja"


def filter_config(data: AnyByStrDict) -> tuple[AnyByStrDict, AnyByStrDict]:
def filter_config(data: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
"""Separates config and questions data."""
config_data: AnyByStrDict = {}
config_data: dict[str, Any] = {}
questions_data = {}
for k, v in data.items():
if k.startswith("_"):
Expand All @@ -62,7 +62,7 @@ def filter_config(data: AnyByStrDict) -> tuple[AnyByStrDict, AnyByStrDict]:
return config_data, questions_data


def load_template_config(conf_path: Path, quiet: bool = False) -> AnyByStrDict:
def load_template_config(conf_path: Path, quiet: bool = False) -> dict[str, Any]:
"""Load the `copier.yml` file.
This is like a simple YAML load, but applying all specific quirks needed
Expand Down Expand Up @@ -153,7 +153,7 @@ class Task:
Additional environment variables to set while executing the command.
"""

cmd: Union[str, Sequence[str]]
cmd: str | Sequence[str]
extra_env: Env = field(default_factory=dict)


Expand Down Expand Up @@ -195,7 +195,7 @@ class Template:
"""

url: str
ref: OptStr = None
ref: str | None = None
use_prereleases: bool = False

def _cleanup(self) -> None:
Expand Down Expand Up @@ -231,7 +231,7 @@ def _temp_clone(self) -> Path | None:
return None

@cached_property
def _raw_config(self) -> AnyByStrDict:
def _raw_config(self) -> dict[str, Any]:
"""Get template configuration, raw.
It reads [the `copier.yml` file][the-copieryml-file].
Expand Down Expand Up @@ -260,20 +260,20 @@ def answers_relpath(self) -> Path:
return result

@cached_property
def commit(self) -> OptStr:
def commit(self) -> str | None:
"""If the template is VCS-tracked, get its commit description."""
if self.vcs == "git":
with local.cwd(self.local_abspath):
return get_git()("describe", "--tags", "--always").strip()

@cached_property
def commit_hash(self) -> OptStr:
def commit_hash(self) -> str | None:
"""If the template is VCS-tracked, get its commit full hash."""
if self.vcs == "git":
return get_git()("-C", self.local_abspath, "rev-parse", "HEAD").strip()

@cached_property
def config_data(self) -> AnyByStrDict:
def config_data(self) -> dict[str, Any]:
"""Get config from the template.
It reads [the `copier.yml` file][the-copieryml-file] to get its
Expand Down Expand Up @@ -340,13 +340,13 @@ def message_before_update(self) -> str:
return self.config_data.get("message_before_update", "")

@cached_property
def metadata(self) -> AnyByStrDict:
def metadata(self) -> dict[str, Any]:
"""Get template metadata.
This data, if any, should be saved in the answers file to be able to
restore the template to this same state.
"""
result: AnyByStrDict = {"_src_path": self.url}
result: dict[str, Any] = {"_src_path": self.url}
if self.commit:
result["_commit"] = self.commit
return result
Expand Down Expand Up @@ -399,7 +399,7 @@ def min_copier_version(self) -> Version | None:
return None

@cached_property
def questions_data(self) -> AnyByStrDict:
def questions_data(self) -> dict[str, Any]:
"""Get questions from the template.
See [questions][].
Expand Down
23 changes: 3 additions & 20 deletions copier/types.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
"""Complex types, annotations, validators."""

from __future__ import annotations

import sys
from pathlib import Path
from typing import (
Any,
Dict,
Literal,
Mapping,
NewType,
Optional,
Sequence,
TypeVar,
Union,
)
from typing import Literal, Mapping, NewType, Sequence, TypeVar

from pydantic import AfterValidator

Expand All @@ -21,20 +13,11 @@
else:
from typing_extensions import Annotated

# simple types
StrOrPath = Union[str, Path]
AnyByStrDict = Dict[str, Any]

# sequences
IntSeq = Sequence[int]
StrSeq = Sequence[str]
PathSeq = Sequence[Path]

# optional types
OptBool = Optional[bool]
OptStrOrPath = Optional[StrOrPath]
OptStr = Optional[str]

# miscellaneous
T = TypeVar("T")
JSONSerializable = (dict, list, str, int, float, bool, type(None))
Expand Down
Loading

0 comments on commit 5b2b541

Please sign in to comment.