Skip to content

Commit

Permalink
use benchmark::Counter to get average flop and throughput
Browse files Browse the repository at this point in the history
  • Loading branch information
AD2605 committed Aug 16, 2024
1 parent 9056e3f commit f784ded
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions test/bench/portfft/launch_bench.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ void bench_dft_average_host_time_impl(benchmark::State& state, sycl::queue q, po
#endif // PORTFFT_VERIFY_BENCHMARKS
std::vector<sycl::event> dependencies;
dependencies.reserve(1);
state.counters["flops"] = 0;
state.counters["throughput"] = 0;
double flop_sum = 0;
double throughput_sum = 0;
for (auto _ : state) {
// we need to manually measure time, so as to have it available here for the
// calculation of flops
Expand Down Expand Up @@ -136,12 +136,12 @@ void bench_dft_average_host_time_impl(benchmark::State& state, sycl::queue q, po
}
double elapsed_seconds =
std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count() / static_cast<double>(runs);
state.counters["flops"] += ops / elapsed_seconds;
state.counters["throughput"] += static_cast<double>(bytes_transferred) / elapsed_seconds;
flop_sum += ops / elapsed_seconds;
throughput_sum += static_cast<double>(bytes_transferred) / elapsed_seconds;
state.SetIterationTime(elapsed_seconds);
}
state.counters["flops"] /= static_cast<double>(state.iterations());
state.counters["throughput"] /= static_cast<double>(state.iterations());
state.counters["flops"] = benchmark::Counter(flop_sum, benchmark::Counter::kAvgIterations);
state.counters["throughput"] = benchmark::Counter(throughput_sum, benchmark::Counter::kAvgIterations);
}

/**
Expand Down Expand Up @@ -215,8 +215,8 @@ void bench_dft_device_time_impl(benchmark::State& state, sycl::queue q, portfft:
verify_dft<portfft::direction::FORWARD, portfft::complex_storage::INTERLEAVED_COMPLEX>(desc, backward_data,
host_output, 1e-2);
#endif // PORTFFT_VERIFY_BENCHMARKS
state.counters["flops"] = 0;
state.counters["throughput"] = 0;
double flop_sum = 0;
double throughput_sum = 0;
for (auto _ : state) {
// Write to the input to invalidate cache
q.copy(host_forward_data.data(), in_dev.get(), num_elements).wait();
Expand All @@ -225,12 +225,12 @@ void bench_dft_device_time_impl(benchmark::State& state, sycl::queue q, portfft:
auto start = e.get_profiling_info<sycl::info::event_profiling::command_start>();
auto end = e.get_profiling_info<sycl::info::event_profiling::command_end>();
double elapsed_seconds = static_cast<double>(end - start) / 1e9;
state.counters["flops"] += ops / elapsed_seconds;
state.counters["throughput"] += static_cast<double>(bytes_transferred) / elapsed_seconds;
flop_sum += ops / elapsed_seconds;
throughput_sum += static_cast<double>(bytes_transferred) / elapsed_seconds;
state.SetIterationTime(elapsed_seconds);
}
state.counters["flops"] /= static_cast<double>(state.iterations());
state.counters["throughput"] /= static_cast<double>(state.iterations());
state.counters["flops"] = benchmark::Counter(flop_sum, benchmark::Counter::kAvgIterations);
state.counters["throughput"] = benchmark::Counter(throughput_sum, benchmark::Counter::kAvgIterations);
}

/**
Expand Down

0 comments on commit f784ded

Please sign in to comment.