diff --git a/python/graphstorm/gconstruct/utils.py b/python/graphstorm/gconstruct/utils.py index d2077f632..9653dcb32 100644 --- a/python/graphstorm/gconstruct/utils.py +++ b/python/graphstorm/gconstruct/utils.py @@ -295,6 +295,8 @@ def multiprocessing_data_read(in_files, num_processes, user_parser, ext_mem_work a dict : key is the file index, the value is processed data. """ if num_processes > 1 and len(in_files) > 1: + if th.cuda.is_available(): + multiprocessing.set_start_method("spawn", force=True) processes = [] manager = multiprocessing.Manager() task_queue = manager.Queue()