Skip to content

Commit

Permalink
Merge pull request #4 from JuliaParallel/jps/threadsafe_workerstate
Browse files Browse the repository at this point in the history
Such threadsafe, much wow
  • Loading branch information
JamesWrigley authored Nov 14, 2024
2 parents 76df474 + c1a3be8 commit b9a8000
Show file tree
Hide file tree
Showing 14 changed files with 1,907 additions and 1,767 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
JULIA_DISTRIBUTED_TESTING_STANDALONE: 1
JULIA_NUM_THREADS: 4
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
docs/src/changelog.md
Manifest.toml
*.swp
2 changes: 2 additions & 0 deletions docs/src/_changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ This documents notable changes in DistributedNext.jl. The format is based on
### Fixed
- Fixed behaviour of `isempty(::RemoteChannel)`, which previously had the
side-effect of taking an element from the channel ([#3]).
- Improved thread-safety, such that it should be safe to start workers with
multiple threads and send messages between them ([#4]).

### Changed
- Added a `project` argument to [`addprocs(::AbstractVector)`](@ref) to specify
Expand Down
117 changes: 74 additions & 43 deletions src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ mutable struct Worker
del_msgs::Array{Any,1} # XXX: Could del_msgs and add_msgs be Channels?
add_msgs::Array{Any,1}
@atomic gcflag::Bool
state::WorkerState
c_state::Condition # wait for state changes
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily
@atomic state::WorkerState
c_state::Threads.Condition # wait for state changes, lock for state
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily

r_stream::IO
w_stream::IO
Expand Down Expand Up @@ -134,7 +134,7 @@ mutable struct Worker
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
end
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Condition(), time(), conn_func)
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
w.initialized = Event()
register_worker(w)
w
Expand All @@ -144,12 +144,14 @@ mutable struct Worker
end

function set_worker_state(w, state)
w.state = state
notify(w.c_state; all=true)
lock(w.c_state) do
@atomic w.state = state
notify(w.c_state; all=true)
end
end

function check_worker_state(w::Worker)
if w.state === W_CREATED
if (@atomic w.state) === W_CREATED
if !isclusterlazy()
if PGRP.topology === :all_to_all
# Since higher pids connect with lower pids, the remote worker
Expand All @@ -170,6 +172,7 @@ function check_worker_state(w::Worker)
wait_for_conn(w)
end
end
return nothing
end

exec_conn_func(id::Int) = exec_conn_func(worker_from_id(id)::Worker)
Expand All @@ -187,13 +190,21 @@ function exec_conn_func(w::Worker)
end

function wait_for_conn(w)
if w.state === W_CREATED
if (@atomic w.state) === W_CREATED
timeout = worker_timeout() - (time() - w.ct_time)
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")

@async (sleep(timeout); notify(w.c_state; all=true))
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
T = Threads.@spawn begin
sleep($timeout)
lock(w.c_state) do
notify(w.c_state; all=true)
end
end
errormonitor(T)
lock(w.c_state) do
wait(w.c_state)
(@atomic w.state) === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
end
end
nothing
end
Expand Down Expand Up @@ -491,7 +502,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
while true
if isempty(launched)
istaskdone(t_launch) && break
@async (sleep(1); notify(launch_ntfy))
@async begin
sleep(1)
notify(launch_ntfy)
end
wait(launch_ntfy)
end

Expand Down Expand Up @@ -645,7 +659,12 @@ function create_worker(manager, wconfig)
# require the value of config.connect_at which is set only upon connection completion
for jw in PGRP.workers
if (jw.id != 1) && (jw.id < w.id)
(jw.state === W_CREATED) && wait(jw.c_state)
lock(jw.c_state) do
# wait for wl to join
if (@atomic jw.state) === W_CREATED
wait(jw.c_state)
end
end
push!(join_list, jw)
end
end
Expand All @@ -668,7 +687,12 @@ function create_worker(manager, wconfig)
end

for wl in wlist
(wl.state === W_CREATED) && wait(wl.c_state)
lock(wl.c_state) do
if (@atomic wl.state) === W_CREATED
# wait for wl to join
wait(wl.c_state)
end
end
push!(join_list, wl)
end
end
Expand All @@ -682,10 +706,16 @@ function create_worker(manager, wconfig)
join_message = JoinPGRPMsg(w.id, all_locs, PGRP.topology, enable_threaded_blas, isclusterlazy())
send_msg_now(w, MsgHeader(RRID(0,0), ntfy_oid), join_message)

@async manage(w.manager, w.id, w.config, :register)
errormonitor(@async manage(w.manager, w.id, w.config, :register))
# wait for rr_ntfy_join with timeout
timedout = false
@async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1))
errormonitor(
@async begin
sleep($timeout)
timedout = true
put!(rr_ntfy_join, 1)
end
)
wait(rr_ntfy_join)
if timedout
error("worker did not connect within $timeout seconds")
Expand Down Expand Up @@ -735,17 +765,20 @@ function check_master_connect()
if ccall(:jl_running_on_valgrind,Cint,()) != 0
return
end
@async begin
start = time_ns()
while !haskey(map_pid_wrkr, 1) && (time_ns() - start) < timeout
sleep(1.0)
end

if !haskey(map_pid_wrkr, 1)
print(stderr, "Master process (id 1) could not connect within $(timeout/1e9) seconds.\nexiting.\n")
exit(1)
errormonitor(
@async begin
start = time_ns()
while !haskey(map_pid_wrkr, 1) && (time_ns() - start) < timeout
sleep(1.0)
end

if !haskey(map_pid_wrkr, 1)
print(stderr, "Master process (id 1) could not connect within $(timeout/1e9) seconds.\nexiting.\n")
exit(1)
end
end
end
)
end


Expand Down Expand Up @@ -870,7 +903,7 @@ function nprocs()
n = length(PGRP.workers)
# filter out workers in the process of being setup/shutdown.
for jw in PGRP.workers
if !isa(jw, LocalProcess) && (jw.state !== W_CONNECTED)
if !isa(jw, LocalProcess) && ((@atomic jw.state) !== W_CONNECTED)
n = n - 1
end
end
Expand Down Expand Up @@ -921,7 +954,7 @@ julia> procs()
function procs()
if myid() == 1 || (PGRP.topology === :all_to_all && !isclusterlazy())
# filter out workers in the process of being setup/shutdown.
return Int[x.id for x in PGRP.workers if isa(x, LocalProcess) || (x.state === W_CONNECTED)]
return Int[x.id for x in PGRP.workers if isa(x, LocalProcess) || ((@atomic x.state) === W_CONNECTED)]
else
return Int[x.id for x in PGRP.workers]
end
Expand All @@ -930,7 +963,7 @@ end
function id_in_procs(id) # faster version of `id in procs()`
if myid() == 1 || (PGRP.topology === :all_to_all && !isclusterlazy())
for x in PGRP.workers
if (x.id::Int) == id && (isa(x, LocalProcess) || (x::Worker).state === W_CONNECTED)
if (x.id::Int) == id && (isa(x, LocalProcess) || (@atomic (x::Worker).state) === W_CONNECTED)
return true
end
end
Expand All @@ -952,7 +985,7 @@ Specifically all workers bound to the same ip-address as `pid` are returned.
"""
function procs(pid::Integer)
if myid() == 1
all_workers = [x for x in PGRP.workers if isa(x, LocalProcess) || (x.state === W_CONNECTED)]
all_workers = [x for x in PGRP.workers if isa(x, LocalProcess) || ((@atomic x.state) === W_CONNECTED)]
if (pid == 1) || (isa(map_pid_wrkr[pid].manager, LocalManager))
Int[x.id for x in filter(w -> (w.id==1) || (isa(w.manager, LocalManager)), all_workers)]
else
Expand Down Expand Up @@ -1059,11 +1092,11 @@ function _rmprocs(pids, waitfor)

start = time_ns()
while (time_ns() - start) < waitfor*1e9
all(w -> w.state === W_TERMINATED, rmprocset) && break
all(w -> (@atomic w.state) === W_TERMINATED, rmprocset) && break
sleep(min(0.1, waitfor - (time_ns() - start)/1e9))
end

unremoved = [wrkr.id for wrkr in filter(w -> w.state !== W_TERMINATED, rmprocset)]
unremoved = [wrkr.id for wrkr in filter(w -> (@atomic w.state) !== W_TERMINATED, rmprocset)]
if length(unremoved) > 0
estr = string("rmprocs: pids ", unremoved, " not terminated after ", waitfor, " seconds.")
throw(ErrorException(estr))
Expand Down Expand Up @@ -1290,18 +1323,16 @@ end

using Random: randstring

let inited = false
# do initialization that's only needed when there is more than 1 processor
global function init_multi()
if !inited
inited = true
push!(Base.package_callbacks, _require_callback)
atexit(terminate_all_workers)
init_bind_addr()
cluster_cookie(randstring(HDR_COOKIE_LEN))
end
return nothing
# do initialization that's only needed when there is more than 1 processor
const inited = Threads.Atomic{Bool}(false)
function init_multi()
if !Threads.atomic_cas!(inited, false, true)
push!(Base.package_callbacks, _require_callback)
atexit(terminate_all_workers)
init_bind_addr()
cluster_cookie(randstring(HDR_COOKIE_LEN))
end
return nothing
end

function init_parallel()
Expand Down
2 changes: 1 addition & 1 deletion src/managers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
# Wait for all launches to complete.
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
let machine=machine, cnt=cnt
@async try
@async try
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
catch e
print(stderr, "exception launching on machine $(machine) : $(e)\n")
Expand Down
2 changes: 1 addition & 1 deletion src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ end
function flush_gc_msgs()
try
for w in (PGRP::ProcessGroup).workers
if isa(w,Worker) && (w.state == W_CONNECTED) && w.gcflag
if isa(w,Worker) && ((@atomic w.state) == W_CONNECTED) && w.gcflag
flush_gc_msgs(w)
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/process_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
println(stderr, "Process($(myid())) - Unknown remote, closing connection.")
elseif !(wpid in map_del_wrkr)
werr = worker_from_id(wpid)
oldstate = werr.state
oldstate = @atomic werr.state
set_worker_state(werr, W_TERMINATED)

# If unhandleable error occurred talking to pid 1, exit
Expand Down
Loading

0 comments on commit b9a8000

Please sign in to comment.