Skip to content

Commit

Permalink
fix: reference counting (retain/release) in PerChannelBookieClient (a…
Browse files Browse the repository at this point in the history
…pache#4293)

### Motivation

This addresses the remaining gaps of apache#4289 in handling ByteBuf retain/release.
This PR will also address the concern about NioBuffer lifecycle brought up in the review of the original PR review: apache#791 (comment) .

This PR fixes several problems:
* ByteString buffer lifecycle in client, follows ByteBufList lifecycle
* ByteBufList lifecycle, moved to write promise
* Calling of write promises in AuthHandler which buffers messages while authentication is in progress. It was ignoring the promises.

### Changes

- add 2 callback parameters to writeAndFlush: cleanupActionFailedBeforeWrite and cleanupActionAfterWrite
  - use these callback actions for proper cleanup
- extract a utility class ByteStringUtil for wrapping ByteBufList or ByteBuf as concatenated zero copy ByteString
- properly handle releasing of ByteBufList in the write promise
- properly handle calling promises that are buffered while authentication is in progress
  • Loading branch information
lhotari authored and Anup Ghatage committed Jul 12, 2024
1 parent 3229970 commit a28837b
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
super.write(ctx, msg, promise);
super.flush(ctx);
} else {
waitingForAuth.add(msg);
addMsgAndPromiseToQueue(msg, promise);
}
} else if (msg instanceof BookieProtocol.Request) {
// let auth messages through, queue the rest
Expand All @@ -364,16 +364,26 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
super.write(ctx, msg, promise);
super.flush(ctx);
} else {
waitingForAuth.add(msg);
addMsgAndPromiseToQueue(msg, promise);
}
} else if (msg instanceof ByteBuf || msg instanceof ByteBufList) {
waitingForAuth.add(msg);
addMsgAndPromiseToQueue(msg, promise);
} else {
LOG.info("[{}] dropping write of message {}", ctx.channel(), msg);
}
}
}

// Add the message and the associated promise to the queue.
// The promise is added to the same queue as the message without an additional wrapper object so
// that object allocations can be avoided. A similar solution is used in Netty codebase.
private void addMsgAndPromiseToQueue(Object msg, ChannelPromise promise) {
waitingForAuth.add(msg);
if (promise != null && !promise.isVoid()) {
waitingForAuth.add(promise);
}
}

