Skip to content

Commit

Permalink
[Feature] Firestorm supports quorum write/read (#96)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Refactor the client to support quorum write/read.

### Why are the changes needed?
 Without this patch, Firestorm cannot tolerate server failures before the shuffle is accomplished.

### Does this PR introduce _any_ user-facing change?
Yes.
The user is allowed to configure quorum with:
spark.rss.data.replica=N
spark.rss.data.replica.write=W
spark.rss.data.replica.read=R

The default config is 1, 1, 1

### How was this patch tested?
All previous UTs and new UTs.

Co-authored-by: frankliee <[email protected]>
  • Loading branch information
frankliee and frankliee authored Mar 25, 2022
1 parent 25bf903 commit 17c9172
Show file tree
Hide file tree
Showing 20 changed files with 864 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ public class RssClientConfig {
public static int RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE_DEFAULT_VALUE = 60;
public static String RSS_DATA_REPLICA = "spark.rss.data.replica";
public static int RSS_DATA_REPLICA_DEFAULT_VALUE = 1;
public static String RSS_DATA_REPLICA_WRITE = "spark.rss.data.replica.write";
public static int RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE = 1;
public static String RSS_DATA_REPLICA_READ = "spark.rss.data.replica.read";
public static int RSS_DATA_REPLICA_READ_DEFAULT_VALUE = 1;
public static String RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE = "spark.rss.ozone.dfs.namenode.odfs.enable";
public static boolean RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE_DEFAULT_VALUE = false;
public static String RSS_OZONE_FS_HDFS_IMPL = "spark.rss.ozone.fs.hdfs.impl";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import com.tencent.rss.common.ShuffleBlockInfo;
import com.tencent.rss.common.ShuffleServerInfo;
import com.tencent.rss.common.util.Constants;
import com.tencent.rss.common.util.RssUtils;

public class RssShuffleManager implements ShuffleManager {

Expand All @@ -74,6 +75,9 @@ public class RssShuffleManager implements ShuffleManager {
private Map<String, Set<Long>> taskToSuccessBlockIds = Maps.newConcurrentMap();
private Map<String, Set<Long>> taskToFailedBlockIds = Maps.newConcurrentMap();
private Map<String, WriteBufferManager> taskToBufferManager = Maps.newConcurrentMap();
private final int dataReplica;
private final int dataReplicaWrite;
private final int dataReplicaRead;
private boolean heartbeatStarted = false;
private ThreadPoolExecutor threadPoolExecutor;
private EventLoop eventLoop = new EventLoop<AddBlockEvent>("ShuffleDataQueue") {
Expand Down Expand Up @@ -123,6 +127,18 @@ public void onStart() {

public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
this.sparkConf = sparkConf;

// set & check replica config
this.dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA,
RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE);
this.dataReplicaWrite = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_WRITE,
RssClientConfig.RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE);
this.dataReplicaRead = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_READ,
RssClientConfig.RSS_DATA_REPLICA_READ_DEFAULT_VALUE);
LOG.info("Check quorum config ["
+ dataReplica + ":" + dataReplicaWrite + ":" + dataReplicaRead + "]");
RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead);

this.clientType = sparkConf.get(RssClientConfig.RSS_CLIENT_TYPE,
RssClientConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE);
this.heartbeatInterval = sparkConf.getLong(RssClientConfig.RSS_HEARTBEAT_INTERVAL,
Expand All @@ -136,7 +152,8 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE);
shuffleWriteClient = ShuffleClientFactory
.getInstance()
.createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum);
.createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum,
dataReplica, dataReplicaWrite, dataReplicaRead);
registerCoordinator();
// fetch client conf and apply them if necessary and disable ESS
if (isDriver && sparkConf.getBoolean(
Expand Down Expand Up @@ -184,12 +201,11 @@ public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff

int partitionNumPerRange = sparkConf.getInt(RssClientConfig.RSS_PARTITION_NUM_PER_RANGE,
RssClientConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE);
int dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA,
RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE);

// get all register info according to coordinator's response
ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
appId, shuffleId, dependency.partitioner().numPartitions(),
partitionNumPerRange, dataReplica, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
partitionNumPerRange, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
Map<Integer, List<ShuffleServerInfo>> partitionToServers = response.getPartitionToServers();

startHeartbeat();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import com.tencent.rss.common.ShuffleBlockInfo;
import com.tencent.rss.common.ShuffleServerInfo;
import com.tencent.rss.common.util.Constants;
import com.tencent.rss.common.util.RssUtils;

public class RssShuffleManager implements ShuffleManager {

Expand All @@ -73,7 +74,9 @@ public class RssShuffleManager implements ShuffleManager {
private final ThreadPoolExecutor threadPoolExecutor;
private AtomicReference<String> id = new AtomicReference<>();
private SparkConf sparkConf;
private int dataReplica;
private final int dataReplica;
private final int dataReplicaWrite;
private final int dataReplicaRead;
private ShuffleWriteClient shuffleWriteClient;
private final Map<String, Set<Long>> taskToSuccessBlockIds;
private final Map<String, Set<Long>> taskToFailedBlockIds;
Expand Down Expand Up @@ -125,22 +128,34 @@ private synchronized void putBlockId(

public RssShuffleManager(SparkConf conf, boolean isDriver) {
this.sparkConf = conf;

// set & check replica config
this.dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA,
RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE);
this.dataReplicaWrite = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_WRITE,
RssClientConfig.RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE);
this.dataReplicaRead = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_READ,
RssClientConfig.RSS_DATA_REPLICA_READ_DEFAULT_VALUE);
LOG.info("Check quorum config ["
+ dataReplica + ":" + dataReplicaWrite + ":" + dataReplicaRead + "]");
RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead);

this.heartbeatInterval = sparkConf.getLong(RssClientConfig.RSS_HEARTBEAT_INTERVAL,
RssClientConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE);
this.heartbeatTimeout = sparkConf.getLong(RssClientConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2);
int retryMax = sparkConf.getInt(RssClientConfig.RSS_CLIENT_RETRY_MAX,
RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE);
this.clientType = sparkConf.get(RssClientConfig.RSS_CLIENT_TYPE,
RssClientConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE);
dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA,
RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE);

