Skip to content

Commit

Permalink
[MultiHostHloRunner] Add GPU profiler support to multihost_hlo_runner
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719349363
  • Loading branch information
juliagmt-google authored and Google-ML-Automation committed Jan 24, 2025
1 parent e5dba5b commit d2baef9
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
6 changes: 6 additions & 0 deletions xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ cc_library(
deps = [
":create_client",
":functional_hlo_runner",
"//tensorflow/core/profiler/lib:profiler_factory_impl",
"//tensorflow/core/profiler/lib:profiler_session_impl",
"//xla:debug_options_flags",
"//xla:status_macros",
"//xla/backends/profiler:profiler_backends", # To register the Host Tracers for GPU Plugin.
"//xla/backends/profiler/plugin:plugin_tracer", # To register the GPU Tracers with the GPU Plugin.
"//xla/pjrt:pjrt_client",
"//xla/pjrt/distributed",
"//xla/pjrt/distributed:client",
Expand All @@ -57,6 +61,8 @@ cc_library(
"@tsl//tsl/platform:platform_port",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:profiler_session",
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
] + if_cuda_or_rocm([
"//xla/service:gpu_plugin",
]) + if_cuda([
Expand Down
50 changes: 49 additions & 1 deletion xla/tools/multihost_hlo_runner/hlo_runner_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "tsl/platform/init_main.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/profiler_session.h"

namespace {
const char* const kUsage = R"(
Expand Down Expand Up @@ -101,6 +102,7 @@ struct HloRunnerConfig {
std::string execution_options_path = "";
int64_t gpu_client_initialization_timeout_sec = 300;
float gpu_client_mem_fraction = xla::GpuAllocatorConfig{}.memory_fraction;
bool enable_gpu_profiler = false;
};

} // namespace
Expand Down Expand Up @@ -202,6 +204,22 @@ RawCompileOptionsFromFlags(const HloRunnerConfig& opts) {
return out;
}

static absl::StatusOr<std::unique_ptr<tsl::ProfilerSession>>
CreateProfilerSession(const HloRunnerConfig& opts) {
if (!opts.enable_gpu_profiler) {
return nullptr;
}

tensorflow::ProfileOptions profile_options;
profile_options.set_host_tracer_level(0);
profile_options.set_device_type(tensorflow::ProfileOptions::GPU);
profile_options.set_enable_hlo_proto(true);
profile_options.set_device_tracer_level(1);

// Create a ProfilerSession with options.
return tsl::ProfilerSession::Create(profile_options);
}

static absl::Status RunMultihostHloRunner(int argc, char** argv,
HloRunnerConfig& opts) {
if (std::string error;
Expand Down Expand Up @@ -249,6 +267,18 @@ static absl::Status RunMultihostHloRunner(int argc, char** argv,
}
CHECK(env.client != nullptr);

TF_ASSIGN_OR_RETURN(std::unique_ptr<tsl::ProfilerSession> profiler_session,
CreateProfilerSession(opts));
if (profiler_session) {
// Start profiling.
absl::Status status = profiler_session->Status();
if (!status.ok()) {
LOG(ERROR) << "Failed to start profiler session: " << status;

return absl::InternalError("Failed to start profiler session");
}
}

for (int c = 1; c < argc; c++) {
const char* filename = argv[c];
std::cout << "\n** Running " << filename << " **\n";
Expand All @@ -263,6 +293,22 @@ static absl::Status RunMultihostHloRunner(int argc, char** argv,
raw_compile_options, argv[c], opts.input_format, opts.task_id));
}
}
if (profiler_session) {
// Stop profiling and collect profiling data.
tensorflow::profiler::XSpace xspace;
if (profiler_session) {
absl::Status collect_data_status = profiler_session->CollectData(&xspace);
if (!collect_data_status.ok()) {
LOG(ERROR) << "Profiler data collection failed: "
<< collect_data_status.message();
return absl::InternalError("Failed to collect profiler data.");
}
LOG(INFO) << "Profiler data collected.";
LOG(INFO) << "XSpace debug:" << xspace.DebugString();
}
}

LOG(INFO) << "Profiling complete.";
return absl::OkStatus();
}

Expand Down Expand Up @@ -337,7 +383,9 @@ int main(int argc, char** argv) {
tsl::Flag("gpu_client_mem_fraction", &opts.gpu_client_mem_fraction,
"The maximum fraction of available memory to allocate in range "
"of (0.0, 1.0). Same as XLA_CLIENT_MEM_FRACTION in the Python "
"client. Only used with the BFC allocator.")};
"client. Only used with the BFC allocator."),
tsl::Flag("enable_gpu_profiler", &opts.enable_gpu_profiler,
"Whether to enable GPU profiler.")};

xla::AppendDebugOptionsFlags(&flag_list);

Expand Down

0 comments on commit d2baef9

Please sign in to comment.