Skip to content

Commit

Permalink
Optimize date/time truncation for scalar format (#2687)
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia authored Dec 13, 2024
1 parent daf34f1 commit 8805168
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 90 deletions.
30 changes: 22 additions & 8 deletions src/main/cpp/src/DateTimeUtilsJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,34 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_DateTimeUtils_rebaseJul
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_DateTimeUtils_truncate(JNIEnv* env,
jclass,
jlong datetime,
jlong format)
JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_DateTimeUtils_truncateWithColumnFormat(
JNIEnv* env, jclass, jlong datetime, jlong format)
{
JNI_NULL_CHECK(env, datetime, "datetime column is null", 0);
JNI_NULL_CHECK(env, format, "format column is null", 0);
JNI_NULL_CHECK(env, datetime, "input datetime is null", 0);
JNI_NULL_CHECK(env, format, "input format is null", 0);

try {
cudf::jni::auto_set_device(env);

auto const datetime_cv = reinterpret_cast<cudf::column_view const*>(datetime);
auto const format_cv = reinterpret_cast<cudf::column_view const*>(format);
auto output = spark_rapids_jni::truncate(*datetime_cv, *format_cv);
return reinterpret_cast<jlong>(output.release());
return reinterpret_cast<jlong>(spark_rapids_jni::truncate(*datetime_cv, *format_cv).release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_DateTimeUtils_truncateWithScalarFormat(
JNIEnv* env, jclass, jlong datetime, jstring format)
{
JNI_NULL_CHECK(env, datetime, "input datetime is null", 0);

try {
cudf::jni::auto_set_device(env);

auto const datetime_cv = reinterpret_cast<cudf::column_view const*>(datetime);
auto const format_jstr = cudf::jni::native_jstring(env, format);
auto const format = std::string(format_jstr.get(), format_jstr.size_bytes());
return reinterpret_cast<jlong>(spark_rapids_jni::truncate(*datetime_cv, format).release());
}
CATCH_STD(env, 0);
}
Expand Down
Loading

0 comments on commit 8805168

Please sign in to comment.