Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support @check_allocs at callsites #54

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,22 @@ end

"""
@check_allocs ignore_throw=true (function def)
@check_allocs ignore_throw=true func(...)

Wraps the provided function definition so that all calls to it will be automatically
checked for allocations.

If the check fails, an `AllocCheckFailure` exception is thrown containing the detailed
failures, including the backtrace for each defect.

Note: All calls to the wrapped function are effectively a dynamic dispatch, which
means they are type-unstable and may allocate memory at function _entry_. `@check_allocs`
only guarantees the absence of allocations after the function has started running.
`@check_allocs` can also be applied to a function call, which operates by creating
an anonymous function that is passed to `@check_allocs` and then immediately calling
the wrapped result.

!!! note
All calls to the wrapped function are effectively a dynamic dispatch, which
means they are type-unstable and may allocate memory at function _entry_. `@check_allocs`
only guarantees the absence of allocations after the function has started running.

# Example
```jldoctest
Expand All @@ -45,23 +51,27 @@ julia> multiply(1.5, 3.5) # no allocations for Float64
5.25

julia> multiply(rand(3,3), rand(3,3)) # matmul needs to allocate the result
ERROR: @check_alloc function contains 1 allocations.

ERROR: @check_alloc function contains 1 allocations (1 allocations / 0 dynamic dispatches).
Stacktrace:
[1] macro expansion
@ ~/repos/AllocCheck/src/macro.jl:134 [inlined]
@ ~/.julia/dev/AllocCheck/src/macro.jl:157 [inlined]
[2] multiply(x::Matrix{Float64}, y::Matrix{Float64})
@ Main ./REPL[2]:133
@ Main ./REPL[2]:156
[3] top-level scope
@ REPL[5]:1
@ REPL[4]:1

julia> @check_allocs 1.5 * 3.5 # check a call
5.25
```
"""
macro check_allocs(ex...)
kws, body = extract_keywords(ex)
if _is_func_def(body)
return _check_allocs_macro(body, __module__, __source__; kws...)
return _check_allocs_defun(body, __module__, __source__; kws...)
elseif Meta.isexpr(body, :call)
return _check_allocs_call(body, __module__, __source__; kws...)
else
error("@check_allocs used on something other than a function definition")
error("@check_allocs used on anything other than a function definition or call")
end
end

Expand Down Expand Up @@ -117,13 +127,20 @@ function forward_args!(func_def)
args, kwargs
end