long retryIntervalMax = sparkConf.getLong(RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX,
RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE);
int heartBeatThreadNum = sparkConf.getInt(RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM,
RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE);
shuffleWriteClient = ShuffleClientFactory
.getInstance()
.createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum);
.createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum,
dataReplica, dataReplicaWrite, dataReplicaRead);
registerCoordinator();
// fetch client conf and apply them if necessary and disable ESS
if (isDriver && sparkConf.getBoolean(
Expand Down Expand Up @@ -184,17 +199,26 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) {
this.heartbeatInterval = sparkConf.getLong(RssClientConfig.RSS_HEARTBEAT_INTERVAL,
RssClientConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE);
this.heartbeatTimeout = sparkConf.getLong(RssClientConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2);
this.dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA,
RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE);
this.dataReplicaWrite = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_WRITE,
RssClientConfig.RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE);
this.dataReplicaRead = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_READ,
RssClientConfig.RSS_DATA_REPLICA_READ_DEFAULT_VALUE);
LOG.info("Check quorum config ["
+ dataReplica + ":" + dataReplicaWrite + ":" + dataReplicaRead + "]");
RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead);

int retryMax = sparkConf.getInt(RssClientConfig.RSS_CLIENT_RETRY_MAX,
RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE);
RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE);
long retryIntervalMax = sparkConf.getLong(RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX,
RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE);
RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE);
int heartBeatThreadNum = sparkConf.getInt(RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM,
RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE);
dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA,
RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE);
shuffleWriteClient = ShuffleClientFactory
RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE);
shuffleWriteClient = ShuffleClientFactory
.getInstance()
.createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum);
.createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum,
dataReplica, dataReplicaWrite, dataReplicaRead);
this.taskToSuccessBlockIds = taskToSuccessBlockIds;
this.taskToFailedBlockIds = taskToFailedBlockIds;
if (loop != null) {
Expand Down Expand Up @@ -225,7 +249,6 @@ public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency<
shuffleId,
dependency.partitioner().numPartitions(),
1,
dataReplica,
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
Map<Integer, List<ShuffleServerInfo>> partitionToServers = response.getPartitionToServers();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void reportShuffleResult(
int bitmapNum);

ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum,
int partitionNumPerRange, int dataReplica, Set<String> requiredTags);
int partitionNumPerRange, Set<String> requiredTags);

