Skip to content

Commit

Permalink
add example trace with comm args metadata (#105)
Browse files Browse the repository at this point in the history
Summary:
As above
1. Adds a test trace with 2 GPU resnet job and 12 communication collectives.
2. Add a commArgs optional argument to ET node, this will be populated soon.
3. Minor updates to parser and add a new unittest that tries to validate this test

## Testing

```
(pytorch) [[email protected] /data/users/bcoutinho]$ export PYTHONPATH=/data/users/bcoutinho
(pytorch) [[email protected] /data/users/bcoutinho]$ python3 param_bench/train/compute/python/test/test_execution_trace.py
 record_param_comms, process group args = ('Tuple[String,String]', ['0', 'default_pg'], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['0', 'default_pg'], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['', ''], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['', ''], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['0', 'default_pg'], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['0', 'default_pg'], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['', ''], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['', ''], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['0', 'default_pg'], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['0', 'default_pg'], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['', ''], [[], []])
 record_param_comms, process group args = ('Tuple[String,String]', ['', ''], [[], []])
.None
None
None
None
None
None
None
None
None
None
None
None
.
----------------------------------------------------------------------
Ran 2 tests in 1.322s

OK
```

Pull Request resolved: #105

Reviewed By: shengfukevin

Differential Revision: D56739869

Pulled By: briancoutinho

fbshipit-source-id: b53a3f36eb57e637e77b988cc071136b71e96caa
  • Loading branch information
briancoutinho authored and facebook-github-bot committed May 3, 2024
1 parent c83ce84 commit 2b4cf3e
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 55 deletions.
Binary file not shown.
11 changes: 10 additions & 1 deletion train/compute/python/test/test_execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _test_and_validate_trace(self, trace_file):
self.assertTrue(t.validate())
return t, execution_trace

def test_trace_load_resnet_1gpu(self):
def test_trace_load_resnet_1gpu_ptorch_1_0_3(self):
et_file = os.path.join(
self.trace_base, "1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz"
)
Expand All @@ -33,6 +33,15 @@ def test_trace_load_resnet_1gpu(self):
self.assertEqual(t.num_comm_ops(), 12)
self.assertEqual(t.num_triton_ops(), 0)

def test_trace_load_resnet_2gpu_ptorch_1_1_0(self):
et_file = os.path.join(
self.trace_base, "1.1.0-chakra.0.0.4/resnet_2gpu_et.json.gz"
)
t, et = self._test_and_validate_trace(et_file)
self.assertGreater(t.num_ops(), 1000)
self.assertEqual(t.num_comm_ops(), 12)
self.assertEqual(t.num_triton_ops(), 0)


if __name__ == "__main__":
unittest.main()
86 changes: 37 additions & 49 deletions train/compute/python/tools/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import json
import logging
import sys
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Set, TextIO, Tuple

Expand Down Expand Up @@ -103,6 +104,15 @@ def is_leaf_tensor(self):
) and self.sinks # A tensor having no sources yet having some sinks is a leaf tensor


@dataclass
class _CommArgs:
"""Contains communication collective metadata"""

collective_name: str
dtype: str
# .. TODO add more see https://github.com/pytorch/pytorch/issues/124674


