Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support triton kernel replay in PARAM #106

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions train/compute/python/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json

import logging
import os
import sys
import time
from collections import defaultdict
Expand All @@ -18,6 +19,7 @@
from param_bench.train.compute.python.tools.et_replay_utils import (
build_fbgemm_func,
build_torchscript_func,
build_triton_func,
fbgemm_input_args_indices,
generate_fbgemm_tensors,
generate_prefix,
Expand Down Expand Up @@ -45,6 +47,10 @@

from param_bench.train.compute.python.tools.utility import trace_handler
from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch
from torch._inductor.codecache import AsyncCompile, TritonFuture

# grid and split_scan_grid are dynamically loaded
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
from torch.profiler import ExecutionTraceObserver


Expand Down Expand Up @@ -227,8 +233,6 @@ def initBench(self):
self.dump_path += "benchmark.py"
# Multiple traces.
else:
import os

print(f"{os.getpid()} is rank{self.comms_env_params['global_rank']}")
self.cuda_id = self.comms_env_params["local_rank"]
self.cuda = f"cuda:{self.comms_env_params['local_rank']}"
Expand All @@ -253,6 +257,14 @@ def initBench(self):

self.dump_path += f"benchmark_{self.comms_env_params['global_rank']}.py"

# base_path is used to find the generated kernel files in the same directory of the trace file.
base_path, file_name = os.path.split(self.trace_file)
self.resource_dir = os.path.join(
base_path, os.path.splitext(file_name)[-2] + "_resources"
)
self.kernel_map = {}
self.async_compile = AsyncCompile()

if self.cpu:
self.device = torch.device("cpu")
else:
Expand Down Expand Up @@ -364,6 +376,12 @@ def dfs_traverse(root):
print("#Operators to execute: ", len(self.sorted_nodes))
for node in self.sorted_nodes:
anlayze_node(node)

# triton kernels are compiled in parallel, need to wait until
# all kernels are compiled.
self.async_compile.wait(globals())
del self.async_compile

self.select_parallel_nodes()

def select_parallel_nodes(self):
Expand Down Expand Up @@ -618,7 +636,22 @@ def build_func(self, node):
assert self.fbgemm_backward_ops
backward_op, forward_id = self.fbgemm_backward_ops.pop(-1)
return backward_op, len(node.output_types)
func, output_count = build_torchscript_func(node)

if node.kernel_backend == "triton":
if node.kernel_file in self.kernel_map:
func = self.kernel_map[node.kernel_file]
# For a triton kernel, it is the caller's responsibility to allocate
# the output tensors, and pass them in as the input arguments.
# So the number of the output tensors is always 0
output_count = 0
else:
func, output_count = build_triton_func(
node, self.resource_dir, self.async_compile, self.device
)
self.kernel_map[node.kernel_file] = func
else:
func, output_count = build_torchscript_func(node)

if not func:
self.actual_skip_nodes.append(node.name)
self.actual_skip_nodes_cnt += 1
Expand Down Expand Up @@ -890,6 +923,9 @@ def _generate_run_ops_str(override):
func, output_count = self.funcs[node.id]
if not func:
continue
if isinstance(func, TritonFuture):
func = func.result()

func_str = f"funcs[{node.id}]"
inputs_str = _generate_inputs_str(node)
outputs_str = _generate_outputs_str(node, override=override)
Expand Down Expand Up @@ -1135,7 +1171,17 @@ def run_op(self, node, iter):
try:
outputs = []
if output_count == 0:
func(*inputs)
if node.kernel_backend == "triton":
# remove the last comma
grid_info = inputs[-2]
index = grid_info.rfind(",")
if index >= 0:
grid_info = grid_info[:index] + grid_info[index + 1 :]
exec(
f"func.run(*inputs[:-2], grid={grid_info}, stream={inputs[-1]})"
)
else:
func(*inputs)
else:
if output_count == 1:
tmp = (func(*inputs),)
Expand Down
10 changes: 9 additions & 1 deletion train/compute/python/tools/et_replay_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re

import torch
Expand Down Expand Up @@ -139,7 +140,6 @@ def skip_op(op):
and "fbgemm::split_embedding_codegen_lookup_" not in op.name
)
)
or ("fused" in op.name)
or (
op.name
in [
Expand Down Expand Up @@ -453,6 +453,14 @@ def build_torchscript_func(n):
return func, output_count


def build_triton_func(n, resources_dir, async_compile, device):
with open(os.path.join(resources_dir, n.kernel_file), "r") as f:
code = f.read()

func = async_compile.triton(n.name, code, device_str=device)
return func, 0


def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows):
template_prefix = """import gc
import argparse
Expand Down
55 changes: 49 additions & 6 deletions train/compute/python/tools/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class NodeType(Enum):
"All2All_Pooled_Req",
"All2All_Pooled_Wait",
"c10d::",
"TorchDynamo Cache Lookup",
"CompiledFunction",
"Torch-Compiled Region",
]


Expand Down Expand Up @@ -136,13 +139,17 @@ def __init__(
output_types: List[str],
output_shapes: List[Any],
rf_id: Optional[int] = None,
kernel_backend: Optional[str] = None,
kernel_file: Optional[str] = None,
):
self.name: str = name
self.parent_id: int = parent_id
self.parent: Optional[Node] = None
self.children: List[Node] = []
self.id: int = id
self.rf_id: Optional[int] = rf_id
self.kernel_backend: Optional[str] = kernel_backend
self.kernel_file: Optional[str] = kernel_file
self.pid: int = pid
self.tid: int = tid
self.fw_tid: int = fw_tid
Expand Down Expand Up @@ -298,6 +305,7 @@ def __init__(self, json):
"1.0.2-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
# 1.0.3 expands pg name to <pg_name, pg_desc> so it use the same parser as 1.0.2
"1.0.3-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
"1.0.4-chakra.0.0.4": ExecutionTrace._create_node_v1_0_4_chakra_0_0_4,
# Add future versions here
}
create_node = node_creation_func.get(self.schema, None)
Expand Down Expand Up @@ -371,19 +379,18 @@ def _read_attrs(node: Dict[str, Any]) -> Tuple:
"rf_id": int,
"scope": int,
"tid": int,
"kernel_backend": str,
"kernel_file": str,
}
attr_dict = {
attr["name"]: attr_types[attr["name"]](attr["value"])
for attr in node["attrs"]
if attr["name"] in attr_types.keys()
}

# Ensure all keys have values
if attr_dict.keys() != attr_types.keys():
raise ValueError(
"Not all keys in attr_dict have updated values. Node:" + str(node)
)
return tuple(attr_dict[key] for key in attr_types.keys())
return tuple(
attr_dict[key] for key in attr_types.keys() if key in attr_dict.keys()
)

@staticmethod
def _create_node_v1_0_1(pid, x: Dict[str, Any]) -> Node:
Expand Down Expand Up @@ -439,6 +446,42 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node:
rf_id,
)

@staticmethod
def _create_node_v1_0_4_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node:
(
fw_parent,
seq_id,
fw_tid,
op_schema,
rf_id,
scope,
tid,
kernel_backend,
kernel_file,
) = ExecutionTrace._read_attrs(x)

return Node(
x["name"],
x["id"],
x["ctrl_deps"],
fw_parent,
seq_id,
pid,
tid,
fw_tid,
op_schema,
scope,
x["inputs"]["values"],
x["inputs"]["types"],
x["inputs"]["shapes"],
x["outputs"]["values"],
x["outputs"]["types"],
x["outputs"]["shapes"],
rf_id,
kernel_backend,
kernel_file,
)

def get_nodes(self, clean: bool = False):
if clean:
return self.clean_nodes
Expand Down
Loading