Skip to content

Commit

Permalink
Support triton kernel replay in PARAM
Browse files Browse the repository at this point in the history
Summary: This DIFF is to import the captured triton kernels into et_replay, load the kernel file, compile it into cuda binary, and replay it.

Reviewed By: briancoutinho

Differential Revision: D56320143
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Apr 30, 2024
1 parent c42dd9e commit 7c377a0
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 11 deletions.
54 changes: 50 additions & 4 deletions train/compute/python/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json

import logging
import os
import sys
import time
from collections import defaultdict
Expand All @@ -18,6 +19,7 @@
from param_bench.train.compute.python.tools.et_replay_utils import (
build_fbgemm_func,
build_torchscript_func,
build_triton_func,
fbgemm_input_args_indices,
generate_fbgemm_tensors,
generate_prefix,
Expand Down Expand Up @@ -45,6 +47,10 @@

from param_bench.train.compute.python.tools.utility import trace_handler
from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch
from torch._inductor.codecache import AsyncCompile, TritonFuture

# grid and split_scan_grid are dynamically loaded
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
from torch.profiler import ExecutionTraceObserver


Expand Down Expand Up @@ -227,8 +233,6 @@ def initBench(self):
self.dump_path += "benchmark.py"
# Multiple traces.
else:
import os

print(f"{os.getpid()} is rank{self.comms_env_params['global_rank']}")
self.cuda_id = self.comms_env_params["local_rank"]
self.cuda = f"cuda:{self.comms_env_params['local_rank']}"
Expand All @@ -253,6 +257,14 @@ def initBench(self):

self.dump_path += f"benchmark_{self.comms_env_params['global_rank']}.py"

# base_path is used to find the generated kernel files in the same directory of the trace file.
base_path, file_name = os.path.split(self.trace_file)
self.resource_dir = os.path.join(
base_path, os.path.splitext(file_name)[-2] + "_resources"
)
self.kernel_map = {}
self.async_compile = AsyncCompile()

if self.cpu:
self.device = torch.device("cpu")
else:
Expand Down Expand Up @@ -364,6 +376,12 @@ def dfs_traverse(root):
print("#Operators to execute: ", len(self.sorted_nodes))
for node in self.sorted_nodes:
anlayze_node(node)

# triton kernels are compiled in parallel, need to wait until
# all kernels are compiled.
self.async_compile.wait(globals())
del self.async_compile

self.select_parallel_nodes()

def select_parallel_nodes(self):
Expand Down Expand Up @@ -618,7 +636,22 @@ def build_func(self, node):
assert self.fbgemm_backward_ops
backward_op, forward_id = self.fbgemm_backward_ops.pop(-1)
return backward_op, len(node.output_types)
func, output_count = build_torchscript_func(node)

if node.kernel_backend == "triton":
if node.kernel_file in self.kernel_map:
func = self.kernel_map[node.kernel_file]
# For a triton kernel, it is the caller's responsibility to allocate
# the output tensors, and pass them in as the input arguments.
# So the number of the output tensors is always 0
output_count = 0
else:
func, output_count = build_triton_func(
node, self.resource_dir, self.async_compile, self.device
)
self.kernel_map[node.kernel_file] = func
else:
func, output_count = build_torchscript_func(node)

if not func:
self.actual_skip_nodes.append(node.name)
self.actual_skip_nodes_cnt += 1
Expand Down Expand Up @@ -890,6 +923,9 @@ def _generate_run_ops_str(override):
func, output_count = self.funcs[node.id]
if not func:
continue
if isinstance(func, TritonFuture):
func = func.result()

func_str = f"funcs[{node.id}]"
inputs_str = _generate_inputs_str(node)
outputs_str = _generate_outputs_str(node, override=override)
Expand Down Expand Up @@ -1135,7 +1171,17 @@ def run_op(self, node, iter):
try:
outputs = []
if output_count == 0:
func(*inputs)
if node.kernel_backend == "triton":
# remove the last comma
grid_info = inputs[-2]
index = grid_info.rfind(",")
if index >= 0:
grid_info = grid_info[:index] + grid_info[index + 1 :]
exec(
f"func.run(*inputs[:-2], grid={grid_info}, stream={inputs[-1]})"
)
else:
func(*inputs)
else:
if output_count == 1:
tmp = (func(*inputs),)
Expand Down
10 changes: 9 additions & 1 deletion train/compute/python/tools/et_replay_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re

import torch
Expand Down Expand Up @@ -139,7 +140,6 @@ def skip_op(op):
and "fbgemm::split_embedding_codegen_lookup_" not in op.name
)
)
or ("fused" in op.name)
or (
op.name
in [
Expand Down Expand Up @@ -453,6 +453,14 @@ def build_torchscript_func(n):
return func, output_count


def build_triton_func(n, resources_dir, async_compile, device):
with open(os.path.join(resources_dir, n.kernel_file), "r") as f:
code = f.read()

func = async_compile.triton(n.name, code, device_str=device)
return func, 0


def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows):
template_prefix = """import gc
import argparse
Expand Down
55 changes: 49 additions & 6 deletions train/compute/python/tools/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class NodeType(Enum):
"All2All_Pooled_Req",
"All2All_Pooled_Wait",
"c10d::",
"TorchDynamo Cache Lookup",
"CompiledFunction",
"Torch-Compiled Region",
]


Expand Down Expand Up @@ -136,13 +139,17 @@ def __init__(
output_types: List[str],
output_shapes: List[Any],
rf_id: Optional[int] = None,
kernel_backend: Optional[str] = None,
kernel_file: Optional[str] = None,
):
self.name: str = name
self.parent_id: int = parent_id
self.parent: Optional[Node] = None
self.children: List[Node] = []
self.id: int = id
self.rf_id: Optional[int] = rf_id
self.kernel_backend: Optional[str] = kernel_backend
self.kernel_file: Optional[str] = kernel_file
self.pid: int = pid
self.tid: int = tid
self.fw_tid: int = fw_tid
Expand Down Expand Up @@ -298,6 +305,7 @@ def __init__(self, json):
"1.0.2-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
# 1.0.3 expands pg name to <pg_name, pg_desc> so it use the same parser as 1.0.2
"1.0.3-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
"1.0.4-chakra.0.0.4": ExecutionTrace._create_node_v1_0_4_chakra_0_0_4,
# Add future versions here
}
create_node = node_creation_func.get(self.schema, None)
Expand Down Expand Up @@ -371,19 +379,18 @@ def _read_attrs(node: Dict[str, Any]) -> Tuple:
"rf_id": int,
"scope": int,
"tid": int,
"kernel_backend": str,
"kernel_file": str,
}
attr_dict = {
attr["name"]: attr_types[attr["name"]](attr["value"])
for attr in node["attrs"]
if attr["name"] in attr_types.keys()
}

# Ensure all keys have values
if attr_dict.keys() != attr_types.keys():
raise ValueError(
"Not all keys in attr_dict have updated values. Node:" + str(node)
)
return tuple(attr_dict[key] for key in attr_types.keys())
return tuple(
attr_dict[key] for key in attr_types.keys() if key in attr_dict.keys()
)

@staticmethod
def _create_node_v1_0_1(pid, x: Dict[str, Any]) -> Node:
Expand Down Expand Up @@ -439,6 +446,42 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node:
rf_id,
)

@staticmethod
def _create_node_v1_0_4_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node:
(
fw_parent,
seq_id,
fw_tid,
op_schema,
rf_id,
scope,
tid,
kernel_backend,
kernel_file,
) = ExecutionTrace._read_attrs(x)

return Node(
x["name"],
x["id"],
x["ctrl_deps"],
fw_parent,
seq_id,
pid,
tid,
fw_tid,
op_schema,
scope,
x["inputs"]["values"],
x["inputs"]["types"],
x["inputs"]["shapes"],
x["outputs"]["values"],
x["outputs"]["types"],
x["outputs"]["shapes"],
rf_id,
kernel_backend,
kernel_file,
)

def get_nodes(self, clean: bool = False):
if clean:
return self.clean_nodes
Expand Down

0 comments on commit 7c377a0

Please sign in to comment.