Roaring64NavigableMap getShuffleResult(String clientType, Set<ShuffleServerInfo> shuffleServerInfoSet,
String appId, int shuffleId, int partitionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ public static ShuffleClientFactory getInstance() {
}

public ShuffleWriteClient createShuffleWriteClient(
String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum) {
return new ShuffleWriteClientImpl(clientType, retryMax, retryIntervalMax, heartBeatThreadNum);
String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum,
int replica, int replicaWrite, int replicaRead) {
return new ShuffleWriteClientImpl(clientType, retryMax, retryIntervalMax, heartBeatThreadNum,
replica, replicaWrite, replicaRead);
}

public ShuffleReadClient createShuffleReadClient(CreateShuffleReadClientRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,21 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient {
private Set<ShuffleServerInfo> shuffleServerInfoSet = Sets.newConcurrentHashSet();
private CoordinatorClientFactory coordinatorClientFactory;
private ExecutorService heartBeatExecutorService;
private int replica;
private int replicaWrite;
private int replicaRead;

public ShuffleWriteClientImpl(String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum) {
public ShuffleWriteClientImpl(String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum,
int replica, int replicaWrite, int replicaRead) {
this.clientType = clientType;
this.retryMax = retryMax;
this.retryIntervalMax = retryIntervalMax;
coordinatorClientFactory = new CoordinatorClientFactory(clientType);
heartBeatExecutorService = Executors.newFixedThreadPool(heartBeatThreadNum,
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("client-heartbeat-%d").build());
this.replica = replica;
this.replicaWrite = replicaWrite;
this.replicaRead = replicaRead;
}

private void sendShuffleDataAsync(
Expand All @@ -95,30 +102,50 @@ private void sendShuffleDataAsync(
Map<ShuffleServerInfo, List<Long>> serverToBlockIds,
Set<Long> successBlockIds,
Set<Long> tempFailedBlockIds) {

// maintain the count of blocks that have been sent to the server
Map<Long, AtomicInteger> blockIdsTracker = Maps.newConcurrentMap();
serverToBlockIds.values().forEach(
blockList -> blockList.forEach(block -> blockIdsTracker.put(block, new AtomicInteger(0)))
);

if (serverToBlocks != null) {
serverToBlocks.entrySet().parallelStream().forEach(entry -> {
ShuffleServerInfo ssi = entry.getKey();
try {
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks = entry.getValue();
// todo: compact unnecessary blocks that reach replicaWrite
RssSendShuffleDataRequest request = new RssSendShuffleDataRequest(
appId, retryMax, retryIntervalMax, entry.getValue());
appId, retryMax, retryIntervalMax, shuffleIdToBlocks);
long s = System.currentTimeMillis();
RssSendShuffleDataResponse response = getShuffleServerClient(ssi).sendShuffleData(request);
LOG.info("ShuffleWriteClientImpl sendShuffleData cost:" + (System.currentTimeMillis() - s));

if (response.getStatusCode() == ResponseStatusCode.SUCCESS) {
successBlockIds.addAll(serverToBlockIds.get(ssi));
// mark a replica of block that has been sent
serverToBlockIds.get(ssi).forEach(block -> blockIdsTracker.get(block).incrementAndGet());
LOG.info("Send: " + serverToBlockIds.get(ssi).size()
+ " blocks to [" + ssi.getId() + "] successfully");
} else {
tempFailedBlockIds.addAll(serverToBlockIds.get(ssi));
LOG.error("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId()
LOG.warn("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId()
+ "] failed with statusCode[" + response.getStatusCode() + "], ");
}
} catch (Exception e) {
tempFailedBlockIds.addAll(serverToBlockIds.get(ssi));
LOG.error("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId() + "] failed.", e);
LOG.warn("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId() + "] failed.", e);
}
});

// check success and failed blocks according to the replicaWrite
blockIdsTracker.entrySet().forEach(blockCt -> {
long blockId = blockCt.getKey();
int count = blockCt.getValue().get();
if (count >= replicaWrite) {
successBlockIds.add(blockId);
} else {
tempFailedBlockIds.add(blockId);
}
}
);
}
}

Expand Down Expand Up @@ -162,7 +189,6 @@ public SendShuffleDataResult sendShuffleData(String appId, List<ShuffleBlockInfo
// if send block failed, the task will fail
// todo: better to have fallback solution when send to multiple servers
sendShuffleDataAsync(appId, serverToBlocks, serverToBlockIds, successBlockIds, failedBlockIds);

return new SendShuffleDataResult(successBlockIds, failedBlockIds);
}

Expand Down Expand Up @@ -243,9 +269,9 @@ public Map<String, String> fetchClientConf(int timeoutMs) {

@Override
public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum,
int partitionNumPerRange, int dataReplica, Set<String> requiredTags) {
int partitionNumPerRange, Set<String> requiredTags) {
RssGetShuffleAssignmentsRequest request = new RssGetShuffleAssignmentsRequest(
appId, shuffleId, partitionNum, partitionNumPerRange, dataReplica, requiredTags);
appId, shuffleId, partitionNum, partitionNumPerRange, replica, requiredTags);

RssGetShuffleAssignmentsResponse response = new RssGetShuffleAssignmentsResponse(ResponseStatusCode.INTERNAL_ERROR);
for (CoordinatorClient coordinatorClient : coordinatorClients) {
Expand Down Expand Up @@ -285,6 +311,7 @@ public void reportShuffleResult(
groupedPartitions.get(ssi).add(entry.getKey());
}
}
int successCnt = 0;
for (Map.Entry<ShuffleServerInfo, List<Integer>> entry : groupedPartitions.entrySet()) {
Map<Integer, List<Long>> requestBlockIds = Maps.newHashMap();
for (Integer partitionId : entry.getValue()) {
Expand All @@ -298,20 +325,17 @@ public void reportShuffleResult(
if (response.getStatusCode() == ResponseStatusCode.SUCCESS) {
LOG.info("Report shuffle result to " + ssi + " for appId[" + appId
+ "], shuffleId[" + shuffleId + "] successfully");
successCnt++;
} else {
isSuccessful = false;
LOG.warn("Report shuffle result to " + ssi + " for appId[" + appId
+ "], shuffleId[" + shuffleId + "] failed with " + response.getStatusCode());
break;
}
} catch (Exception e) {
isSuccessful = false;
LOG.warn("Report shuffle result is failed to " + ssi
+ " for appId[" + appId + "], shuffleId[" + shuffleId + "]");
break;
}
}
if (!isSuccessful) {
if (successCnt < replicaWrite) {
throw new RssException("Report shuffle result is failed for appId["
+ appId + "], shuffleId[" + shuffleId + "]");
}
Expand All @@ -324,14 +348,20 @@ public Roaring64NavigableMap getShuffleResult(String clientType, Set<ShuffleServ
appId, shuffleId, partitionId);
boolean isSuccessful = false;
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
int successCnt = 0;
for (ShuffleServerInfo ssi : shuffleServerInfoSet) {
try {
RssGetShuffleResultResponse response = ShuffleServerClientFactory
.getInstance().getShuffleServerClient(clientType, ssi).getShuffleResult(request);
if (response.getStatusCode() == ResponseStatusCode.SUCCESS) {
blockIdBitmap = response.getBlockIdBitmap();
isSuccessful = true;
break;
// merge into blockIds from multiple servers.
Roaring64NavigableMap blockIdBitmapOfServer = response.getBlockIdBitmap();
blockIdBitmap.or(blockIdBitmapOfServer);
successCnt++;
if (successCnt >= replicaRead) {
isSuccessful = true;
break;
}
}
} catch (Exception e) {
LOG.warn("Get shuffle result is failed from " + ssi
Expand Down Expand Up @@ -407,7 +437,8 @@ private void throwExceptionIfNecessary(ClientResponse response, String errorMsg)
}

@VisibleForTesting
protected ShuffleServerClient getShuffleServerClient(ShuffleServerInfo shuffleServerInfo) {
public ShuffleServerClient getShuffleServerClient(ShuffleServerInfo shuffleServerInfo) {
return ShuffleServerClientFactory.getInstance().getShuffleServerClient(clientType, shuffleServerInfo);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class ShuffleWriteClientImplTest {
@Test
public void testSendData() {
ShuffleWriteClientImpl shuffleWriteClient =
new ShuffleWriteClientImpl("GRPC", 3, 2000, 4);
new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 1, 1, 1);
ShuffleServerClient mockShuffleServerClient = mock(ShuffleServerClient.class);
ShuffleWriteClientImpl spyClient = spy(shuffleWriteClient);
doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,4 +240,13 @@ public static <T> List<T> loadExtensions(
}
return extensions;
}

public static void checkQuorumSetting(int replica, int replicaWrite, int replicaRead) {
if (replica < 1 || replicaWrite > replica || replicaRead > replica) {
throw new RuntimeException("Replica config is invalid, recommend replica.write + replica.read > replica");
}
if (replicaWrite + replicaRead <= replica) {
throw new RuntimeException("Replica config is unsafe, recommend replica.write + replica.read > replica");
}
}
}
Loading

0 comments on commit 17c9172

Please sign in to comment.