"""
Node
Expand Down Expand Up @@ -141,6 +151,7 @@ def __init__(
rf_id: Optional[int] = None,
kernel_backend: Optional[str] = None,
kernel_file: Optional[str] = None,
comm_args: Optional[_CommArgs] = None,
):
self.name: str = name
self.parent_id: int = parent_id
Expand All @@ -166,6 +177,7 @@ def __init__(
self.outputs: List[Any] = outputs
self.output_types: List[str] = output_types
self.output_shapes: List[Any] = output_shapes
self.commArgs: Optional[_CommArgs] = comm_args

def get_inputs(self) -> Iterable:
return zip(self.input_types, self.inputs, self.input_shapes)
Expand Down Expand Up @@ -305,7 +317,10 @@ def __init__(self, json):
"1.0.2-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
# 1.0.3 expands pg name to <pg_name, pg_desc> so it use the same parser as 1.0.2
"1.0.3-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
"1.0.4-chakra.0.0.4": ExecutionTrace._create_node_v1_0_4_chakra_0_0_4,
# 1.0.4 adds PT2 kernel backend and kernel file
"1.0.4-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
# 1.1.0 includes new comm args in record_param_comms
"1.1.0-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
# Add future versions here
}
create_node = node_creation_func.get(self.schema, None)
Expand Down Expand Up @@ -369,27 +384,32 @@ def schema_chakra(self) -> Tuple[int]:
return (0, 0, 0)
return self._versiontuple(self.schema.split("-")[1])

@staticmethod
def _read_attrs(node: Dict[str, Any]) -> Tuple:
attr_types = {
"fw_parent": int,
"seq_id": int,
"fw_tid": int,
"op_schema": str,
"rf_id": int,
"scope": int,
"tid": int,
"kernel_backend": str,
"kernel_file": str,
}
ATTR_TYPES = {
"fw_parent": int,
"seq_id": int,
"fw_tid": int,
"op_schema": str,
"rf_id": int,
"scope": int,
"tid": int,
"kernel_backend": str,
"kernel_file": str,
}
OPTIONAL_ATTR = ["kernel_backend", "kernel_file"]

@classmethod
def _read_attrs(cls, node: Dict[str, Any]) -> Tuple:
attr_dict = {
attr["name"]: attr_types[attr["name"]](attr["value"])
attr["name"]: cls.ATTR_TYPES[attr["name"]](attr["value"])
for attr in node["attrs"]
if attr["name"] in attr_types.keys()
if attr["name"] in cls.ATTR_TYPES.keys()
}
for opt_key in cls.OPTIONAL_ATTR:
if opt_key not in attr_dict:
attr_dict[opt_key] = None

return tuple(
attr_dict[key] for key in attr_types.keys() if key in attr_dict.keys()
attr_dict[key] for key in cls.ATTR_TYPES.keys() if key in attr_dict.keys()
)

@staticmethod
Expand All @@ -416,38 +436,6 @@ def _create_node_v1_0_1(pid, x: Dict[str, Any]) -> Node:

@staticmethod
def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node:
(
fw_parent,
seq_id,
fw_tid,
op_schema,
rf_id,
scope,
tid,
) = ExecutionTrace._read_attrs(x)

return Node(
x["name"],
x["id"],
x["ctrl_deps"],
fw_parent,
seq_id,
pid,
tid,
fw_tid,
op_schema,
scope,
x["inputs"]["values"],
x["inputs"]["types"],
x["inputs"]["shapes"],
x["outputs"]["values"],
x["outputs"]["types"],
x["outputs"]["shapes"],
rf_id,
)

@staticmethod
def _create_node_v1_0_4_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node:
(
fw_parent,
seq_id,
Expand Down
23 changes: 18 additions & 5 deletions train/compute/python/tools/validate_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


class TraceValidator:

def __init__(self, execution_trace: ExecutionTrace):
self.et = execution_trace

Expand Down Expand Up @@ -42,14 +41,14 @@ def _validate_tree(self) -> bool:

def _validate_param_comms(self) -> bool:
"""Check if param comms has correct attributes"""
# This should use the comms parser, for now something simple will be fine
# https://github.com/facebookresearch/param/blob/main/train/comms/pt/commsTraceParser.py#L256

if self.et.schema_pytorch() < (1, 0, 2):
return True

def check_comms_node(n) -> bool:
"""TODO use comms parser"""
def check_comms_node_pre_1_1_0(n) -> bool:
"""Roughly based on commsTraceParser"""
# https://github.com/facebookresearch/param/blob/main/train/comms/pt/commsTraceParser.py#L256

has_pg_id = False
# Slightly hacky but find a argument with tuple type
for arg in n.get_inputs():
Expand All @@ -58,6 +57,20 @@ def check_comms_node(n) -> bool:
has_pg_id = True
return has_pg_id

def check_comms_node_1_1_0(n) -> bool:
"""New elements are added as per
https://github.com/pytorch/pytorch/issues/124674
"""
# TODO check for node.commArgs dataclass
print(n.commArgs)
return True

check_comms_node = (
check_comms_node_1_1_0
if self.et.schema_pytorch() >= (1, 1, 0)
else check_comms_node_pre_1_1_0
)

return all(
check_comms_node(n)
for n in self.et.nodes.values()
Expand Down

0 comments on commit 2b4cf3e

Please sign in to comment.