Skip to content

Commit

Permalink
replace macros in Gaius with Base.threads macros for SnpLinAlg (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
biona001 authored Mar 2, 2022
1 parent db152fd commit 9bc2da3
Showing 1 changed file with 28 additions and 44 deletions.
72 changes: 28 additions & 44 deletions src/linalg_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,6 @@ function Base.getindex(s::SnpLinAlg{T}, i::Int, j::Int) where T
return x
end

# macros taken from Gaius.jl
macro _spawn(ex)
if Threads.nthreads() > 1
esc(Expr(:macrocall, Expr(:(.), :Threads, QuoteNode(Symbol("@spawn"))), __source__, ex))
else
esc(ex)
end
end
macro _sync(ex)
if Threads.nthreads() > 1
esc(Expr(:macrocall, Symbol("@sync"), __source__, ex))
else
esc(ex)
end
end

"""
LinearAlgebra.mul!(out, sla::SnpLinAlg, v)
Expand Down Expand Up @@ -213,11 +197,11 @@ function _snparray_ax_tile!(c, A, b, model, μ, impute, rows_filled)
Nrem = N & (hstep - 1)
taskarray = Array{Any}(undef, Miter + 1)
fill!(taskarray, nothing)
@_sync begin
@sync begin
GC.@preserve c A b for n in 0:Niter - 1
for m in 0:Miter - 1
wait(taskarray[m+1])
taskarray[m+1] = @_spawn _ftn!(
taskarray[m+1] = Threads.@spawn _ftn!(
gesp(stridedpointer(c), (4 * vstep * m,)),
gesp(stridedpointer(A), (vstep * m, hstep * n)),
gesp(stridedpointer(b), (hstep * n,)),
Expand All @@ -226,7 +210,7 @@ function _snparray_ax_tile!(c, A, b, model, μ, impute, rows_filled)
end
if Mrem != 0
wait(taskarray[Miter+1])
taskarray[Miter+1] = @_spawn _ftn!(
taskarray[Miter+1] = Threads.@spawn _ftn!(
@view(c[4 * vstep * Miter + 1:end]),
@view(A[vstep * Miter + 1:end, hstep * n + 1:hstep * (n + 1)]),
@view(b[hstep * n + 1:hstep * (n + 1)]),
Expand All @@ -238,7 +222,7 @@ function _snparray_ax_tile!(c, A, b, model, μ, impute, rows_filled)
if Nrem != 0
for m in 0:Miter-1
wait(taskarray[m+1])
taskarray[m+1] = @_spawn _ftn!(
taskarray[m+1] = Threads.@spawn _ftn!(
@view(c[4 * vstep * m + 1:4 * vstep * (m + 1)]),
@view(A[vstep * m + 1:vstep * (m + 1), hstep * Niter + 1:end]),
@view(b[hstep * Niter + 1:end]),
Expand All @@ -248,7 +232,7 @@ function _snparray_ax_tile!(c, A, b, model, μ, impute, rows_filled)
end
if Mrem != 0
wait(taskarray[Miter + 1])
taskarray[Miter + 1] = @_spawn _ftn!(
taskarray[Miter + 1] = Threads.@spawn _ftn!(
@view(c[4 * vstep * Miter+1:end]),
@view(A[vstep * Miter + 1:end, hstep * Niter + 1:end]),
@view(b[hstep * Niter + 1:end]),
Expand Down Expand Up @@ -298,12 +282,12 @@ function _snparray_AX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
Prem = P & (pstep - 1)
taskarray = Array{Any}(undef, Miter + 1)
fill!(taskarray, nothing)
@_sync begin
@sync begin
GC.@preserve C A B for p in 0:Piter - 1
for n in 0:Niter - 1
for m in 0:Miter - 1
wait(taskarray[m+1])
taskarray[m+1] = @_spawn _ftn!(
taskarray[m+1] = Threads.@spawn _ftn!(
@view(C[4 * vstep * m + 1:4 * vstep * (m + 1), pstep * p + 1:pstep * (p + 1)]),
@view(A[vstep * m + 1:vstep * (m + 1), hstep * n + 1:hstep * (n + 1)]),
@view(B[hstep * n + 1:hstep * (n + 1), pstep * p + 1:pstep * (p + 1)]),
Expand All @@ -315,7 +299,7 @@ function _snparray_AX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
end
if Mrem != 0
wait(taskarray[Miter+1])
taskarray[Miter+1] = @_spawn _ftn!(
taskarray[Miter+1] = Threads.@spawn _ftn!(
@view(C[4 * vstep * Miter + 1:end, pstep * p + 1:pstep * (p + 1)]),
@view(A[vstep * Miter + 1:end, hstep * n + 1:hstep * (n + 1)]),
@view(B[hstep * n + 1:hstep * (n + 1), pstep * p + 1:pstep * (p + 1)]),
Expand All @@ -329,7 +313,7 @@ function _snparray_AX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
if Nrem != 0
for m in 0:Miter-1
wait(taskarray[m+1])
taskarray[m+1] = @_spawn _ftn!(
taskarray[m+1] = Threads.@spawn _ftn!(
@view(C[4 * vstep * m + 1:4 * vstep * (m + 1), pstep * p + 1:pstep * (p + 1)]),
@view(A[vstep * m + 1:vstep * (m + 1), hstep * Niter + 1:end]),
@view(B[hstep * Niter + 1:end, pstep * p + 1:pstep * (p + 1)]),
Expand All @@ -341,7 +325,7 @@ function _snparray_AX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
end
if Mrem != 0
wait(taskarray[Miter + 1])
taskarray[Miter + 1] = @_spawn _ftn!(
taskarray[Miter + 1] = Threads.@spawn _ftn!(
@view(C[4 * vstep * Miter+1:end, pstep * p + 1:pstep * (p + 1)]),
@view(A[vstep * Miter + 1:end, hstep * Niter + 1:end]),
@view(B[hstep * Niter + 1:end, pstep * p + 1:pstep * (p + 1)]),
Expand All @@ -357,7 +341,7 @@ function _snparray_AX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
for n in 0:Niter - 1
for m in 0:Miter - 1
wait(taskarray[m+1])
taskarray[m+1] = @_spawn _ftn!(
taskarray[m+1] = Threads.@spawn _ftn!(
@view(C[4 * vstep * m + 1:4 * vstep * (m + 1), pstep * Piter + 1:end]),
@view(A[vstep * m + 1:vstep * (m + 1), hstep * n + 1:hstep * (n + 1)]),
@view(B[hstep * n + 1:hstep * (n + 1), pstep * Piter + 1:end]),
Expand All @@ -369,7 +353,7 @@ function _snparray_AX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
end
if Mrem != 0
wait(taskarray[Miter+1])
taskarray[Miter+1] = @_spawn _ftn!(
taskarray[Miter+1] = Threads.@spawn _ftn!(
@view(C[4 * vstep * Miter + 1:end, pstep * Piter + 1:end]),
@view(A[vstep * Miter + 1:end, hstep * n + 1:hstep * (n + 1)]),
@view(B[hstep * n + 1:hstep * (n + 1), pstep * Piter + 1:end]),
Expand All @@ -383,7 +367,7 @@ function _snparray_AX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
if Nrem != 0
for m in 0:Miter-1
wait(taskarray[m+1])
taskarray[m+1] = @_spawn _ftn!(
taskarray[m+1] = Threads.@spawn _ftn!(
@view(C[4 * vstep * m + 1:4 * vstep * (m + 1), pstep * Piter + 1:end]),
@view(A[vstep * m + 1:vstep * (m + 1), hstep * Niter + 1:end]),
@view(B[hstep * Niter + 1:end, pstep * Piter + 1:end]),
Expand All @@ -395,7 +379,7 @@ function _snparray_AX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
end
if Mrem != 0
wait(taskarray[Miter + 1])
taskarray[Miter + 1] = @_spawn _ftn!(
taskarray[Miter + 1] = Threads.@spawn _ftn!(
@view(C[4 * vstep * Miter+1:end, pstep * Piter + 1:end]),
@view(A[vstep * Miter + 1:end, hstep * Niter + 1:end]),
@view(B[hstep * Niter + 1:end, pstep * Piter + 1:end]),
Expand Down Expand Up @@ -442,11 +426,11 @@ function _snparray_atx_tile!(c, A, b, model, μ, impute, rows_filled)
Nrem = N & (hstep - 1)
taskarray = Array{Any}(undef, Niter+1)
fill!(taskarray, nothing)
@_sync begin
@sync begin
GC.@preserve c A b for m in 0:Miter - 1
for n in 0:Niter - 1
wait(taskarray[n + 1])
taskarray[n + 1] = @_spawn _ftn!(
taskarray[n + 1] = Threads.@spawn _ftn!(
gesp(stridedpointer(c), (hstep * n,)),
gesp(stridedpointer(A), (vstep * m, hstep * n)),
gesp(stridedpointer(b), (4 * vstep * m,)),
Expand All @@ -455,7 +439,7 @@ function _snparray_atx_tile!(c, A, b, model, μ, impute, rows_filled)
end
if Nrem != 0
wait(taskarray[Niter + 1])
taskarray[Niter + 1] = @_spawn _ftn!(
taskarray[Niter + 1] = Threads.@spawn _ftn!(
@view(c[hstep * Niter + 1:end]),
@view(A[vstep * m + 1:vstep * (m + 1), hstep * Niter + 1:end]),
@view(b[4 * vstep * m + 1:4 * vstep * (m + 1)]),
Expand All @@ -467,7 +451,7 @@ function _snparray_atx_tile!(c, A, b, model, μ, impute, rows_filled)
if Mrem != 0
for n in 0:Niter - 1
wait(taskarray[n + 1])
taskarray[n + 1] = @_spawn _ftn!(
taskarray[n + 1] = Threads.@spawn _ftn!(
@view(c[hstep * n + 1:hstep * (n + 1)]),
@view(A[vstep * Miter + 1:end, hstep * n + 1:hstep * (n + 1)]),
@view(b[4 * vstep * Miter + 1:end]),
Expand All @@ -477,7 +461,7 @@ function _snparray_atx_tile!(c, A, b, model, μ, impute, rows_filled)
end
if Nrem != 0
wait(taskarray[Niter + 1])
taskarray[Niter + 1] = @_spawn _ftn!(
taskarray[Niter + 1] = Threads.@spawn _ftn!(
@view(c[hstep * Niter + 1:end]),
@view(A[vstep * Miter + 1:end, hstep * Niter + 1:end]),
@view(b[4 * vstep * Miter + 1:end]),
Expand Down Expand Up @@ -527,12 +511,12 @@ function _snparray_AtX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
Prem = P & (pstep - 1)
taskarray = Array{Any}(undef, Niter + 1)
fill!(taskarray, nothing)
@_sync begin
@sync begin
GC.@preserve C A B for p in 0:Piter - 1
for m in 0:Miter - 1
for n in 0:Niter - 1
wait(taskarray[n + 1])
taskarray[n + 1] = @_spawn _ftn!(
taskarray[n + 1] = Threads.@spawn _ftn!(
@view(C[hstep * n + 1:hstep * (n + 1), pstep * p + 1:pstep * (p + 1)]),
@view(A[vstep * m + 1:vstep * (m + 1), hstep * n + 1:hstep * (n + 1)]),
@view(B[4 * vstep * m + 1:4 * vstep * (m + 1), pstep * p + 1:pstep * (p + 1)]),
Expand All @@ -544,7 +528,7 @@ function _snparray_AtX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
end
if Nrem != 0
wait(taskarray[Niter + 1])
taskarray[Niter + 1] = @_spawn _ftn!(
taskarray[Niter + 1] = Threads.@spawn _ftn!(
@view(C[hstep * Niter + 1:end, pstep * p + 1:pstep * (p + 1)]),
@view(A[vstep * m + 1:vstep * (m + 1), hstep * Niter + 1:end]),
@view(B[4 * vstep * m + 1:4 * vstep * (m + 1), pstep * p + 1:pstep * (p + 1)]),
Expand All @@ -558,7 +542,7 @@ function _snparray_AtX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
if Mrem != 0
for n in 0:Niter - 1
wait(taskarray[n + 1])
taskarray[n + 1] = @_spawn _ftn!(
taskarray[n + 1] = Threads.@spawn _ftn!(
@view(C[hstep * n + 1:hstep * (n + 1), pstep * p + 1:pstep * (p + 1)]),
@view(A[vstep * Miter + 1:end, hstep * n + 1:hstep * (n + 1)]),
@view(B[4 * vstep * Miter + 1:end, pstep * p + 1:pstep * (p + 1)]),
Expand All @@ -570,7 +554,7 @@ function _snparray_AtX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
end
if Nrem != 0
wait(taskarray[Niter + 1])
taskarray[Niter + 1] = @_spawn _ftn!(
taskarray[Niter + 1] = Threads.@spawn _ftn!(
@view(C[hstep * Niter + 1:end, pstep * p + 1:pstep * (p + 1)]),
@view(A[vstep * Miter + 1:end, hstep * Niter + 1:end]),
@view(B[4 * vstep * Miter + 1:end, pstep * p + 1:pstep * (p + 1)]),
Expand All @@ -586,7 +570,7 @@ function _snparray_AtX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
for m in 0:Miter - 1
for n in 0:Niter - 1
wait(taskarray[n + 1])
taskarray[n + 1] = @_spawn _ftn!(
taskarray[n + 1] = Threads.@spawn _ftn!(
@view(C[hstep * n + 1:hstep * (n + 1), pstep * Piter + 1:end]),
@view(A[vstep * m + 1:vstep * (m + 1), hstep * n + 1:hstep * (n + 1)]),
@view(B[4 * vstep * m + 1:4 * vstep * (m + 1), pstep * Piter + 1:end]),
Expand All @@ -598,7 +582,7 @@ function _snparray_AtX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
end
if Nrem != 0
wait(taskarray[Niter + 1])
taskarray[Niter + 1] = @_spawn _ftn!(
taskarray[Niter + 1] = Threads.@spawn _ftn!(
@view(C[hstep * Niter + 1:end, pstep * Piter + 1:end]),
@view(A[vstep * m + 1:vstep * (m + 1), hstep * Niter + 1:end]),
@view(B[4 * vstep * m + 1:4 * vstep * (m + 1), pstep * Piter + 1:end]),
Expand All @@ -612,7 +596,7 @@ function _snparray_AtX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
if Mrem != 0
for n in 0:Niter - 1
wait(taskarray[n + 1])
taskarray[n + 1] = @_spawn _ftn!(
taskarray[n + 1] = Threads.@spawn _ftn!(
@view(C[hstep * n + 1:hstep * (n + 1), pstep * Piter + 1:end]),
@view(A[vstep * Miter + 1:end, hstep * n + 1:hstep * (n + 1)]),
@view(B[4 * vstep * Miter + 1:end, pstep * Piter + 1:end]),
Expand All @@ -624,7 +608,7 @@ function _snparray_AtX_tile!(C, A, B, model, μ, μimpute, impute, rows_filled,
end
if Nrem != 0
wait(taskarray[Niter + 1])
taskarray[Niter + 1] = @_spawn _ftn!(
taskarray[Niter + 1] = Threads.@spawn _ftn!(
@view(C[hstep * Niter + 1:end, pstep * Piter + 1:end]),
@view(A[vstep * Miter + 1:end, hstep * Niter + 1:end]),
@view(B[4 * vstep * Miter + 1:end, pstep * Piter + 1:end]),
Expand Down

0 comments on commit 9bc2da3

Please sign in to comment.