Skip to content

Commit

Permalink
Implement show_locals feature
Browse files Browse the repository at this point in the history
  • Loading branch information
dhzdhd committed Aug 30, 2024
1 parent 9ce478d commit aa46424
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 74 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Check the [examples directory](https://github.com/dhzdhd/pysvt/tree/master/examp
```python
from pysvt import test

@test("<path_to_TOML_file>")
@test(file="<path_to_TOML_file>")
def function(arg1: int, arg2: int) -> int:
return arg1 + arg2
```
Expand Down Expand Up @@ -64,7 +64,7 @@ Check the [examples directory](https://github.com/dhzdhd/pysvt/tree/master/examp
from pysvt import test

# Specify the name of the method as the second argument
@test("<path_to_TOML_file>", "function")
@test(file="<path_to_TOML_file>", "function")
class Solution:
def function(self, arg1: int, arg2: int) -> int:
return arg1 + arg2
Expand Down Expand Up @@ -104,3 +104,10 @@ Check the [examples directory](https://github.com/dhzdhd/pysvt/tree/master/examp
## Running examples

`poetry run python -m examples.<example_file_name>`

## Creating a new release

- Update the version in `pyproject.toml`
- Update `CHANGELOG.md`
- Run local lint/format/tests
- Create new git tag and push
2 changes: 1 addition & 1 deletion examples/inspect_locals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def hello():
a = 5

for _ in range(5):
for i in range(5):
a += 1
print("Hello, World!")

Expand Down
57 changes: 0 additions & 57 deletions examples/inspect_locals_deco.py

This file was deleted.

35 changes: 21 additions & 14 deletions pysvt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rich.console import Console

from pysvt.utils.ctx import Timer
from pysvt.utils.validation import get_result_locals
from pysvt.utils.models import Result, _ClsModel, _FuncModel
from pysvt.utils.printer import Printer

Expand Down Expand Up @@ -74,9 +75,6 @@ def __init__(
redirect_stdout: bool = True,
show_locals: bool = False,
) -> None:
if show_locals:
raise NotImplementedError("show_locals has not been implemented yet")

if (file is None and data is None) or (file is not None and data is not None):
raise ValueError("Either of file or data argument should be filled")

Expand Down Expand Up @@ -118,7 +116,6 @@ def __call__(self, obj: object) -> Any:
"The decorator cannot be applied to non-instance methods. Instead, use it directly on the function"
)

# with self._printer.init() as _:
failures = 0

for index, data in enumerate(self._data.data):
Expand All @@ -137,7 +134,6 @@ def __call__(self, obj: object) -> Any:
"The decorator cannot be applied to instance methods. Instead, apply it on the class and pass the name of the method as an argument"
)

# with self._printer.init() as _:
failures = 0

for index, data in enumerate(self._data):
Expand Down Expand Up @@ -306,6 +302,7 @@ def _validate(self, data: _FuncModel, func: Callable[..., Any]) -> Result:
"""
partial_fn = partial(func)
stdout = None
local_vars = None

if data.inputs is not None:
if not isinstance(data.inputs, list):
Expand All @@ -317,29 +314,39 @@ def _validate(self, data: _FuncModel, func: Callable[..., Any]) -> Result:
try:
if self._redirect_stdout:
with redirect_stdout(StringIO()) as f:
result = partial_fn()
if self._show_locals:
result, local_vars = get_result_locals(partial_fn)
else:
result = partial_fn()
stdout = f.getvalue()
else:
result = partial_fn()
if self._show_locals:
result, local_vars = get_result_locals(partial_fn)
else:
result = partial_fn()
except Exception:
console.print_exception(show_locals=True)
else:
if self._redirect_stdout:
with redirect_stdout(StringIO()) as f:
result = partial_fn()
if self._show_locals:
result, local_vars = get_result_locals(partial_fn)
else:
result = partial_fn()
stdout = f.getvalue()
else:
result = partial_fn()
if self._show_locals:
result, local_vars = get_result_locals(partial_fn)
else:
result = partial_fn()

if self._postprocess is not None:
result = self._postprocess(result)

return Result(result, stdout, result == data.output)
return Result(result, stdout, result == data.output, local_vars)


class inspect_locals:
def __init__(self) -> None:
...
def __init__(self) -> None: ...

def __call__(self, obj: object) -> Any:
...
def __call__(self, obj: object) -> Any: ...
1 change: 1 addition & 0 deletions pysvt/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ class Result:
data: Any
stdout: str | None
valid: bool
local_vars: dict[str, Any] | None
8 changes: 8 additions & 0 deletions pysvt/utils/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def post_validation(
input_str = "\n".join(
map(lambda t: f" {t[0]} - {t[1]}", zip(input_args, data.inputs))
)
input_str = " None" if input_str.strip() == "" else input_str

exp_out_str = f"""{Printer.bold("Expected output")} - {data.output}"""
act_out_str = f"""{Printer.bold("Actual output")} - {res.data}"""

Expand All @@ -83,6 +85,12 @@ def post_validation(
if res.stdout is not None and res.stdout.strip() != "":
out_str += f"""\n\n{Printer.bold("Stdout")} -\n{res.stdout.strip()}"""

if res.local_vars is not None:
out_str += f"\n\n{Printer.bold("Local variables")} -"

for k, v in res.local_vars.items():
out_str += f"\n {k} - {v}"

emoji = ":white_check_mark:" if res.valid else ":cross_mark:"
time_str = (
f"{time_taken * 1000:.3f} ms" if time_taken < 1.0 else f"{time_taken:.3f} s"
Expand Down
25 changes: 25 additions & 0 deletions pysvt/utils/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import sys
import types
from typing import Callable, Any


def get_result_locals(
func: Callable[..., Any], *args, **kwargs
) -> tuple[Any, dict[str, Any]]:
frame: types.FrameType | None = None
trace = sys.gettrace()

def snatch_locals(_frame, name, arg):
nonlocal frame
if frame is None and name == "call":
frame = _frame
sys.settrace(trace)
return trace

sys.settrace(snatch_locals)
try:
result = func(*args, **kwargs)
finally:
sys.settrace(trace)

return (result, frame.f_locals)

0 comments on commit aa46424

Please sign in to comment.