Skip to content

Commit

Permalink
PR #20911: [XLA:GPU] Update cudnn frontend version to 1.9
Browse files Browse the repository at this point in the history
Imported from GitHub PR #20911

cudnn frontend 1.9 is released, there are some new features that cudnn flash attention will incorporate, hence this PR.
* flex attention with arbitrary pointwise operations after softmax in cudnn flash attention graph.
* [sequence packing](#20861) enhancement with reduced workspace size.

Release note: https://github.com/NVIDIA/cudnn-frontend/releases/tag/v1.9.0
Copybara import of the project:

--
07a0d7a by cjkkkk <[email protected]>:

update

Merging this change closes #20911

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20911 from Cjkkkk:update_cudnn_fe_1.9 07a0d7a
PiperOrigin-RevId: 713680032
  • Loading branch information
Cjkkkk authored and Google-ML-Automation committed Jan 11, 2025
1 parent 266599d commit 3ce1143
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
6 changes: 3 additions & 3 deletions workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def _tf_repositories():
name = "cudnn_frontend_archive",
build_file = "//third_party:cudnn_frontend.BUILD",
patch_file = ["//third_party:cudnn_frontend_header_fix.patch"],
sha256 = "5f77784dc3ccbca7aca5ea0b5a6e31b95aa85023c5942d22be5fa8dd6c339d81",
strip_prefix = "cudnn-frontend-1.8.0",
urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.8.0.zip"),
sha256 = "7be8afebc693f0ef75bbc673ce5c1cf422673e84ea7d53e488201756c046496e",
strip_prefix = "cudnn-frontend-1.9.0",
urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.9.0.zip"),
)

tf_http_archive(
Expand Down
5 changes: 4 additions & 1 deletion xla/python/pjit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ void PjitFunction::InitExecutables() {
}
}

PjitFunction::~PjitFunction() = default;
PjitFunction::~PjitFunction() {
nb::ft_object_guard lock(cache_);
executables_ = nullptr;
}

void CallShardArgFallback(
nb::handle arg, nb::handle sharding, nb::handle layout,
Expand Down
11 changes: 11 additions & 0 deletions xla/python/profiler/internal/python_hooks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void AddEventToXLine(const PythonTraceEntry& event,
xevent.SetEndTimestampNs(event.end_time_ns);
}

#if PY_VERSION_HEX < 0x030C0000
template <typename ForEachThreadFunc>
void ForEachThread(PyThreadState* curr_thread, ForEachThreadFunc&& callback) {
// Note: PyThreadState's interp is not accessible in open source due to
Expand Down Expand Up @@ -118,6 +119,8 @@ void ForEachThread(PyThreadState* curr_thread, ForEachThreadFunc&& callback) {
#endif
}

#endif // PY_VERSION_HEX

} // namespace

/*static*/ PythonHookContext* PythonHooks::e2e_context_ = nullptr;
Expand Down Expand Up @@ -371,21 +374,29 @@ void PythonHookContext::ProfileFast(PyFrameObject* frame, int what,

// NOTE: This must be after `threading.setprofile` otherwise we
// end up recording that in our trace.
#if PY_VERSION_HEX < 0x030C0000
PyThreadState* curr_thread = PyThreadState_Get();
ForEachThread(curr_thread, [](PyThreadState* thread) {
VLOG(1) << "Setting profiler in " << thread->thread_id;
PyEval_SetProfile(&PythonHooks::ProfileFunction, nullptr);
});
PyThreadState_Swap(curr_thread);
#else // PY_VERSION_HEX >= 0x030C0000
PyEval_SetProfileAllThreads(&PythonHooks::ProfileFunction, nullptr);
#endif // PY_VERSION_HEX >= 0x030C0000
}

/*static*/ void PythonHookContext::ClearProfilerInAllThreads() {
#if PY_VERSION_HEX < 0x030C0000
PyThreadState* curr_thread = PyThreadState_Get();
ForEachThread(curr_thread, [](PyThreadState* thread) {
VLOG(1) << "Clearing profiler in " << thread->thread_id;
PyEval_SetProfile(nullptr, nullptr);
});
PyThreadState_Swap(curr_thread);
#else // PY_VERSION_HEX >= 0x030C0000
PyEval_SetProfileAllThreads(nullptr, nullptr);
#endif // PY_VERSION_HEX >= 0x030C0000

// And notify the threading library that we're done.
ThreadingSetProfile(py::none());
Expand Down

0 comments on commit 3ce1143

Please sign in to comment.