Skip to content

Commit

Permalink
Binding for prctl(PR_SET_PDEATHSIG) (pytorch#14491)
Browse files Browse the repository at this point in the history
Summary:
If torch.multiprocessing.spawn is used to launch non-daemonic
processes (the default since pytorch#14391), the spawned children won't be
automatically terminated when the parent terminates.

On Linux, we can address this by setting PR_SET_PDEATHSIG, which
delivers a configurable signal to child processes when their parent
terminates.

Fixes pytorch#14394.
Pull Request resolved: pytorch#14491

Differential Revision: D13270374

Pulled By: pietern

fbshipit-source-id: 092c9d3c3cea2622c3766b467957bc27a1bd500c
  • Loading branch information
pietern authored and facebook-github-bot committed Nov 30, 2018
1 parent 9127ab3 commit 220ce80
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ Checks: '
,-cppcoreguidelines-pro-type-static-cast-downcast
,-cppcoreguidelines-pro-bounds-pointer-arithmetic
,-cppcoreguidelines-pro-bounds-constant-array-index
,-cppcoreguidelines-pro-type-cstyle-cast
,-cppcoreguidelines-pro-type-reinterpret-cast
,-cppcoreguidelines-pro-type-vararg
,-cppcoreguidelines-special-member-functions
,-cppcoreguidelines-interfaces-global-init
,-cppcoreguidelines-owning-memory
Expand Down
61 changes: 61 additions & 0 deletions test/test_multiprocessing_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,33 @@ def test_success_first_then_exception_func(i, arg):
raise ValueError("legitimate exception")


def test_nested_child_body(i, ready_queue, nested_child_sleep):
ready_queue.put(None)
time.sleep(nested_child_sleep)


def test_nested_spawn(i, pids_queue, nested_child_sleep):
context = mp.get_context("spawn")
nested_child_ready_queue = context.Queue()
nprocs = 2
spawn_context = mp.spawn(
fn=test_nested_child_body,
args=(nested_child_ready_queue, nested_child_sleep),
nprocs=nprocs,
join=False,
daemon=False,
)
pids_queue.put(spawn_context.pids())

# Wait for both children to have spawned, to ensure that they
# have called prctl(2) to register a parent death signal.
for _ in range(nprocs):
nested_child_ready_queue.get()

# Kill self. This should take down the child processes as well.
os.kill(os.getpid(), signal.SIGTERM)


@unittest.skipIf(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that don't support the spawn start method")
Expand Down Expand Up @@ -118,6 +145,40 @@ def test_success_first_then_exception(self):
):
mp.spawn(test_success_first_then_exception_func, args=(exitcode,), nprocs=2)

@unittest.skipIf(
sys.platform != "linux",
"Only runs on Linux; requires prctl(2)",
)
def test_nested_spawn(self):
context = mp.get_context("spawn")
pids_queue = context.Queue()
nested_child_sleep = 20.0
spawn_context = mp.spawn(
fn=test_nested_spawn,
args=(pids_queue, nested_child_sleep),
nprocs=1,
join=False,
daemon=False,
)

# Wait for nested children to terminate in time
pids = pids_queue.get()
start = time.time()
while len(pids) > 0:
for pid in pids:
try:
os.kill(pid, 0)
except ProcessLookupError:
pids.remove(pid)
break

# This assert fails if any nested child process is still
# alive after (nested_child_sleep / 2) seconds. By
# extension, this test times out with an assertion error
# after (nested_child_sleep / 2) seconds.
self.assertLess(time.time() - start, nested_child_sleep / 2)
time.sleep(0.1)


if __name__ == '__main__':
run_tests()
1 change: 1 addition & 0 deletions torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ if (BUILD_PYTHON)
${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/script/python_tree_views.cpp
${TORCH_SRC_DIR}/csrc/multiprocessing/init.cpp
${TORCH_SRC_DIR}/csrc/nn/THNN.cpp
${TORCH_SRC_DIR}/csrc/onnx/init.cpp
${TORCH_SRC_DIR}/csrc/serialization.cpp
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "torch/csrc/autograd/generated/python_nn_functions.h"
#include "torch/csrc/autograd/python_legacy_variable.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/multiprocessing/init.h"
#include "torch/csrc/tensor/python_tensor.h"
#include "torch/csrc/utils/tensor_dtypes.h"
#include "torch/csrc/utils/python_strings.h"
Expand Down Expand Up @@ -534,6 +535,7 @@ PyObject* initModule() {
THPUtils_addPyMethodDefs(methods, TorchMethods);
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
#ifdef USE_CUDA
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
#endif
Expand Down
58 changes: 58 additions & 0 deletions torch/csrc/multiprocessing/init.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>

#include <stdexcept>

#if defined(__linux__)
#include <sys/prctl.h>
#endif

#define SYSASSERT(rv, ...) \
if ((rv) < 0) { \
throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \
}

namespace torch {
namespace multiprocessing {

namespace {

PyObject* multiprocessing_init(PyObject* _unused) {
auto multiprocessing_module =
THPObjectPtr(PyImport_ImportModule("torch.multiprocessing"));
if (!multiprocessing_module) {
throw python_error();
}

auto module = py::handle(multiprocessing_module).cast<py::module>();

module.def("_prctl_pr_set_pdeathsig", [](int signal) {
#if defined(__linux__)
auto rv = prctl(PR_SET_PDEATHSIG, signal);
SYSASSERT(rv, "prctl");
#endif
});

Py_RETURN_TRUE;
}

} // namespace

// multiprocessing methods on torch._C
static PyMethodDef methods[] = {
{
"_multiprocessing_init",
(PyCFunction)multiprocessing_init,
METH_NOARGS,
nullptr,
},
{nullptr, nullptr, 0, nullptr},
};

PyMethodDef* python_functions() {
return methods;
}

} // namespace multiprocessing
} // namespace torch
11 changes: 11 additions & 0 deletions torch/csrc/multiprocessing/init.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <torch/csrc/python_headers.h>

namespace torch {
namespace multiprocessing {

PyMethodDef* python_functions();

} // namespace multiprocessing
} // namespace torch
6 changes: 6 additions & 0 deletions torch/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Because of the similarity of APIs we do not document most of this package
contents, and we recommend referring to very good docs of the original module.
"""
import torch
import sys
from .reductions import init_reductions
import multiprocessing
Expand All @@ -27,6 +28,11 @@
__all__ += multiprocessing.__all__


# This call adds a Linux specific prctl(2) wrapper function to this module.
# See https://github.com/pytorch/pytorch/pull/14391 for more information.
torch._C._multiprocessing_init()


if sys.version_info < (3, 3):
"""Override basic classes in Python 2.7 and Python 3.3 to use ForkingPickler
for serialization. Later versions of Python already use ForkingPickler."""
Expand Down
11 changes: 11 additions & 0 deletions torch/multiprocessing/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@
import signal
import sys

from . import _prctl_pr_set_pdeathsig


def _wrap(fn, i, args, error_queue):
# prctl(2) is a Linux specific system call.
# On other systems the following function call has no effect.
# This is set to ensure that non-daemonic child processes can
# terminate if their parent terminates before they do.
_prctl_pr_set_pdeathsig(signal.SIGINT)

try:
fn(i, *args)
except KeyboardInterrupt:
Expand Down Expand Up @@ -39,6 +47,9 @@ def __init__(self, processes, error_queues):
for index, process in enumerate(processes)
}

def pids(self):
return [int(process.pid) for process in self.processes]

def join(self, timeout=None):
r"""
Tries to join one or more processes in this spawn context.
Expand Down

0 comments on commit 220ce80

Please sign in to comment.