diff --git a/handcalcs/decorator.py b/handcalcs/decorator.py index 814a231..a0ebdaf 100644 --- a/handcalcs/decorator.py +++ b/handcalcs/decorator.py @@ -1,7 +1,7 @@ __all__ = ["handcalc"] -from typing import Optional -from functools import wraps +from typing import Optional, Callable +from functools import wraps, update_wrapper import inspect import innerscope from .handcalcs import LatexRenderer @@ -13,43 +13,112 @@ def handcalc( left: str = "", right: str = "", scientific_notation: Optional[bool] = None, - decimal_separator: str = ".", jupyter_display: bool = False, + record: bool = False, ): def handcalc_decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - line_args = { - "override": override, - "precision": precision, - "sci_not": scientific_notation, - } - func_source = inspect.getsource(func) - cell_source = _func_source_to_cell(func_source) - # use innerscope to get the values of locals, closures, and globals when calling func - scope = innerscope.call(func, *args, **kwargs) - LatexRenderer.dec_sep = decimal_separator - renderer = LatexRenderer(cell_source, scope, line_args) - latex_code = renderer.render() - if jupyter_display: - try: - from IPython.display import Latex, display - except ModuleNotFoundError: - ModuleNotFoundError( - "jupyter_display option requires IPython.display to be installed." - ) - display(Latex(latex_code)) - return scope.return_value + if record: + decorated = HandcalcsCallRecorder( + func, + override, + precision, + left, + right, + scientific_notation, + jupyter_display, + ) + else: - # https://stackoverflow.com/questions/9943504/right-to-left-string-replace-in-python - latex_code = "".join(latex_code.replace("\\[", "", 1).rsplit("\\]", 1)) - return (left + latex_code + right, scope.return_value) + @wraps(func) + def decorated(*args, **kwargs): + line_args = { + "override": override, + "precision": precision, + "sci_not": scientific_notation, + } + func_source = inspect.getsource(func) + cell_source = _func_source_to_cell(func_source) + # innerscope retrieves values of locals, closures, and globals + scope = innerscope.call(func, *args, **kwargs) + renderer = LatexRenderer(cell_source, scope, line_args) + latex_code = renderer.render() + raw_latex_code = "".join( + latex_code.replace("\\[", "", 1).rsplit("\\]", 1) + ) + if jupyter_display: + try: + from IPython.display import Latex, display + except ModuleNotFoundError: + ModuleNotFoundError( + "jupyter_display option requires IPython.display to be installed." + ) + display(Latex(latex_code)) + return scope.return_value + return (left + raw_latex_code + right, scope.return_value) - return wrapper + return decorated return handcalc_decorator +class HandcalcsCallRecorder: + """ + Records function calls for the func stored in .callable + """ + + def __init__( + self, + func: Callable, + _override: str = "", + _precision: int = 3, + _left: str = "", + _right: str = "", + _scientific_notation: Optional[bool] = None, + _jupyter_display: bool = False, + ): + self.callable = func + self.history = list() + self._override = _override + self._precision = _precision + self._left = _left + self._right = _right + self._scientific_notation = _scientific_notation + self._jupyter_display = _jupyter_display + update_wrapper(self, func) + + def __repr__(self): + return f"{self.__class__.__name__}({self.callable.__name__}, num_of_calls: {len(self.history)})" + + @property + def calls(self): + return len(self.history) + + def __call__(self, *args, **kwargs): + line_args = { + "override": self._override, + "precision": self._precision, + "sci_not": self._scientific_notation, + } + func_source = inspect.getsource(self.callable) + cell_source = _func_source_to_cell(func_source) + # innerscope retrieves values of locals, closures, and globals + scope = innerscope.call(self.callable, *args, **kwargs) + renderer = LatexRenderer(cell_source, scope, line_args) + latex_code = renderer.render() + raw_latex_code = "".join(latex_code.replace("\\[", "", 1).rsplit("\\]", 1)) + self.history.append({"return": scope.return_value, "latex": raw_latex_code}) + if self._jupyter_display: + try: + from IPython.display import Latex, display + except ModuleNotFoundError: + ModuleNotFoundError( + "jupyter_display option requires IPython.display to be installed." + ) + display(Latex(latex_code)) + return scope.return_value + return (self._left + raw_latex_code + self._right, scope.return_value) + + def _func_source_to_cell(source: str): """ Returns a string that represents `source` but with no signature, doc string, diff --git a/handcalcs/handcalcs.py b/handcalcs/handcalcs.py index 0412d12..e3799a5 100644 --- a/handcalcs/handcalcs.py +++ b/handcalcs/handcalcs.py @@ -173,8 +173,6 @@ def dict_get(d: dict, item: Any) -> Any: # The renderer class ("output" class) class LatexRenderer: - # dec_sep = "." - def __init__(self, python_code_str: str, results: dict, line_args: dict): self.source = python_code_str self.results = results diff --git a/test_handcalcs/test_decorator_file.py b/test_handcalcs/test_decorator_file.py index e69de29..4af0cae 100644 --- a/test_handcalcs/test_decorator_file.py +++ b/test_handcalcs/test_decorator_file.py @@ -0,0 +1,36 @@ +from handcalcs.decorator import HandcalcsCallRecorder, handcalc +import pytest + +# Define a simple arithmetic function for testing +def simple_func(a: float, b: float) -> float: + c = a + b + return c + +@pytest.fixture +def recorder(): + return HandcalcsCallRecorder(simple_func) + +def test_simple_arithmetic(recorder): + result = recorder(1, 2) + assert result[1] == 3 + assert result[0] == '\n\\begin{aligned}\nc &= a + b = 1 + 2 &= 3 \n\\end{aligned}\n' + +def test_call_recording(recorder): + recorder(1.0, 2.0) + recorder(3.5, 4.5) + assert recorder.calls == 2 # There should be two recorded calls. + assert recorder.history[0]['return'] == 3.0 + assert recorder.history[1]['return'] == 8.0 + +def test_decorator_with_recording(): + decorated_func = handcalc(record=True)(simple_func) + result = decorated_func(1.0, 2.0) + assert result[1] == 3.0 + assert decorated_func.calls == 1 + assert decorated_func.history[0]['return'] == 3.0 + assert decorated_func.history[0]['latex'] == '\n\\begin{aligned}\nc &= a + b = 1.000 + 2.000 &= 3.000 \n\\end{aligned}\n' + + decorated_func = handcalc(record=False)(simple_func) + latex, result = decorated_func(1.0, 2.0) + assert result == 3.0 + assert latex == '\n\\begin{aligned}\nc &= a + b = 1.000 + 2.000 &= 3.000 \n\\end{aligned}\n'