Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the problem of Workflow terminates after parallel tasks execution, merge node not triggered #12498

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 77 additions & 66 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,83 +439,94 @@ def _run_parallel_branches(
parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
node_id = edge_mappings[0].target_node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:

target_nodes: dict[str, list[GraphEdge]] = {}
for edge in edge_mappings:
if edge.target_node_id not in target_nodes:
target_nodes[edge.target_node_id] = []
target_nodes[edge.target_node_id].append(edge)
lazyFrogLOL marked this conversation as resolved.
Show resolved Hide resolved

for target_node_id, edges in target_nodes.items():
parallel_id = self.graph.node_parallel_mapping.get(target_node_id)
if not parallel_id:
node_config = self.graph.node_id_config_mapping.get(target_node_id)
if not node_config:
raise GraphRunFailedError(
f"Node {target_node_id} related parallel not found"
f"or incorrectly connected to multiple parallel branches."
)

node_title = node_config.get("data", {}).get("title")
raise GraphRunFailedError(
f"Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches."
f"Node {node_title} related parallel not found"
f"or incorrectly connected to multiple parallel branches."
)

node_title = node_config.get("data", {}).get("title")
raise GraphRunFailedError(
f"Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches."
)
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError(f"Parallel {parallel_id} not found.")

parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError(f"Parallel {parallel_id} not found.")
q: queue.Queue = queue.Queue()
futures = []

# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
for edge in edges:
if (
edge.target_node_id not in self.graph.node_parallel_mapping
or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id
):
continue

future = self.thread_pool.submit(
self._run_parallel_node,
**{
"flask_app": current_app._get_current_object(), # type: ignore[attr-defined]
"q": q,
"parallel_id": parallel_id,
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
"parent_parallel_start_node_id": parallel_start_node_id,
"handle_exceptions": handle_exceptions,
},
)

# Create a list to store the threads
futures = []
future.add_done_callback(self.thread_pool.task_done_callback)
futures.append(future)

# new thread
for edge in edge_mappings:
if (
edge.target_node_id not in self.graph.node_parallel_mapping
or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id
):
continue

future = self.thread_pool.submit(
self._run_parallel_node,
**{
"flask_app": current_app._get_current_object(), # type: ignore[attr-defined]
"q": q,
"parallel_id": parallel_id,
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
"parent_parallel_start_node_id": parallel_start_node_id,
"handle_exceptions": handle_exceptions,
},
)
succeeded_count = 0
branch_results = []

while True:
try:
event = q.get(timeout=1)
if event is None:
break

future.add_done_callback(self.thread_pool.task_done_callback)
yield event
if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
branch_results.append(event)

if succeeded_count == len(futures):
q.put(None)

if len(branch_results) > 1:
for _ in range(len(branch_results) - 1):
final_node_id = parallel.end_to_node_id
if final_node_id:
yield final_node_id

futures.append(future)
continue
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error)
except queue.Empty:
continue

succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
wait(futures)

yield event
if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(futures):
q.put(None)

continue
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error)
except queue.Empty:
continue

# wait all threads
wait(futures)

# get final node id
final_node_id = parallel.end_to_node_id
if final_node_id:
yield final_node_id
final_node_id = parallel.end_to_node_id
if final_node_id:
yield final_node_id

def _run_parallel_node(
self,
Expand Down