function _check_allocs_macro(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
function _check_allocs_defun(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
(; original_fn, f_sym, wrapper_fn) = _check_allocs_wrap_fn(ex, mod, source; ignore_throw)
quote
local $f_sym = $(esc(original_fn))
$wrapper_fn
end
end

function _check_allocs_wrap_fn(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
# Transform original function to a renamed version with flattened args
def = splitdef(deepcopy(ex))
normalize_args!(def)
original_fn = combinedef(def)
f_sym = haskey(def, :name) ? gensym(def[:name]) : gensym()
f_sym = haskey(def, :name) ? gensym(def[:name]) : gensym("fn_alias")

# Next, create a wrapper function that will compile the original function on-the-fly.
def = splitdef(ex)
Expand All @@ -149,8 +166,26 @@ function _check_allocs_macro(ex::Expr, mod::Module, source::LineNumberNode; igno
def[:body].args[1] = source

wrapper_fn = combinedef(def)
return quote
local $f_sym = $(esc(original_fn))
$(wrapper_fn)

(; original_fn, f_sym, wrapper_fn)
end

function _check_allocs_call(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
fn = first(ex.args)
args = ex.args[2:end]
args_template = if !isempty(args) && Meta.isexpr(first(args), :parameters)
kwargs = Expr(:parameters, map(a -> if Meta.isexpr(a, :kw) first(a.args) else a end::Symbol, first(args).args)...)
[kwargs, map(_ -> gensym("arg"), 2:length(args))...]
else
[map(_ -> gensym("arg"), 1:length(args))...]
end
passthrough_defun = Expr(:function, Expr(:tuple, args_template...), Expr(:call, fn, args_template...))
(original_fn, f_sym, wrapper_fn) = _check_allocs_wrap_fn(passthrough_defun, mod, source; ignore_throw)
af_sym = gensym("alloccheck_fn")
quote
let $f_sym = $(esc(original_fn))
$af_sym = $wrapper_fn
$(Expr(:call, af_sym, map(esc, args)...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would $af_sym($(esc.(args)...)) work just as well?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd think so, it's just a difference in coding style.

end
end
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ end
@test mysum2(x, y) == x + y
@check_allocs (x::Bar)(y::Bar) = x.val + y.val
@test Bar(x)(Bar(y)) == x + y

# Callsite forms
@test 1 + x == @check_allocs 1 + x
@test x^2 == @check_allocs (a -> a^2)(x)
@test_throws AllocCheck.AllocCheckFailure @check_allocs same_ccall()
Copy link
Member

@topolarity topolarity Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the spirit of your recent issues, I double-checked and I don't think this handles kwargs correctly right now:

julia> foo(a,b;c=1) = a + b + c
foo (generic function with 1 method)
julia> @check_allocs foo(1,2;c=50)
4
julia> foo(1,2;c=50)
53

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this check! I'm currently trying to work out why exactly this is happening. With @macroexpand and a @show invocation, this is what I see with your example:

@show local variables
ex = :(foo(1, 2; c = 30))
mod = Main
source = :(#= REPL[16]:1 =#)
args = Any[:($(Expr(:parameters, :($(Expr(:kw, :c, 30)))))), 1, 2]
args_template = Any[:($(Expr(:parameters, :c))), Symbol("##arg#255"), Symbol("##arg#256")]
passthrough_defun = :(function (var"##arg#255", var"##arg#256"; c)
      foo(var"##arg#255", var"##arg#256"; c)
  end)
original_fn = :(function (var"##arg#255", var"##arg#256", c;)
      foo(var"##arg#255", var"##arg#256"; )
  end)
f_sym = Symbol("##fn_alias#257")
wrapper_fn = :(function ($(Expr(:escape, :(var"##arg#255"::Any))), $(Expr(:escape, :(var"##arg#256"::Any))); $(Expr(:escape, :c)))
      #= REPL[16]:1 =#
      callable_tt = Tuple{map(Core.Typeof, ($(Expr(:escape, Symbol("##arg#255"))), $(Expr(:escape, Symbol("##arg#256"))), $(Expr(:escape, :c))))...}
      #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:157 =#
      callable = (AllocCheck.compile_callable)(var"##fn_alias#257", callable_tt; ignore_throw = true)
      #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:158 =#
      if length(callable.analysis) > 0
          #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:159 =#
          throw(AllocCheckFailure(callable.analysis))
      end
      #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:161 =#
      callable($(Expr(:escape, Symbol("##arg#255"))), $(Expr(:escape, Symbol("##arg#256"))), $(Expr(:escape, :c)))
  end)

@macroexpand
quote
    #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:188 =#
    let var"#98###fn_alias#257" = function (var"##arg#255", var"##arg#256", c;)
                foo(var"##arg#255", var"##arg#256"; )
            end
        #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:189 =#
        var"#99###alloccheck_fn#260" = function (var"##arg#255"::Any, var"##arg#256"::Any; c)
                #= REPL[16]:1 =#
                var"#102#callable_tt" = AllocCheck.Tuple{AllocCheck.map((AllocCheck.Core).Typeof, (var"##arg#255", var"##arg#256", c))...}
                #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:157 =#
                var"#103#callable" = (AllocCheck.compile_callable)(var"#98###fn_alias#257", var"#102#callable_tt"; ignore_throw = true)
                #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:158 =#
                if AllocCheck.length((var"#103#callable").analysis) > 0
                    #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:159 =#
                    AllocCheck.throw(AllocCheck.AllocCheckFailure((var"#103#callable").analysis))
                end
                #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:161 =#
                var"#103#callable"(var"##arg#255", var"##arg#256", c)
            end
        #= /home/tec/.julia/dev/AllocCheck/src/macro.jl:190 =#
        var"#99###alloccheck_fn#260"(1, 2; c = 30)
    end
end

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, I'm a bit stuck on this, and help from someone else would be good to move this forward.

end


Expand Down