diff --git a/train/compute/python/tools/et_replay.py b/train/compute/python/tools/et_replay.py index 6c4a8e02..48731d00 100644 --- a/train/compute/python/tools/et_replay.py +++ b/train/compute/python/tools/et_replay.py @@ -3,6 +3,7 @@ import json import logging +import os import sys import time from collections import defaultdict @@ -18,6 +19,7 @@ from param_bench.train.compute.python.tools.et_replay_utils import ( build_fbgemm_func, build_torchscript_func, + build_triton_func, fbgemm_input_args_indices, generate_fbgemm_tensors, generate_prefix, @@ -45,6 +47,10 @@ from param_bench.train.compute.python.tools.utility import trace_handler from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch +from torch._inductor.codecache import AsyncCompile, TritonFuture + +# grid and split_scan_grid are dynamically loaded +from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid from torch.profiler import ExecutionTraceObserver @@ -227,8 +233,6 @@ def initBench(self): self.dump_path += "benchmark.py" # Multiple traces. else: - import os - print(f"{os.getpid()} is rank{self.comms_env_params['global_rank']}") self.cuda_id = self.comms_env_params["local_rank"] self.cuda = f"cuda:{self.comms_env_params['local_rank']}" @@ -253,6 +257,14 @@ def initBench(self): self.dump_path += f"benchmark_{self.comms_env_params['global_rank']}.py" + # base_path is used to find the generated kernel files in the same directory of the trace file. + base_path, file_name = os.path.split(self.trace_file) + self.resource_dir = os.path.join( + base_path, os.path.splitext(file_name)[-2] + "_resources" + ) + self.kernel_map = {} + self.async_compile = AsyncCompile() + if self.cpu: self.device = torch.device("cpu") else: @@ -364,6 +376,12 @@ def dfs_traverse(root): print("#Operators to execute: ", len(self.sorted_nodes)) for node in self.sorted_nodes: anlayze_node(node) + + # triton kernels are compiled in parallel, need to wait until + # all kernels are compiled. + self.async_compile.wait(globals()) + del self.async_compile + self.select_parallel_nodes() def select_parallel_nodes(self): @@ -618,7 +636,22 @@ def build_func(self, node): assert self.fbgemm_backward_ops backward_op, forward_id = self.fbgemm_backward_ops.pop(-1) return backward_op, len(node.output_types) - func, output_count = build_torchscript_func(node) + + if node.kernel_backend == "triton": + if node.kernel_file in self.kernel_map: + func = self.kernel_map[node.kernel_file] + # For a triton kernel, it is the caller's responsibility to allocate + # the output tensors, and pass them in as the input arguments. + # So the number of the output tensors is always 0 + output_count = 0 + else: + func, output_count = build_triton_func( + node, self.resource_dir, self.async_compile, self.device + ) + self.kernel_map[node.kernel_file] = func + else: + func, output_count = build_torchscript_func(node) + if not func: self.actual_skip_nodes.append(node.name) self.actual_skip_nodes_cnt += 1 @@ -890,6 +923,9 @@ def _generate_run_ops_str(override): func, output_count = self.funcs[node.id] if not func: continue + if isinstance(func, TritonFuture): + func = func.result() + func_str = f"funcs[{node.id}]" inputs_str = _generate_inputs_str(node) outputs_str = _generate_outputs_str(node, override=override) @@ -1135,7 +1171,17 @@ def run_op(self, node, iter): try: outputs = [] if output_count == 0: - func(*inputs) + if node.kernel_backend == "triton": + # remove the last comma + grid_info = inputs[-2] + index = grid_info.rfind(",") + if index >= 0: + grid_info = grid_info[:index] + grid_info[index + 1 :] + exec( + f"func.run(*inputs[:-2], grid={grid_info}, stream={inputs[-1]})" + ) + else: + func(*inputs) else: if output_count == 1: tmp = (func(*inputs),) diff --git a/train/compute/python/tools/et_replay_utils.py b/train/compute/python/tools/et_replay_utils.py index d9ef8190..d9f8f29c 100644 --- a/train/compute/python/tools/et_replay_utils.py +++ b/train/compute/python/tools/et_replay_utils.py @@ -1,3 +1,4 @@ +import os import re import torch @@ -139,7 +140,6 @@ def skip_op(op): and "fbgemm::split_embedding_codegen_lookup_" not in op.name ) ) - or ("fused" in op.name) or ( op.name in [ @@ -453,6 +453,14 @@ def build_torchscript_func(n): return func, output_count +def build_triton_func(n, resources_dir, async_compile, device): + with open(os.path.join(resources_dir, n.kernel_file), "r") as f: + code = f.read() + + func = async_compile.triton(n.name, code, device_str=device) + return func, 0 + + def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows): template_prefix = """import gc import argparse diff --git a/train/compute/python/tools/execution_trace.py b/train/compute/python/tools/execution_trace.py index 00e23a75..991b37a8 100644 --- a/train/compute/python/tools/execution_trace.py +++ b/train/compute/python/tools/execution_trace.py @@ -59,6 +59,9 @@ class NodeType(Enum): "All2All_Pooled_Req", "All2All_Pooled_Wait", "c10d::", + "TorchDynamo Cache Lookup", + "CompiledFunction", + "Torch-Compiled Region", ] @@ -136,6 +139,8 @@ def __init__( output_types: List[str], output_shapes: List[Any], rf_id: Optional[int] = None, + kernel_backend: Optional[str] = None, + kernel_file: Optional[str] = None, ): self.name: str = name self.parent_id: int = parent_id @@ -143,6 +148,8 @@ def __init__( self.children: List[Node] = [] self.id: int = id self.rf_id: Optional[int] = rf_id + self.kernel_backend: Optional[str] = kernel_backend + self.kernel_file: Optional[str] = kernel_file self.pid: int = pid self.tid: int = tid self.fw_tid: int = fw_tid @@ -298,6 +305,7 @@ def __init__(self, json): "1.0.2-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4, # 1.0.3 expands pg name to so it use the same parser as 1.0.2 "1.0.3-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4, + "1.0.4-chakra.0.0.4": ExecutionTrace._create_node_v1_0_4_chakra_0_0_4, # Add future versions here } create_node = node_creation_func.get(self.schema, None) @@ -371,6 +379,8 @@ def _read_attrs(node: Dict[str, Any]) -> Tuple: "rf_id": int, "scope": int, "tid": int, + "kernel_backend": str, + "kernel_file": str, } attr_dict = { attr["name"]: attr_types[attr["name"]](attr["value"]) @@ -378,12 +388,9 @@ def _read_attrs(node: Dict[str, Any]) -> Tuple: if attr["name"] in attr_types.keys() } - # Ensure all keys have values - if attr_dict.keys() != attr_types.keys(): - raise ValueError( - "Not all keys in attr_dict have updated values. Node:" + str(node) - ) - return tuple(attr_dict[key] for key in attr_types.keys()) + return tuple( + attr_dict[key] for key in attr_types.keys() if key in attr_dict.keys() + ) @staticmethod def _create_node_v1_0_1(pid, x: Dict[str, Any]) -> Node: @@ -439,6 +446,42 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node: rf_id, ) + @staticmethod + def _create_node_v1_0_4_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node: + ( + fw_parent, + seq_id, + fw_tid, + op_schema, + rf_id, + scope, + tid, + kernel_backend, + kernel_file, + ) = ExecutionTrace._read_attrs(x) + + return Node( + x["name"], + x["id"], + x["ctrl_deps"], + fw_parent, + seq_id, + pid, + tid, + fw_tid, + op_schema, + scope, + x["inputs"]["values"], + x["inputs"]["types"], + x["inputs"]["shapes"], + x["outputs"]["values"], + x["outputs"]["types"], + x["outputs"]["shapes"], + rf_id, + kernel_backend, + kernel_file, + ) + def get_nodes(self, clean: bool = False): if clean: return self.clean_nodes