long newTxnId() {
return transactionIdGenerator.incrementAndGet();
}
Expand Down Expand Up @@ -433,10 +443,19 @@ public void operationComplete(int rc, Void v) {
if (rc == BKException.Code.OK) {
synchronized (this) {
authenticated = true;
Object msg = waitingForAuth.poll();
while (msg != null) {
NettyChannelUtil.writeAndFlushWithVoidPromise(ctx, msg);
msg = waitingForAuth.poll();
while (true) {
Object msg = waitingForAuth.poll();
if (msg == null) {
break;
}
ChannelPromise promise;
// check if the message has an associated promise as the next element in the queue
if (waitingForAuth.peek() instanceof ChannelPromise) {
promise = (ChannelPromise) waitingForAuth.poll();
} else {
promise = ctx.voidPromise();
}
ctx.writeAndFlush(msg, promise);
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,18 +261,20 @@ public void writeLac(final BookieId addr, final long ledgerId, final byte[] mast

toSend.retain();
client.obtain((rc, pcbc) -> {
if (rc != BKException.Code.OK) {
try {
executor.executeOrdered(ledgerId,
() -> cb.writeLacComplete(rc, ledgerId, addr, ctx));
} catch (RejectedExecutionException re) {
cb.writeLacComplete(getRc(BKException.Code.InterruptedException), ledgerId, addr, ctx);
try {
if (rc != BKException.Code.OK) {
try {
executor.executeOrdered(ledgerId,
() -> cb.writeLacComplete(rc, ledgerId, addr, ctx));
} catch (RejectedExecutionException re) {
cb.writeLacComplete(getRc(BKException.Code.InterruptedException), ledgerId, addr, ctx);
}
} else {
pcbc.writeLac(ledgerId, masterKey, lac, toSend, cb, ctx);
}
} else {
pcbc.writeLac(ledgerId, masterKey, lac, toSend, cb, ctx);
} finally {
ReferenceCountUtil.release(toSend);
}

ReferenceCountUtil.release(toSend);
}, ledgerId, useV3Enforced);
}

Expand Down Expand Up @@ -407,14 +409,16 @@ static ChannelReadyForAddEntryCallback create(
@Override
public void operationComplete(final int rc,
PerChannelBookieClient pcbc) {
if (rc != BKException.Code.OK) {
bookieClient.completeAdd(rc, ledgerId, entryId, addr, cb, ctx);
} else {
pcbc.addEntry(ledgerId, masterKey, entryId,
toSend, cb, ctx, options, allowFastFail, writeFlags);
try {
if (rc != BKException.Code.OK) {
bookieClient.completeAdd(rc, ledgerId, entryId, addr, cb, ctx);
} else {
pcbc.addEntry(ledgerId, masterKey, entryId,
toSend, cb, ctx, options, allowFastFail, writeFlags);
}
} finally {
ReferenceCountUtil.release(toSend);
}

ReferenceCountUtil.release(toSend);
recycle();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.bookkeeper.proto;

import com.google.protobuf.ByteString;
import com.google.protobuf.UnsafeByteOperations;
import io.netty.buffer.ByteBuf;
import java.nio.ByteBuffer;
import org.apache.bookkeeper.util.ByteBufList;

public class ByteStringUtil {

/**
* Wrap the internal buffers of a ByteBufList into a single ByteString.
* The lifecycle of the wrapped ByteString is tied to the ByteBufList.
*
* @param bufList ByteBufList to wrap
* @return ByteString wrapping the internal buffers of the ByteBufList
*/
public static ByteString byteBufListToByteString(ByteBufList bufList) {
ByteString aggregated = null;
for (int i = 0; i < bufList.size(); i++) {
ByteBuf buffer = bufList.getBuffer(i);
if (buffer.readableBytes() > 0) {
aggregated = byteBufToByteString(aggregated, buffer);
}
}
return aggregated != null ? aggregated : ByteString.EMPTY;
}

/**
* Wrap the internal buffers of a ByteBuf into a single ByteString.
* The lifecycle of the wrapped ByteString is tied to the ByteBuf.
*
* @param byteBuf ByteBuf to wrap
* @return ByteString wrapping the internal buffers of the ByteBuf
*/
public static ByteString byteBufToByteString(ByteBuf byteBuf) {
return byteBufToByteString(null, byteBuf);
}

// internal method to aggregate a ByteBuf into a single aggregated ByteString
private static ByteString byteBufToByteString(ByteString aggregated, ByteBuf byteBuf) {
if (byteBuf.readableBytes() == 0) {
return ByteString.EMPTY;
}
if (byteBuf.nioBufferCount() > 1) {
for (ByteBuffer nioBuffer : byteBuf.nioBuffers()) {
ByteString piece = UnsafeByteOperations.unsafeWrap(nioBuffer);
aggregated = (aggregated == null) ? piece : aggregated.concat(piece);
}
} else {
ByteString piece;
if (byteBuf.hasArray()) {
piece = UnsafeByteOperations.unsafeWrap(byteBuf.array(), byteBuf.arrayOffset() + byteBuf.readerIndex(),
byteBuf.readableBytes());
} else {
piece = UnsafeByteOperations.unsafeWrap(byteBuf.nioBuffer());
}
aggregated = (aggregated == null) ? piece : aggregated.concat(piece);
}
return aggregated;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -700,14 +700,10 @@ void writeLac(final long ledgerId, final byte[] masterKey, final long lac, ByteB
.setVersion(ProtocolVersion.VERSION_THREE)
.setOperation(OperationType.WRITE_LAC)
.setTxnId(txnId);
ByteString body;
if (toSend.hasArray()) {
body = UnsafeByteOperations.unsafeWrap(toSend.array(), toSend.arrayOffset(), toSend.readableBytes());
} else if (toSend.size() == 1) {
body = UnsafeByteOperations.unsafeWrap(toSend.getBuffer(0).nioBuffer());
} else {
body = UnsafeByteOperations.unsafeWrap(toSend.toArray());
}
ByteString body = ByteStringUtil.byteBufListToByteString(toSend);
toSend.retain();
Runnable cleanupActionFailedBeforeWrite = toSend::release;
Runnable cleanupActionAfterWrite = cleanupActionFailedBeforeWrite;
WriteLacRequest.Builder writeLacBuilder = WriteLacRequest.newBuilder()
.setLedgerId(ledgerId)
.setLac(lac)
Expand All @@ -718,7 +714,8 @@ void writeLac(final long ledgerId, final byte[] masterKey, final long lac, ByteB
.setHeader(headerBuilder)
.setWriteLacRequest(writeLacBuilder)
.build();
writeAndFlush(channel, completionKey, writeLacRequest);
writeAndFlush(channel, completionKey, writeLacRequest, false, cleanupActionFailedBeforeWrite,
cleanupActionAfterWrite);
}

void forceLedger(final long ledgerId, ForceLedgerCallback cb, Object ctx) {
Expand Down Expand Up @@ -777,6 +774,8 @@ void addEntry(final long ledgerId, byte[] masterKey, final long entryId, Referen
Object ctx, final int options, boolean allowFastFail, final EnumSet<WriteFlag> writeFlags) {
Object request = null;
CompletionKey completionKey = null;
Runnable cleanupActionFailedBeforeWrite = null;
Runnable cleanupActionAfterWrite = null;
if (useV2WireProtocol) {
if (writeFlags.contains(WriteFlag.DEFERRED_SYNC)) {
LOG.error("invalid writeflags {} for v2 protocol", writeFlags);
Expand All @@ -786,9 +785,14 @@ void addEntry(final long ledgerId, byte[] masterKey, final long entryId, Referen
completionKey = acquireV2Key(ledgerId, entryId, OperationType.ADD_ENTRY);

if (toSend instanceof ByteBuf) {
request = ((ByteBuf) toSend).retainedDuplicate();
ByteBuf byteBuf = ((ByteBuf) toSend).retainedDuplicate();
request = byteBuf;
cleanupActionFailedBeforeWrite = byteBuf::release;
} else {
request = ByteBufList.clone((ByteBufList) toSend);
ByteBufList byteBufList = (ByteBufList) toSend;
byteBufList.retain();
request = byteBufList;
cleanupActionFailedBeforeWrite = byteBufList::release;
}
} else {
final long txnId = getTxnId();
Expand All @@ -803,19 +807,11 @@ void addEntry(final long ledgerId, byte[] masterKey, final long entryId, Referen
headerBuilder.setPriority(DEFAULT_HIGH_PRIORITY_VALUE);
}

ByteString body = null;
ByteBufList bufToSend = (ByteBufList) toSend;

if (bufToSend.hasArray()) {
body = UnsafeByteOperations.unsafeWrap(bufToSend.array(), bufToSend.arrayOffset(),
bufToSend.readableBytes());
} else {
for (int i = 0; i < bufToSend.size(); i++) {
ByteString piece = UnsafeByteOperations.unsafeWrap(bufToSend.getBuffer(i).nioBuffer());
// use ByteString.concat to avoid byte[] allocation when toSend has multiple ByteBufs
body = (body == null) ? piece : body.concat(piece);
}
}
ByteString body = ByteStringUtil.byteBufListToByteString(bufToSend);
bufToSend.retain();
cleanupActionFailedBeforeWrite = bufToSend::release;
cleanupActionAfterWrite = cleanupActionFailedBeforeWrite;
AddRequest.Builder addBuilder = AddRequest.newBuilder()
.setLedgerId(ledgerId)
.setEntryId(entryId)
Expand All @@ -840,17 +836,9 @@ void addEntry(final long ledgerId, byte[] masterKey, final long entryId, Referen
putCompletionKeyValue(completionKey,
acquireAddCompletion(completionKey,
cb, ctx, ledgerId, entryId));
final Channel c = channel;
if (c == null) {
// Manually release the binary data(variable "request") that we manually created when it can not be sent out
// because the channel is switching.
errorOut(completionKey);
ReferenceCountUtil.release(request);
return;
} else {
// addEntry times out on backpressure
writeAndFlush(c, completionKey, request, allowFastFail);
}
// addEntry times out on backpressure
writeAndFlush(channel, completionKey, request, allowFastFail, cleanupActionFailedBeforeWrite,
cleanupActionAfterWrite);
}

public void readLac(final long ledgerId, ReadLacCallback cb, Object ctx) {
Expand Down Expand Up @@ -1005,7 +993,7 @@ private void readEntryInternal(final long ledgerId,
ReadCompletion readCompletion = new ReadCompletion(completionKey, cb, ctx, ledgerId, entryId);
putCompletionKeyValue(completionKey, readCompletion);

writeAndFlush(channel, completionKey, request, allowFastFail);
writeAndFlush(channel, completionKey, request, allowFastFail, null, null);
}

public void batchReadEntries(final long ledgerId,
Expand Down Expand Up @@ -1048,7 +1036,7 @@ private void batchReadEntriesInternal(final long ledgerId,
completionKey, cb, ctx, ledgerId, startEntryId);
putCompletionKeyValue(completionKey, readCompletion);

writeAndFlush(channel, completionKey, request, allowFastFail);
writeAndFlush(channel, completionKey, request, allowFastFail, null, null);
}

public void getBookieInfo(final long requested, GetBookieInfoCallback cb, Object ctx) {
Expand Down Expand Up @@ -1170,17 +1158,20 @@ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exceptio
private void writeAndFlush(final Channel channel,
final CompletionKey key,
final Object request) {
writeAndFlush(channel, key, request, false);
writeAndFlush(channel, key, request, false, null, null);
}

private void writeAndFlush(final Channel channel,
final CompletionKey key,
final Object request,
final boolean allowFastFail) {
final boolean allowFastFail, final Runnable cleanupActionFailedBeforeWrite,
final Runnable cleanupActionAfterWrite) {
if (channel == null) {
LOG.warn("Operation {} failed: channel == null", StringUtils.requestToString(request));
errorOut(key);
ReferenceCountUtil.release(request);
if (cleanupActionFailedBeforeWrite != null) {
cleanupActionFailedBeforeWrite.run();
}
return;
}

Expand All @@ -1195,31 +1186,39 @@ private void writeAndFlush(final Channel channel,
StringUtils.requestToString(request));

errorOut(key, BKException.Code.TooManyRequestsException);
ReferenceCountUtil.release(request);
if (cleanupActionFailedBeforeWrite != null) {
cleanupActionFailedBeforeWrite.run();
}
return;
}

try {
final long startTime = MathUtils.nowInNano();

ChannelPromise promise = channel.newPromise().addListener(future -> {
if (future.isSuccess()) {
nettyOpLogger.registerSuccessfulEvent(MathUtils.elapsedNanos(startTime), TimeUnit.NANOSECONDS);
CompletionValue completion = completionObjects.get(key);
if (completion != null) {
completion.setOutstanding();
try {
if (future.isSuccess()) {
nettyOpLogger.registerSuccessfulEvent(MathUtils.elapsedNanos(startTime), TimeUnit.NANOSECONDS);
CompletionValue completion = completionObjects.get(key);
if (completion != null) {
completion.setOutstanding();
}
} else {
nettyOpLogger.registerFailedEvent(MathUtils.elapsedNanos(startTime), TimeUnit.NANOSECONDS);
}
} finally {
if (cleanupActionAfterWrite != null) {
cleanupActionAfterWrite.run();
}
} else {
nettyOpLogger.registerFailedEvent(MathUtils.elapsedNanos(startTime), TimeUnit.NANOSECONDS);
}
});
channel.writeAndFlush(request, promise);
} catch (Throwable e) {
LOG.warn("Operation {} failed", StringUtils.requestToString(request), e);
errorOut(key);
// If the request goes into the writeAndFlush, it should be handled well by Netty. So all the exceptions we
// get here, we can release the request.
ReferenceCountUtil.release(request);
if (cleanupActionFailedBeforeWrite != null) {
cleanupActionFailedBeforeWrite.run();
}
}
}

Expand Down
Loading

0 comments on commit a28837b

Please sign in to comment.