Skip to content

Commit

Permalink
Update resource adaptor for rmm
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Mattione <[email protected]>
  • Loading branch information
pmattione-nvidia committed May 22, 2024
1 parent 79253a9 commit 634b051
Showing 1 changed file with 50 additions and 11 deletions.
61 changes: 50 additions & 11 deletions src/main/cpp/src/SparkResourceAdaptorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cudf_jni_apis.hpp>
#include <pthread.h>
Expand Down Expand Up @@ -384,10 +384,10 @@ class full_thread_state {
* mitigation we might want to do to avoid killing a task with an out of
* memory error.
*/
class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
class spark_resource_adaptor final {
public:
spark_resource_adaptor(JNIEnv* env,
rmm::mr::device_memory_resource* mr,
rmm::device_async_resource_ref mr,
std::shared_ptr<spdlog::logger>& logger,
bool const is_log_enabled)
: resource{mr}, logger{logger}, is_log_enabled{is_log_enabled}
Expand All @@ -399,7 +399,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
logger->set_pattern("%H:%M:%S.%f,%v");
}

rmm::mr::device_memory_resource* get_wrapped_resource() { return resource; }
rmm::device_async_resource_ref get_wrapped_resource() { return resource; }

/**
* Update the internal state so that a specific thread is dedicated to a task.
Expand Down Expand Up @@ -870,7 +870,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
}

private:
rmm::mr::device_memory_resource* const resource;
rmm::device_async_resource_ref resource;
std::shared_ptr<spdlog::logger> logger; ///< spdlog logger object
bool const is_log_enabled;

Expand Down Expand Up @@ -1728,13 +1728,46 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
return ret;
}

void* do_allocate(std::size_t const num_bytes, rmm::cuda_stream_view stream) override
/**
* Sync allocation method required to satisfy cuda::mr::resource concept
* Synchronous memory allocations are not supported
*/
void* allocate(std::size_t, std::size_t) { return nullptr; }

/**
* Sync deallocation method required to satisfy cuda::mr::resource concept
* Asynchronous memory allocations are not supported
*/
void deallocate(void*, std::size_t, std::size_t) {}

/**
* Equality comparison method required to satisfy cuda::mr::resource concept
*/
friend bool operator==(const spark_resource_adaptor& lhs, const spark_resource_adaptor& rhs)
{
return (lhs.resource == rhs.resource) && (lhs.jvm == rhs.jvm);
}

/**
* Equality comparison method required to satisfy cuda::mr::resource concept
*/
friend bool operator!=(const spark_resource_adaptor& lhs, const spark_resource_adaptor& rhs)
{
return !(lhs == rhs);
}

/**
* Async allocation method required to satisfy cuda::mr::async_resource concept
*/
void* allocate_async(std::size_t const num_bytes,
std::size_t const alignment,
rmm::cuda_stream_view stream)
{
auto const tid = static_cast<long>(pthread_self());
while (true) {
bool const likely_spill = pre_alloc(tid);
try {
void* ret = resource->allocate(num_bytes, stream);
void* ret = resource.allocate_async(num_bytes, alignment, stream);
post_alloc_success(tid, likely_spill);
return ret;
} catch (rmm::out_of_memory const& e) {
Expand Down Expand Up @@ -1787,9 +1820,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
wake_next_highest_priority_blocked(lock, true, is_for_cpu);
}

void do_deallocate(void* p, std::size_t size, rmm::cuda_stream_view stream) override
/**
* Async deallocation method required to satisfy cuda::mr::async_resource concept
*/
void deallocate_async(void* p,
std::size_t size,
std::size_t const alignment,
rmm::cuda_stream_view stream)
{
resource->deallocate(p, size, stream);
resource.deallocate_async(p, size, alignment, stream);
// deallocate success
if (size > 0) {
std::unique_lock<std::mutex> lock(state_mutex);
Expand Down Expand Up @@ -1818,7 +1857,7 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_cr
JNI_NULL_CHECK(env, child, "child is null", 0);
try {
cudf::jni::auto_set_device(env);
auto wrapped = reinterpret_cast<rmm::mr::device_memory_resource*>(child);
auto wrapped = reinterpret_cast<rmm::device_async_resource_ref*>(child);
cudf::jni::native_jstring nlogloc(env, log_loc);
std::shared_ptr<spdlog::logger> logger;
bool is_log_enabled;
Expand All @@ -1837,7 +1876,7 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_SparkResourceAdaptor_cr
}
}

auto ret = new spark_resource_adaptor(env, wrapped, logger, is_log_enabled);
auto ret = new spark_resource_adaptor(env, *wrapped, logger, is_log_enabled);
return cudf::jni::ptr_as_jlong(ret);
}
CATCH_STD(env, 0)
Expand Down

0 comments on commit 634b051

Please sign in to comment.