Skip to content

Commit

Permalink
using listener for compile times in triton
Browse files Browse the repository at this point in the history
Summary: Taking advantage of internal tooling to add commit hooks to triton on compile. This will allow us to get in depth stats directly from Triton on compile tims

Reviewed By: FindHao

Differential Revision: D67547087

fbshipit-source-id: f4576f59319589886a467ea2b17130b501160f5c
  • Loading branch information
adamomainz authored and facebook-github-bot committed Jan 6, 2025
1 parent 196f5f1 commit 717ac3f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 21 deletions.
2 changes: 1 addition & 1 deletion tritonbench/components/compile_time/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .trace import do_compile_time_in_task
from .trace import do_compile_time_in_task, fbcode_do_compile_time_in_task # noqa F401
22 changes: 21 additions & 1 deletion tritonbench/components/compile_time/trace.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
from typing import Callable
from typing import Callable, Dict

import torch
from triton.fb.triton_util import triton_add_listener, TritonHook
from tritonbench.utils.env_utils import fresh_triton_cache


def fbcode_do_compile_time_in_task(fn: Callable) -> Dict[str, float]:
# not yet getting results that make sense to me
detailed_data = {}
with fresh_triton_cache():

def _inner(**kwargs):
stats = kwargs.get("stats", {})
if not stats:
return
if "compile_time_stats" in stats:
detailed_data["compile_time_stats"] = stats["compile_time_stats"]

triton_add_listener(TritonHook.POST_COMPILE, _inner)
fn()
if "compile_time_stats" in detailed_data:
return detailed_data["compile_time_stats"]
return None


def do_compile_time_in_task(fn: Callable) -> float:
with fresh_triton_cache():
torch.cuda.synchronize()
Expand Down
74 changes: 55 additions & 19 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class BenchmarkOperatorMetrics:
walltime: Optional[float] = None
# compile time
compile_time: Optional[float] = None
# stage breakdown of compile times
compile_time_by_stage: Optional[Dict[str, float]] = None
# ncu trace file
ncu_trace: Optional[str] = None
# ncu replay file
Expand Down Expand Up @@ -395,18 +397,24 @@ def userbenchmark_dict(self) -> Dict[str, Any]:
for row in table:
x_val = row[0]

for ind, value in enumerate(row[1:]):
for ind, v in enumerate(row[1:]):
header = headers[ind + 1]
provider, _dash, metrics = header.partition("-")
metric_name = f"tritonbench_{self.op_name}_{self.op_mode}[x_{x_val}-{provider}]_{metrics}"
userbenchmark_metrics_dict[metric_name] = value
agg_metric_name = f"tritonbench_{self.op_name}_{self.op_mode}[{provider}]-{metrics}-avg"
if value is None:
continue
if isinstance(value, (int, float)):
agg_data[agg_metric_name] = agg_data.get(agg_metric_name, []) + [
value
]
provider, _, metrics_name = header.partition("-")
metrics_dict = {metrics_name: v}
if v and isinstance(v, dict):
metrics_dict = {
f"{metrics_name}-{k}": value for k, value in v.items()
}
for metrics, value in metrics_dict.items():
metric_name = f"tritonbench_{self.op_name}_{self.op_mode}[x_{x_val}-{provider}]_{metrics}"
userbenchmark_metrics_dict[metric_name] = value
agg_metric_name = f"tritonbench_{self.op_name}_{self.op_mode}[{provider}]-{metrics}-avg"
if value is None:
continue
if isinstance(value, (int, float)):
agg_data[agg_metric_name] = agg_data.get(
agg_metric_name, []
) + [value]
final_agg_data = {k: sum(v) / len(v) for k, v in agg_data.items()}
userbenchmark_metrics_dict.update(final_agg_data)

Expand Down Expand Up @@ -644,6 +652,8 @@ def __init__(
if tb_args.metrics
else self.DEFAULT_METRICS
)
if "compile_time" in self.required_metrics and IS_FBCODE:
self.required_metrics.append("compile_time_by_stage")
self.extra_args = extra_args
if self.name not in REGISTERED_X_VALS:
REGISTERED_X_VALS[self.name] = "x_val"
Expand Down Expand Up @@ -1102,7 +1112,12 @@ def _init_extra_metrics() -> Dict[str, Any]:
if "gbps" in self.required_metrics:
metrics.gbps = self.gbps(fn, self.example_inputs, metrics)
if "compile_time" in self.required_metrics:
metrics.compile_time = self.compile_time(input_id, fn_name, metrics)
compile_time, compile_time_by_stage = self.compile_time(
input_id, fn_name, metrics
)
metrics.compile_time = compile_time
if compile_time_by_stage:
metrics.compile_time_by_stage = compile_time_by_stage
if "ncu_trace" in self.required_metrics:
metrics.ncu_trace = self.ncu_trace(input_id, fn_name)
# Collect NCU metrics if any required metrics match the ncu analyzer
Expand Down Expand Up @@ -1205,12 +1220,29 @@ def _init_extra_metrics() -> Dict[str, Any]:
)
from tritonbench.components.compile_time import do_compile_time_in_task

metrics.extra_metrics["_compile_time_in_task"] = (
do_compile_time_in_task(fn)
)
self._latency_with_compile_in_task = metrics.extra_metrics[
"_compile_time_in_task"
]
if IS_FBCODE:
from tritonbench.components.compile_time import (
fbcode_do_compile_time_in_task,
)

compile_times = fbcode_do_compile_time_in_task(fn)
if compile_times is not None:
metrics.extra_metrics["compile_times"] = compile_times
self.compile_time_by_stage = {
k: v / 1_000_000
for k, v in compile_times.items()
if k != "total"
}
self.triton_hook_latency = (
compile_times["total"] / 1_000_000
) # converting from ms to s
if "compile_times" not in metrics.extra_metrics:
metrics.extra_metrics["_compile_time_in_task"] = (
do_compile_time_in_task(fn)
)
self._latency_with_compile_in_task = metrics.extra_metrics[
"_compile_time_in_task"
]
if "_ncu_trace_in_task" in self.required_metrics:
assert (
self.required_metrics == ["_ncu_trace_in_task"]
Expand Down Expand Up @@ -1558,10 +1590,14 @@ def compile_time(
op_task = OpTask(name=self.name)
op_task.make_operator_instance(args=op_task_args)
op_task.run()
if op_task.get_attribute("triton_hook_latency") is not None:
compiled_time = op_task.get_attribute("triton_hook_latency")
compile_time_by_stage = op_task.get_attribute("compile_time_by_stage")
return compiled_time, compile_time_by_stage
latency_with_compile = op_task.get_attribute("_latency_with_compile_in_task")
del op_task
latency_without_compile = metrics.latency
return latency_with_compile - latency_without_compile
return latency_with_compile - latency_without_compile, None

def hw_roofline(self) -> float:
"""Hardware roofline in tflops."""
Expand Down

0 comments on commit 717ac3f

Please sign in to comment.