Skip to content

Commit

Permalink
[XLA:Python] Use PyEval_SetProfileAllThreads to install the python pr…
Browse files Browse the repository at this point in the history
…ofiler in all threads under Python 3.12+.

This API is thread-safe under Python 3.13 free-threading, not to mention simpler.

PiperOrigin-RevId: 714148968
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Jan 11, 2025
1 parent a5ba283 commit 3264fef
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions xla/python/profiler/internal/python_hooks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void AddEventToXLine(const PythonTraceEntry& event,
xevent.SetEndTimestampNs(event.end_time_ns);
}

#if PY_VERSION_HEX < 0x030C0000
template <typename ForEachThreadFunc>
void ForEachThread(PyThreadState* curr_thread, ForEachThreadFunc&& callback) {
// Note: PyThreadState's interp is not accessible in open source due to
Expand Down Expand Up @@ -118,6 +119,8 @@ void ForEachThread(PyThreadState* curr_thread, ForEachThreadFunc&& callback) {
#endif
}

#endif // PY_VERSION_HEX

} // namespace

/*static*/ PythonHookContext* PythonHooks::e2e_context_ = nullptr;
Expand Down Expand Up @@ -371,21 +374,29 @@ void PythonHookContext::ProfileFast(PyFrameObject* frame, int what,

// NOTE: This must be after `threading.setprofile` otherwise we
// end up recording that in our trace.
#if PY_VERSION_HEX < 0x030C0000
PyThreadState* curr_thread = PyThreadState_Get();
ForEachThread(curr_thread, [](PyThreadState* thread) {
VLOG(1) << "Setting profiler in " << thread->thread_id;
PyEval_SetProfile(&PythonHooks::ProfileFunction, nullptr);
});
PyThreadState_Swap(curr_thread);
#else // PY_VERSION_HEX >= 0x030C0000
PyEval_SetProfileAllThreads(&PythonHooks::ProfileFunction, nullptr);
#endif // PY_VERSION_HEX >= 0x030C0000
}

/*static*/ void PythonHookContext::ClearProfilerInAllThreads() {
#if PY_VERSION_HEX < 0x030C0000
PyThreadState* curr_thread = PyThreadState_Get();
ForEachThread(curr_thread, [](PyThreadState* thread) {
VLOG(1) << "Clearing profiler in " << thread->thread_id;
PyEval_SetProfile(nullptr, nullptr);
});
PyThreadState_Swap(curr_thread);
#else // PY_VERSION_HEX >= 0x030C0000
PyEval_SetProfileAllThreads(nullptr, nullptr);
#endif // PY_VERSION_HEX >= 0x030C0000

// And notify the threading library that we're done.
ThreadingSetProfile(py::none());
Expand Down

0 comments on commit 3264fef

Please sign in to comment.