diff --git a/src/macro.jl b/src/macro.jl index bf46a31..3cf45d1 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -25,6 +25,7 @@ end """ @check_allocs ignore_throw=true (function def) + @check_allocs ignore_throw=true func(...) Wraps the provided function definition so that all calls to it will be automatically checked for allocations. @@ -32,9 +33,14 @@ checked for allocations. If the check fails, an `AllocCheckFailure` exception is thrown containing the detailed failures, including the backtrace for each defect. -Note: All calls to the wrapped function are effectively a dynamic dispatch, which -means they are type-unstable and may allocate memory at function _entry_. `@check_allocs` -only guarantees the absence of allocations after the function has started running. +`@check_allocs` can also be applied to a function call, which operates by creating +an anonymous function that is passed to `@check_allocs` and then immediately calling +the wrapped result. + +!!! note + All calls to the wrapped function are effectively a dynamic dispatch, which + means they are type-unstable and may allocate memory at function _entry_. `@check_allocs` + only guarantees the absence of allocations after the function has started running. # Example ```jldoctest @@ -45,23 +51,27 @@ julia> multiply(1.5, 3.5) # no allocations for Float64 5.25 julia> multiply(rand(3,3), rand(3,3)) # matmul needs to allocate the result -ERROR: @check_alloc function contains 1 allocations. - +ERROR: @check_alloc function contains 1 allocations (1 allocations / 0 dynamic dispatches). Stacktrace: [1] macro expansion - @ ~/repos/AllocCheck/src/macro.jl:134 [inlined] + @ ~/.julia/dev/AllocCheck/src/macro.jl:157 [inlined] [2] multiply(x::Matrix{Float64}, y::Matrix{Float64}) - @ Main ./REPL[2]:133 + @ Main ./REPL[2]:156 [3] top-level scope - @ REPL[5]:1 + @ REPL[4]:1 + +julia> @check_allocs 1.5 * 3.5 # check a call +5.25 ``` """ macro check_allocs(ex...) kws, body = extract_keywords(ex) if _is_func_def(body) - return _check_allocs_macro(body, __module__, __source__; kws...) + return _check_allocs_defun(body, __module__, __source__; kws...) + elseif Meta.isexpr(body, :call) + return _check_allocs_call(body, __module__, __source__; kws...) else - error("@check_allocs used on something other than a function definition") + error("@check_allocs used on anything other than a function definition or call") end end @@ -117,13 +127,20 @@ function forward_args!(func_def) args, kwargs end -function _check_allocs_macro(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true) +function _check_allocs_defun(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true) + (; original_fn, f_sym, wrapper_fn) = _check_allocs_wrap_fn(ex, mod, source; ignore_throw) + quote + local $f_sym = $(esc(original_fn)) + $wrapper_fn + end +end +function _check_allocs_wrap_fn(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true) # Transform original function to a renamed version with flattened args def = splitdef(deepcopy(ex)) normalize_args!(def) original_fn = combinedef(def) - f_sym = haskey(def, :name) ? gensym(def[:name]) : gensym() + f_sym = haskey(def, :name) ? gensym(def[:name]) : gensym("fn_alias") # Next, create a wrapper function that will compile the original function on-the-fly. def = splitdef(ex) @@ -149,8 +166,26 @@ function _check_allocs_macro(ex::Expr, mod::Module, source::LineNumberNode; igno def[:body].args[1] = source wrapper_fn = combinedef(def) - return quote - local $f_sym = $(esc(original_fn)) - $(wrapper_fn) + + (; original_fn, f_sym, wrapper_fn) +end + +function _check_allocs_call(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true) + fn = first(ex.args) + args = ex.args[2:end] + args_template = if !isempty(args) && Meta.isexpr(first(args), :parameters) + kwargs = Expr(:parameters, map(a -> if Meta.isexpr(a, :kw) first(a.args) else a end::Symbol, first(args).args)...) + [kwargs, map(_ -> gensym("arg"), 2:length(args))...] + else + [map(_ -> gensym("arg"), 1:length(args))...] + end + passthrough_defun = Expr(:function, Expr(:tuple, args_template...), Expr(:call, fn, args_template...)) + (original_fn, f_sym, wrapper_fn) = _check_allocs_wrap_fn(passthrough_defun, mod, source; ignore_throw) + af_sym = gensym("alloccheck_fn") + quote + let $f_sym = $(esc(original_fn)) + $af_sym = $wrapper_fn + $(Expr(:call, af_sym, map(esc, args)...)) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 455ec48..ef30930 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -141,6 +141,11 @@ end @test mysum2(x, y) == x + y @check_allocs (x::Bar)(y::Bar) = x.val + y.val @test Bar(x)(Bar(y)) == x + y + + # Callsite forms + @test 1 + x == @check_allocs 1 + x + @test x^2 == @check_allocs (a -> a^2)(x) + @test_throws AllocCheck.AllocCheckFailure @check_allocs same_ccall() end