Skip to content

Commit

Permalink
PR #20494: Update slop_factor flag desc in debug_options_flags.cc
Browse files Browse the repository at this point in the history
Imported from GitHub PR #20494

Copybara import of the project:

--
04a8e94 by Sevin Varoglu <[email protected]>:

Update slop_factor flag desc in debug_options_flags.cc

--
4a5d4fe by Sevin Varoglu <[email protected]>:

Fix error

--
0347b54 by Sevin Varoglu <[email protected]>:

Add default value

Merging this change closes #20494

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20494 from sfvaroglu:sevin/update_comment 0347b54
PiperOrigin-RevId: 713773831
  • Loading branch information
sfvaroglu authored and Google-ML-Automation committed Jan 11, 2025
1 parent 266599d commit 6e8204f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
21 changes: 20 additions & 1 deletion xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1645,7 +1645,26 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"xla_gpu_memory_limit_slop_factor",
int32_setter_for(&DebugOptions::set_xla_gpu_memory_limit_slop_factor),
debug_options->xla_gpu_memory_limit_slop_factor(),
"Slop factor for memory limits in XLA:GPU"));
"Slop factor for memory limits in XLA:GPU. This flag serves as a "
"multiplier "
"applied to the total available memory, creating a threshold that guides "
"the "
"Latency Hiding Scheduler (LHS) in balancing memory reduction and "
"latency "
"hiding optimizations. This factor effectively establishes a memory "
"limit "
"for compiler passes, determining when the scheduler should prioritize: "
" 1. Memory reduction: When memory usage approaches or exceeds the "
"calculated "
" threshold. "
" 2. Latency hiding: When memory usage is below the threshold, allowing "
"for "
" more aggressive optimizations that may temporarily increase memory "
"usage "
" but improve overall performance. "
"By adjusting this factor, users can fine-tune the trade-off between "
"memory "
"efficiency and performance optimizations. The default value is 95."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_highest_priority_async_stream",
bool_setter_for(
Expand Down
5 changes: 4 additions & 1 deletion xla/python/pjit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ void PjitFunction::InitExecutables() {
}
}

PjitFunction::~PjitFunction() = default;
PjitFunction::~PjitFunction() {
nb::ft_object_guard lock(cache_);
executables_ = nullptr;
}

void CallShardArgFallback(
nb::handle arg, nb::handle sharding, nb::handle layout,
Expand Down
11 changes: 11 additions & 0 deletions xla/python/profiler/internal/python_hooks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void AddEventToXLine(const PythonTraceEntry& event,
xevent.SetEndTimestampNs(event.end_time_ns);
}

#if PY_VERSION_HEX < 0x030C0000
template <typename ForEachThreadFunc>
void ForEachThread(PyThreadState* curr_thread, ForEachThreadFunc&& callback) {
// Note: PyThreadState's interp is not accessible in open source due to
Expand Down Expand Up @@ -118,6 +119,8 @@ void ForEachThread(PyThreadState* curr_thread, ForEachThreadFunc&& callback) {
#endif
}

#endif // PY_VERSION_HEX

} // namespace

/*static*/ PythonHookContext* PythonHooks::e2e_context_ = nullptr;
Expand Down Expand Up @@ -371,21 +374,29 @@ void PythonHookContext::ProfileFast(PyFrameObject* frame, int what,

// NOTE: This must be after `threading.setprofile` otherwise we
// end up recording that in our trace.
#if PY_VERSION_HEX < 0x030C0000
PyThreadState* curr_thread = PyThreadState_Get();
ForEachThread(curr_thread, [](PyThreadState* thread) {
VLOG(1) << "Setting profiler in " << thread->thread_id;
PyEval_SetProfile(&PythonHooks::ProfileFunction, nullptr);
});
PyThreadState_Swap(curr_thread);
#else // PY_VERSION_HEX >= 0x030C0000
PyEval_SetProfileAllThreads(&PythonHooks::ProfileFunction, nullptr);
#endif // PY_VERSION_HEX >= 0x030C0000
}

/*static*/ void PythonHookContext::ClearProfilerInAllThreads() {
#if PY_VERSION_HEX < 0x030C0000
PyThreadState* curr_thread = PyThreadState_Get();
ForEachThread(curr_thread, [](PyThreadState* thread) {
VLOG(1) << "Clearing profiler in " << thread->thread_id;
PyEval_SetProfile(nullptr, nullptr);
});
PyThreadState_Swap(curr_thread);
#else // PY_VERSION_HEX >= 0x030C0000
PyEval_SetProfileAllThreads(nullptr, nullptr);
#endif // PY_VERSION_HEX >= 0x030C0000

// And notify the threading library that we're done.
ThreadingSetProfile(py::none());
Expand Down

0 comments on commit 6e8204f

Please sign in to comment.