Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize validity buffer concat. #2626

Open
wants to merge 1 commit into
base: branch-25.02
Choose a base branch
from

Conversation

liurenjie1024
Copy link
Collaborator

Close #2579

@liurenjie1024 liurenjie1024 requested a review from jlowe November 26, 2024 06:24
@liurenjie1024
Copy link
Collaborator Author

cc @jlowe I did some benchmark and didn't notice much performance improvement.

@liurenjie1024
Copy link
Collaborator Author

build

// Extract appendCount bits from srcByte, starting from curSrcBitIdx
byte mask = (byte) (((1 << appendCount) - 1) & 0xFF);
srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask);
int totalRowCount = toIntExact(sliceInfo.getRowCount() + sliceInfo.getValidityBufferInfo().getBeginBit());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of this variable makes reading this confusing. It's not a total row count as the name implies. It's the end index or ending row, IIUC.

// Sets the bits in destination buffer starting from curDestBitIdx to 0
byte destByte = dest.getByte(curDestByteIdx);
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1));
if (dest.getLength() >= (curDestOffset + Integer.BYTES)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conditionals should not be in the body of the while loop. The while loop should be as simple as possible, since that's expected to be the hotspot. IMO the code should be structured into three parts similar to the following:

if (curSrcIdx % 8 != 0) {
  // read an int from the buffer
  // mask off the unused bits
  // count the bits
  // shift and store the bits
}
while (whole_ints_left_in_buffer) {
  // read int from buffer
  // count bits
  // shift and store the bits
}
if (leftover bits) {
  // read an int from the buffer (leverage padded buffer here)
  // mask off the unused bits
  // count the bits
  // shift and store the bits
}

byte destByte = dest.getByte(curDestByteIdx);
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1));
if (dest.getLength() >= (curDestOffset + Integer.BYTES)) {
// We have enough room to get an int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we always have enough space to get an integer from the destination buffer because it's at least 4-byte padded?

Comment on lines +181 to +183
byte[] destBytes = new byte[4];
dest.getBytes(destBytes, 0, curDestOffset, destBufRemBytes);
int destInt = ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).getInt();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we're doing byte-at-time here and endian stuff when we don't up above? This is still grabbing 4 bytes like the above code. Byte-at-a-time is usually a lot slower, might be faster to read the int from the buffer (we're loading 4 bytes anyway) and call Integer.reverseBytes (a HotSpot intrinsic candidate) if ByteOrder.nativeOrder == ByteOrder.LITTLE_ENDIAN.

Comment on lines +187 to +188
ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).putInt(destInt);
dest.setBytes(curDestOffset, destBytes, 0, destBufRemBytes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, can call setInt here and leverage ByteOrder to determine when we need to swap bytes or not.

@liurenjie1024 liurenjie1024 changed the base branch from branch-24.12 to branch-25.02 November 27, 2024 06:09
@ttnghia
Copy link
Collaborator

ttnghia commented Dec 6, 2024

Please update the PR description. It will be displayed in the commit log. Simply saying "closes XXX" will just show that words, which is difficult to track the changes through commit log.

@gerashegalov
Copy link
Collaborator

It will be displayed in the commit log

To be fair we do not have the automation (yet) of checking in PRs using the PR description. I follow the convention of copying the PR description as the commit message but it's not mandated and not followed widely in NVIDIA/spark* repos.

@res-life
Copy link
Collaborator

Could we copy by long instead of int? We can avoid to use toIntExact and use Long.bitCount. This may be more fast.

// Sets the bits in destination buffer starting from curDestBitIdx to 0
byte destByte = dest.getByte(curDestByteIdx);
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1) & 0xFF);
while (curSrcIdx < totalRowCount) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does this concatenate the validity buffers, one word (32 bits) at a time, in a serial manner, in Java?

Can't we use C++/CUDA for accelerating this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a separate effort by @nvdbaranec to provide GPU kernels for Kudo buffer concatenation. We need both CPU and GPU implementations. In some cases, the GPU will be the bottleneck for a stage, and CPUs will be idle while waiting for the GPU semaphore. It will be more efficient to concatenate on the CPU, even if it's slower than the GPU, when the CPU can complete the operation within the time otherwise spent waiting for the GPU.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never seen any parallel Java code in our plugin/JNI. Assuming that it is doable, would we consider doing so for the operations like this?

// Update destination byte with the bits from source byte
destByte = (byte) ((destByte | (srcByte << curDestBitIdx)) & 0xFF);
dest.setByte(curDestByteIdx, destByte);
int curDestOffset = (curDestIdx / 32) * Integer.BYTES;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using mnemonic constants

Suggested change
int curDestOffset = (curDestIdx / 32) * Integer.BYTES;
int curDestOffset = (curDestIdx / Integer.SIZE) * Integer.BYTES;

or even

Suggested change
int curDestOffset = (curDestIdx / 32) * Integer.BYTES;
int curDestOffset = curDestIdx / Byte.SIZE;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEA] Optimize kudo when merging validity buffer.
5 participants