Skip to content

Commit

Permalink
update commsReplay for 1.0.3-chakra.0.0.4 schema
Browse files Browse the repository at this point in the history
Summary:
1.0.3-chakra.0.0.4 schema (PR #124035) logs <group_name, group_desc> as the new pg_name  instead of pg uid in profiler.
- group_name remains as the unique identifier, e.g. “0”, "1"
- group_desc will be the user specified name, e.g. "fsdp".

This diff updates the commsReplay to support the new schema

Reviewed By: shengfukevin

Differential Revision: D56288398

fbshipit-source-id: 3663e45507e098cedb407609eecc2e0cec6890a1
  • Loading branch information
shengbao-zheng authored and facebook-github-bot committed Apr 18, 2024
1 parent c8e3f2f commit 7868e09
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 21 deletions.
47 changes: 26 additions & 21 deletions train/comms/pt/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def _parseExecutionTrace(
Convert the Execution Trace comms metadata to the common trace format for replay.
"""
# Execution Trace PG_ID types availability
ET_PG_NAME_TUPLE = True if in_trace.schema == "1.0.3-chakra.0.0.4" else False
ET_BACKENDID = True if in_trace.schema != "1.0.3-chakra.0.0.4" else False

initOps = []
newCommsTrace = []
Expand All @@ -233,23 +236,21 @@ def _parseExecutionTrace(
break

for pg in pgObj:
backendId = pg["uid"] if "uid" in pg else pg["backend_id"]
if not pg["pg_name"].isdecimal():
# TODO support local synchronization pg
continue
pgId = int(pg["pg_name"])
ranks = pg["ranks"]
if isinstance(ranks, list):
pgId = int(pg["pg_name"])
groupCnt = pg["group_count"]
pgRanksMap[pgId] = (
ranks
if len(ranks) > 0
else list(range(pg["group_size"]))
# rank list is empty when all ranks are in a pg
)
elif isinstance(
ranks, dict
): # TODO for legacy traces: remove once all ET use the most recent pg
pgId = pg["pg_id"]
pgRanksMap[pgId] = [int(rank) for rank in ranks.keys()]
backendIdToPgid[backendId] = pgId
groupCnt = pg["group_count"]
pgRanksMap[pgId] = (
ranks
if len(ranks) > 0
else list(range(pg["group_size"]))
# rank list is empty when all ranks are in a pg
)
if ET_BACKENDID:
backendId = pg["uid"] if "uid" in pg else pg["backend_id"]
backendIdToPgid[backendId] = pgId
break # only one process_group init node per trace

# Parse comms nodes
Expand All @@ -269,12 +270,16 @@ def _parseExecutionTrace(
1 - shift
] # 2nd value of inputs is the req id of the collective

backendId = node.inputs[
pgIdentifier = node.inputs[
2 - shift
] # 3rd value of inputs is the backend id of the collective
if backendId in backendIdToPgid:
# Assign pg_id info for PGs that were created.
newComm.pgId = backendIdToPgid[backendId]
] # 3rd value of inputs is the pg identifier of the collective
# Assign pg_id info for PGs that were created.
if ET_BACKENDID and pgIdentifier in backendIdToPgid:
newComm.pgId = backendIdToPgid[pgIdentifier]
newComm.groupRanks = pgRanksMap[newComm.pgId]
newComm.worldSize = len(newComm.groupRanks)
elif ET_PG_NAME_TUPLE and pgIdentifier[0].isdecimal():
newComm.pgId = int(pgIdentifier[0])
newComm.groupRanks = pgRanksMap[newComm.pgId]
newComm.worldSize = len(newComm.groupRanks)

Expand Down
2 changes: 2 additions & 0 deletions train/compute/python/tools/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def __init__(self, json):
node_creation_func = {
"1.0.1": ExecutionTrace._create_node_v1_0_1,
"1.0.2-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
# 1.0.3 expands pg name to <pg_name, pg_desc> so it use the same parser as 1.0.2
"1.0.3-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
# Add future versions here
}
create_node = node_creation_func.get(self.schema, None)
Expand Down

0 comments on commit 7868e09

Please sign in to comment.