Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoddemus authored Dec 17, 2024
1 parent 8c5b1b3 commit 05c4c4b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
26 changes: 16 additions & 10 deletions src/pytest_regressions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,22 +197,28 @@ def make_location_message(banner: str, filename: Path, aux_files: List[str]) ->
T = TypeVar("T", bound=Union[MutableSequence, MutableMapping])


def round_digits(data: T, precision: int) -> T:
def round_digits(data: T, digits: int) -> T:
"""
Recursively Round the values of any float value in a collection to the given precision.
:param data: The collection to round.
:param precision: The number of decimal places to round to.
:return: The collection with all float values rounded to the given precision.
Recursively round the values of any float value in a collection to the given number of digits. The rounding is done in-place.
:param data:
The collection to round.
:param digits:
The number of digits to round to.
:return:
The collection with all float values rounded to the given precision.
Note that the rounding is done in-place, so this return value only exists
because we use the function recursively.
"""
# change the generator depending on the collection type
# Change the generator depending on the collection type.
generator = enumerate(data) if isinstance(data, MutableSequence) else data.items()
for k, v in generator:
if isinstance(v, (MutableSequence, MutableMapping)):
data[k] = round_digits(v, precision)
data[k] = round_digits(v, digits)
elif isinstance(v, float):
data[k] = round(v, precision)
data[k] = round(v, digits)
else:
data[k] = v
return data
9 changes: 5 additions & 4 deletions src/pytest_regressions/data_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def check(
data_dict: Dict[str, Any],
basename: Optional[str] = None,
fullpath: Optional["os.PathLike[str]"] = None,
precision: Optional[int] = None,
round_digits: Optional[int] = None,
) -> None:
"""
Checks the given dict against a previously recorded version, or generate a new file.
Expand All @@ -48,14 +48,15 @@ def check(
will ignore ``datadir`` fixture when reading *expected* files but will still use it to
write *obtained* files. Useful if a reference file is located in the session data dir for example.
:param precision: if given, round all floats in the dict to the given number of digits.
:param round_digits:
If given, round all floats in the dict to the given number of digits.
``basename`` and ``fullpath`` are exclusive.
"""
__tracebackhide__ = True

if precision is not None:
round_digits(data_dict, precision)
if round_digits is not None:
round_digits(data_dict, round_digits)

def dump(filename: Path) -> None:
"""Dump dict contents to the given filename"""
Expand Down

0 comments on commit 05c4c4b

Please sign in to comment.