Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor IsXSpaceGrouped to verify that all device and host planes are grouped. #21824

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions xla/tsl/profiler/utils/xplane_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,15 +399,14 @@ bool IsEmpty(const XSpace& space) {

bool IsXSpaceGrouped(const XSpace& space) {
for (const auto& plane : space.planes()) {
// If any plane has been grouped, consider space as grouped.
// CreateTfXPlaneVisitor is necessary because we need check "group_id" stat
// by its type StatType::kGroupId.
if (!IsDevicePlane(plane) && !IsHostPlane(plane)) continue;
// Ensure all host and device planes have a group id stat.
XPlaneVisitor xplane = tsl::profiler::CreateTfXPlaneVisitor(&plane);
const XStatMetadata* group_id_stat =
xplane.GetStatMetadataByType(StatType::kGroupId);
if (group_id_stat) return true;
if (!group_id_stat) return false;
}
return false;
return true;
}

void AddFlowsToXplane(int32_t host_id, bool is_host_plane, bool connect_traceme,
Expand Down
60 changes: 44 additions & 16 deletions xla/tsl/profiler/utils/xplane_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include <gtest/gtest.h>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -732,26 +733,53 @@ TEST(XplaneutilsTest, TestEventMetadataStatsAreCopiedForRefValue) {
EXPECT_EQ(stat->StrOrRefValue(), "TestFunction");
}

TEST(XplaneutilsTest, TestIsXSpaceGrouped) {
struct IsXSpaceGroupedTestCase {
struct XPlaneParams {
std::string name;
std::vector<bool> events_grouped;
};
std::vector<XPlaneParams> planes;
bool expected_result;
std::string test_case_name;
};

using IsXSpaceGroupedTest = ::testing::TestWithParam<IsXSpaceGroupedTestCase>;
TEST_P(IsXSpaceGroupedTest, TestIsXSpaceGrouped) {
const IsXSpaceGroupedTestCase& test_case = GetParam();
XSpace space;
{
XPlaneBuilder p1(space.add_planes());
auto l1 = CreateXLine(&p1, "l1", "d1", 1, 100);
auto e1 = CreateXEvent(&p1, l1, "event1", "display1", 1, 2);
CreateXStats(&p1, &e1, "event_stat1", 2.0);
}
EXPECT_FALSE(IsXSpaceGrouped(space));

{
XPlaneBuilder p2(space.add_planes());
auto l2 = CreateXLine(&p2, "l2", "d2", 1, 100);
auto e2 = CreateXEvent(&p2, l2, "event2", "display2", 1, 2);
CreateXStats(&p2, &e2, "group_id", 1);
for (const auto& plane_params : test_case.planes) {
XPlaneBuilder plane(space.add_planes());
plane.SetName(plane_params.name);
for (bool event_grouped : plane_params.events_grouped) {
auto l1 = CreateXLine(&plane, "l1", "d1", 1, 100);
auto e1 = CreateXEvent(&plane, l1, "event1", "display1", 1, 2);
if (event_grouped) {
CreateXStats(&plane, &e1, "group_id", 1);
}
}
}
LOG(ERROR) << space.DebugString();
EXPECT_TRUE(IsXSpaceGrouped(space));
EXPECT_EQ(IsXSpaceGrouped(space), test_case.expected_result);
}

INSTANTIATE_TEST_SUITE_P(
IsXSpaceGroupedTestSuiteInstantiation, IsXSpaceGroupedTest,
::testing::ValuesIn<IsXSpaceGroupedTestCase>({
{{{"/host:CPU", {true}}, {"/device:TPU", {true}}},
true,
"HostAndDeviceGrouped"},
{{{"/host:CPU", {false}}, {"/device:TPU", {true}}},
false,
"HostNotGrouped"},
{{{"/host:CPU", {true}},
{"/device:TPU", {true}},
{"/nonhostordevice:???", {false}}},
true,
"HostAndDeviceGroupedWithNonHostOrDevice"},
}),
[](const ::testing::TestParamInfo<IsXSpaceGroupedTest::ParamType>& info) {
return info.param.test_case_name;
});

TEST(XplaneutilsTest, TestIsHostPlane) {
XSpace xspace;
auto xplane_host_thread = FindOrAddMutablePlaneWithName(&xspace, "/host:CPU");
Expand Down