diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssClientConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssClientConfig.java index 53b405a0..0f0736af 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssClientConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssClientConfig.java @@ -73,6 +73,10 @@ public class RssClientConfig { public static int RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE_DEFAULT_VALUE = 60; public static String RSS_DATA_REPLICA = "spark.rss.data.replica"; public static int RSS_DATA_REPLICA_DEFAULT_VALUE = 1; + public static String RSS_DATA_REPLICA_WRITE = "spark.rss.data.replica.write"; + public static int RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE = 1; + public static String RSS_DATA_REPLICA_READ = "spark.rss.data.replica.read"; + public static int RSS_DATA_REPLICA_READ_DEFAULT_VALUE = 1; public static String RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE = "spark.rss.ozone.dfs.namenode.odfs.enable"; public static boolean RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE_DEFAULT_VALUE = false; public static String RSS_OZONE_FS_HDFS_IMPL = "spark.rss.ozone.fs.hdfs.impl"; diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 2efc0bc1..6bf229c2 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -60,6 +60,7 @@ import com.tencent.rss.common.ShuffleBlockInfo; import com.tencent.rss.common.ShuffleServerInfo; import com.tencent.rss.common.util.Constants; +import com.tencent.rss.common.util.RssUtils; public class RssShuffleManager implements ShuffleManager { @@ -74,6 +75,9 @@ public class RssShuffleManager implements ShuffleManager { private Map> taskToSuccessBlockIds = Maps.newConcurrentMap(); private Map> taskToFailedBlockIds = Maps.newConcurrentMap(); private Map taskToBufferManager = Maps.newConcurrentMap(); + private final int dataReplica; + private final int dataReplicaWrite; + private final int dataReplicaRead; private boolean heartbeatStarted = false; private ThreadPoolExecutor threadPoolExecutor; private EventLoop eventLoop = new EventLoop("ShuffleDataQueue") { @@ -123,6 +127,18 @@ public void onStart() { public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { this.sparkConf = sparkConf; + + // set & check replica config + this.dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA, + RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE); + this.dataReplicaWrite = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_WRITE, + RssClientConfig.RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE); + this.dataReplicaRead = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_READ, + RssClientConfig.RSS_DATA_REPLICA_READ_DEFAULT_VALUE); + LOG.info("Check quorum config [" + + dataReplica + ":" + dataReplicaWrite + ":" + dataReplicaRead + "]"); + RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); + this.clientType = sparkConf.get(RssClientConfig.RSS_CLIENT_TYPE, RssClientConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE); this.heartbeatInterval = sparkConf.getLong(RssClientConfig.RSS_HEARTBEAT_INTERVAL, @@ -136,7 +152,8 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE); shuffleWriteClient = ShuffleClientFactory .getInstance() - .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum); + .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, + dataReplica, dataReplicaWrite, dataReplicaRead); registerCoordinator(); // fetch client conf and apply them if necessary and disable ESS if (isDriver && sparkConf.getBoolean( @@ -184,12 +201,11 @@ public ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff int partitionNumPerRange = sparkConf.getInt(RssClientConfig.RSS_PARTITION_NUM_PER_RANGE, RssClientConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE); - int dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA, - RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE); + // get all register info according to coordinator's response ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments( appId, shuffleId, dependency.partitioner().numPartitions(), - partitionNumPerRange, dataReplica, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION)); + partitionNumPerRange, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION)); Map> partitionToServers = response.getPartitionToServers(); startHeartbeat(); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 00dc4773..6e544a01 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -63,6 +63,7 @@ import com.tencent.rss.common.ShuffleBlockInfo; import com.tencent.rss.common.ShuffleServerInfo; import com.tencent.rss.common.util.Constants; +import com.tencent.rss.common.util.RssUtils; public class RssShuffleManager implements ShuffleManager { @@ -73,7 +74,9 @@ public class RssShuffleManager implements ShuffleManager { private final ThreadPoolExecutor threadPoolExecutor; private AtomicReference id = new AtomicReference<>(); private SparkConf sparkConf; - private int dataReplica; + private final int dataReplica; + private final int dataReplicaWrite; + private final int dataReplicaRead; private ShuffleWriteClient shuffleWriteClient; private final Map> taskToSuccessBlockIds; private final Map> taskToFailedBlockIds; @@ -125,6 +128,18 @@ private synchronized void putBlockId( public RssShuffleManager(SparkConf conf, boolean isDriver) { this.sparkConf = conf; + + // set & check replica config + this.dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA, + RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE); + this.dataReplicaWrite = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_WRITE, + RssClientConfig.RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE); + this.dataReplicaRead = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_READ, + RssClientConfig.RSS_DATA_REPLICA_READ_DEFAULT_VALUE); + LOG.info("Check quorum config [" + + dataReplica + ":" + dataReplicaWrite + ":" + dataReplicaRead + "]"); + RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); + this.heartbeatInterval = sparkConf.getLong(RssClientConfig.RSS_HEARTBEAT_INTERVAL, RssClientConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE); this.heartbeatTimeout = sparkConf.getLong(RssClientConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2); @@ -132,15 +147,15 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE); this.clientType = sparkConf.get(RssClientConfig.RSS_CLIENT_TYPE, RssClientConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE); - dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA, - RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE); + long retryIntervalMax = sparkConf.getLong(RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX, RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE); int heartBeatThreadNum = sparkConf.getInt(RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM, RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE); shuffleWriteClient = ShuffleClientFactory .getInstance() - .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum); + .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, + dataReplica, dataReplicaWrite, dataReplicaRead); registerCoordinator(); // fetch client conf and apply them if necessary and disable ESS if (isDriver && sparkConf.getBoolean( @@ -184,17 +199,26 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { this.heartbeatInterval = sparkConf.getLong(RssClientConfig.RSS_HEARTBEAT_INTERVAL, RssClientConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE); this.heartbeatTimeout = sparkConf.getLong(RssClientConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2); + this.dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA, + RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE); + this.dataReplicaWrite = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_WRITE, + RssClientConfig.RSS_DATA_REPLICA_WRITE_DEFAULT_VALUE); + this.dataReplicaRead = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA_READ, + RssClientConfig.RSS_DATA_REPLICA_READ_DEFAULT_VALUE); + LOG.info("Check quorum config [" + + dataReplica + ":" + dataReplicaWrite + ":" + dataReplicaRead + "]"); + RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); + int retryMax = sparkConf.getInt(RssClientConfig.RSS_CLIENT_RETRY_MAX, - RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE); + RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE); long retryIntervalMax = sparkConf.getLong(RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX, - RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE); + RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE); int heartBeatThreadNum = sparkConf.getInt(RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM, - RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE); - dataReplica = sparkConf.getInt(RssClientConfig.RSS_DATA_REPLICA, - RssClientConfig.RSS_DATA_REPLICA_DEFAULT_VALUE); - shuffleWriteClient = ShuffleClientFactory + RssClientConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE); + shuffleWriteClient = ShuffleClientFactory .getInstance() - .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum); + .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, + dataReplica, dataReplicaWrite, dataReplicaRead); this.taskToSuccessBlockIds = taskToSuccessBlockIds; this.taskToFailedBlockIds = taskToFailedBlockIds; if (loop != null) { @@ -225,7 +249,6 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< shuffleId, dependency.partitioner().numPartitions(), 1, - dataReplica, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION)); Map> partitionToServers = response.getPartitionToServers(); diff --git a/client/src/main/java/com/tencent/rss/client/api/ShuffleWriteClient.java b/client/src/main/java/com/tencent/rss/client/api/ShuffleWriteClient.java index b9c79778..e79bfdff 100644 --- a/client/src/main/java/com/tencent/rss/client/api/ShuffleWriteClient.java +++ b/client/src/main/java/com/tencent/rss/client/api/ShuffleWriteClient.java @@ -54,7 +54,7 @@ void reportShuffleResult( int bitmapNum); ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum, - int partitionNumPerRange, int dataReplica, Set requiredTags); + int partitionNumPerRange, Set requiredTags); Roaring64NavigableMap getShuffleResult(String clientType, Set shuffleServerInfoSet, String appId, int shuffleId, int partitionId); diff --git a/client/src/main/java/com/tencent/rss/client/factory/ShuffleClientFactory.java b/client/src/main/java/com/tencent/rss/client/factory/ShuffleClientFactory.java index 70843ab3..0b488b05 100644 --- a/client/src/main/java/com/tencent/rss/client/factory/ShuffleClientFactory.java +++ b/client/src/main/java/com/tencent/rss/client/factory/ShuffleClientFactory.java @@ -36,8 +36,10 @@ public static ShuffleClientFactory getInstance() { } public ShuffleWriteClient createShuffleWriteClient( - String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum) { - return new ShuffleWriteClientImpl(clientType, retryMax, retryIntervalMax, heartBeatThreadNum); + String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum, + int replica, int replicaWrite, int replicaRead) { + return new ShuffleWriteClientImpl(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, + replica, replicaWrite, replicaRead); } public ShuffleReadClient createShuffleReadClient(CreateShuffleReadClientRequest request) { diff --git a/client/src/main/java/com/tencent/rss/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/com/tencent/rss/client/impl/ShuffleWriteClientImpl.java index 3c232bae..3effe6a1 100644 --- a/client/src/main/java/com/tencent/rss/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/com/tencent/rss/client/impl/ShuffleWriteClientImpl.java @@ -79,14 +79,21 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient { private Set shuffleServerInfoSet = Sets.newConcurrentHashSet(); private CoordinatorClientFactory coordinatorClientFactory; private ExecutorService heartBeatExecutorService; + private int replica; + private int replicaWrite; + private int replicaRead; - public ShuffleWriteClientImpl(String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum) { + public ShuffleWriteClientImpl(String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum, + int replica, int replicaWrite, int replicaRead) { this.clientType = clientType; this.retryMax = retryMax; this.retryIntervalMax = retryIntervalMax; coordinatorClientFactory = new CoordinatorClientFactory(clientType); heartBeatExecutorService = Executors.newFixedThreadPool(heartBeatThreadNum, new ThreadFactoryBuilder().setDaemon(true).setNameFormat("client-heartbeat-%d").build()); + this.replica = replica; + this.replicaWrite = replicaWrite; + this.replicaRead = replicaRead; } private void sendShuffleDataAsync( @@ -95,30 +102,50 @@ private void sendShuffleDataAsync( Map> serverToBlockIds, Set successBlockIds, Set tempFailedBlockIds) { + + // maintain the count of blocks that have been sent to the server + Map blockIdsTracker = Maps.newConcurrentMap(); + serverToBlockIds.values().forEach( + blockList -> blockList.forEach(block -> blockIdsTracker.put(block, new AtomicInteger(0))) + ); + if (serverToBlocks != null) { serverToBlocks.entrySet().parallelStream().forEach(entry -> { ShuffleServerInfo ssi = entry.getKey(); try { + Map>> shuffleIdToBlocks = entry.getValue(); + // todo: compact unnecessary blocks that reach replicaWrite RssSendShuffleDataRequest request = new RssSendShuffleDataRequest( - appId, retryMax, retryIntervalMax, entry.getValue()); + appId, retryMax, retryIntervalMax, shuffleIdToBlocks); long s = System.currentTimeMillis(); RssSendShuffleDataResponse response = getShuffleServerClient(ssi).sendShuffleData(request); LOG.info("ShuffleWriteClientImpl sendShuffleData cost:" + (System.currentTimeMillis() - s)); if (response.getStatusCode() == ResponseStatusCode.SUCCESS) { - successBlockIds.addAll(serverToBlockIds.get(ssi)); + // mark a replica of block that has been sent + serverToBlockIds.get(ssi).forEach(block -> blockIdsTracker.get(block).incrementAndGet()); LOG.info("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId() + "] successfully"); } else { - tempFailedBlockIds.addAll(serverToBlockIds.get(ssi)); - LOG.error("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId() + LOG.warn("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId() + "] failed with statusCode[" + response.getStatusCode() + "], "); } } catch (Exception e) { - tempFailedBlockIds.addAll(serverToBlockIds.get(ssi)); - LOG.error("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId() + "] failed.", e); + LOG.warn("Send: " + serverToBlockIds.get(ssi).size() + " blocks to [" + ssi.getId() + "] failed.", e); } }); + + // check success and failed blocks according to the replicaWrite + blockIdsTracker.entrySet().forEach(blockCt -> { + long blockId = blockCt.getKey(); + int count = blockCt.getValue().get(); + if (count >= replicaWrite) { + successBlockIds.add(blockId); + } else { + tempFailedBlockIds.add(blockId); + } + } + ); } } @@ -162,7 +189,6 @@ public SendShuffleDataResult sendShuffleData(String appId, List fetchClientConf(int timeoutMs) { @Override public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, int partitionNum, - int partitionNumPerRange, int dataReplica, Set requiredTags) { + int partitionNumPerRange, Set requiredTags) { RssGetShuffleAssignmentsRequest request = new RssGetShuffleAssignmentsRequest( - appId, shuffleId, partitionNum, partitionNumPerRange, dataReplica, requiredTags); + appId, shuffleId, partitionNum, partitionNumPerRange, replica, requiredTags); RssGetShuffleAssignmentsResponse response = new RssGetShuffleAssignmentsResponse(ResponseStatusCode.INTERNAL_ERROR); for (CoordinatorClient coordinatorClient : coordinatorClients) { @@ -285,6 +311,7 @@ public void reportShuffleResult( groupedPartitions.get(ssi).add(entry.getKey()); } } + int successCnt = 0; for (Map.Entry> entry : groupedPartitions.entrySet()) { Map> requestBlockIds = Maps.newHashMap(); for (Integer partitionId : entry.getValue()) { @@ -298,20 +325,17 @@ public void reportShuffleResult( if (response.getStatusCode() == ResponseStatusCode.SUCCESS) { LOG.info("Report shuffle result to " + ssi + " for appId[" + appId + "], shuffleId[" + shuffleId + "] successfully"); + successCnt++; } else { - isSuccessful = false; LOG.warn("Report shuffle result to " + ssi + " for appId[" + appId + "], shuffleId[" + shuffleId + "] failed with " + response.getStatusCode()); - break; } } catch (Exception e) { - isSuccessful = false; LOG.warn("Report shuffle result is failed to " + ssi + " for appId[" + appId + "], shuffleId[" + shuffleId + "]"); - break; } } - if (!isSuccessful) { + if (successCnt < replicaWrite) { throw new RssException("Report shuffle result is failed for appId[" + appId + "], shuffleId[" + shuffleId + "]"); } @@ -324,14 +348,20 @@ public Roaring64NavigableMap getShuffleResult(String clientType, Set= replicaRead) { + isSuccessful = true; + break; + } } } catch (Exception e) { LOG.warn("Get shuffle result is failed from " + ssi @@ -407,7 +437,8 @@ private void throwExceptionIfNecessary(ClientResponse response, String errorMsg) } @VisibleForTesting - protected ShuffleServerClient getShuffleServerClient(ShuffleServerInfo shuffleServerInfo) { + public ShuffleServerClient getShuffleServerClient(ShuffleServerInfo shuffleServerInfo) { return ShuffleServerClientFactory.getInstance().getShuffleServerClient(clientType, shuffleServerInfo); } + } diff --git a/client/src/test/java/com/tencent/rss/client/impl/ShuffleWriteClientImplTest.java b/client/src/test/java/com/tencent/rss/client/impl/ShuffleWriteClientImplTest.java index d54f68ab..eeede10d 100644 --- a/client/src/test/java/com/tencent/rss/client/impl/ShuffleWriteClientImplTest.java +++ b/client/src/test/java/com/tencent/rss/client/impl/ShuffleWriteClientImplTest.java @@ -42,7 +42,7 @@ public class ShuffleWriteClientImplTest { @Test public void testSendData() { ShuffleWriteClientImpl shuffleWriteClient = - new ShuffleWriteClientImpl("GRPC", 3, 2000, 4); + new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 1, 1, 1); ShuffleServerClient mockShuffleServerClient = mock(ShuffleServerClient.class); ShuffleWriteClientImpl spyClient = spy(shuffleWriteClient); doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any()); diff --git a/common/src/main/java/com/tencent/rss/common/util/RssUtils.java b/common/src/main/java/com/tencent/rss/common/util/RssUtils.java index 39e7b5ac..57897730 100644 --- a/common/src/main/java/com/tencent/rss/common/util/RssUtils.java +++ b/common/src/main/java/com/tencent/rss/common/util/RssUtils.java @@ -240,4 +240,13 @@ public static List loadExtensions( } return extensions; } + + public static void checkQuorumSetting(int replica, int replicaWrite, int replicaRead) { + if (replica < 1 || replicaWrite > replica || replicaRead > replica) { + throw new RuntimeException("Replica config is invalid, recommend replica.write + replica.read > replica"); + } + if (replicaWrite + replicaRead <= replica) { + throw new RuntimeException("Replica config is unsafe, recommend replica.write + replica.read > replica"); + } + } } diff --git a/integration-test/common/pom.xml b/integration-test/common/pom.xml index c2ea4ca6..0fcb7f16 100644 --- a/integration-test/common/pom.xml +++ b/integration-test/common/pom.xml @@ -44,6 +44,13 @@ shuffle-server test + + com.tencent.rss + shuffle-server + test-jar + ${project.version} + test + com.tencent.rss coordinator diff --git a/integration-test/common/src/test/java/com/tencent/rss/test/IntegrationTestBase.java b/integration-test/common/src/test/java/com/tencent/rss/test/IntegrationTestBase.java index 3f377a31..c1483d6c 100644 --- a/integration-test/common/src/test/java/com/tencent/rss/test/IntegrationTestBase.java +++ b/integration-test/common/src/test/java/com/tencent/rss/test/IntegrationTestBase.java @@ -21,6 +21,7 @@ import com.google.common.collect.Lists; import com.tencent.rss.coordinator.CoordinatorConf; import com.tencent.rss.coordinator.CoordinatorServer; +import com.tencent.rss.server.MockedShuffleServer; import com.tencent.rss.server.ShuffleServer; import com.tencent.rss.server.ShuffleServerConf; import com.tencent.rss.storage.HdfsTestBase; @@ -100,6 +101,10 @@ protected static void createShuffleServer(ShuffleServerConf serverConf) throws E shuffleServers.add(new ShuffleServer(serverConf)); } + protected static void createMockedShuffleServer(ShuffleServerConf serverConf) throws Exception { + shuffleServers.add(new MockedShuffleServer(serverConf)); + } + protected static void createAndStartServers( ShuffleServerConf shuffleServerConf, CoordinatorConf coordinatorConf) throws Exception { diff --git a/integration-test/common/src/test/java/com/tencent/rss/test/QuorumTest.java b/integration-test/common/src/test/java/com/tencent/rss/test/QuorumTest.java new file mode 100644 index 00000000..a1527270 --- /dev/null +++ b/integration-test/common/src/test/java/com/tencent/rss/test/QuorumTest.java @@ -0,0 +1,498 @@ +/* + * Tencent is pleased to support the open source community by making + * Firestorm-Spark remote shuffle server available. + * + * Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * https://opensource.org/licenses/Apache-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.tencent.rss.test; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.google.common.io.Files; +import com.tencent.rss.client.factory.ShuffleServerClientFactory; +import com.tencent.rss.client.impl.ShuffleReadClientImpl; +import com.tencent.rss.client.impl.ShuffleWriteClientImpl; +import com.tencent.rss.client.impl.grpc.ShuffleServerGrpcClient; +import com.tencent.rss.client.response.CompressedShuffleBlock; +import com.tencent.rss.client.response.SendShuffleDataResult; +import com.tencent.rss.client.util.ClientType; +import com.tencent.rss.common.PartitionRange; +import com.tencent.rss.common.ShuffleBlockInfo; +import com.tencent.rss.common.ShuffleServerInfo; +import com.tencent.rss.common.util.RssUtils; +import com.tencent.rss.coordinator.CoordinatorConf; +import com.tencent.rss.coordinator.CoordinatorServer; +import com.tencent.rss.server.MockedGrpcServer; +import com.tencent.rss.server.MockedShuffleServer; +import com.tencent.rss.server.ShuffleServer; +import com.tencent.rss.server.ShuffleServerConf; +import com.tencent.rss.storage.util.StorageType; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.roaringbitmap.longlong.Roaring64NavigableMap; + +import java.io.File; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.*; + +public class QuorumTest extends ShuffleReadWriteBase { + + private static final String EXPECTED_EXCEPTION_MESSAGE = "Exception should be thrown"; + private static ShuffleServerInfo shuffleServerInfo0; + private static ShuffleServerInfo shuffleServerInfo1; + private static ShuffleServerInfo shuffleServerInfo2; + private static ShuffleServerInfo fakedShuffleServerInfo0; + private static ShuffleServerInfo fakedShuffleServerInfo1; + private static ShuffleServerInfo fakedShuffleServerInfo2; + private ShuffleWriteClientImpl shuffleWriteClientImpl; + + + public static MockedShuffleServer createServer(int id) throws Exception { + ShuffleServerConf shuffleServerConf = getShuffleServerConf(); + shuffleServerConf.setLong("rss.server.app.expired.withoutHeartbeat", 4000); + shuffleServerConf.setLong("rss.server.heartbeat.interval", 5000); + File tmpDir = Files.createTempDir(); + tmpDir.deleteOnExit(); + File dataDir1 = new File(tmpDir, id + "_1"); + File dataDir2 = new File(tmpDir, id + "_2"); + String basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath(); + shuffleServerConf.setString("rss.storage.type", StorageType.MEMORY_LOCALFILE.name()); + shuffleServerConf.setInteger("rss.rpc.server.port", SHUFFLE_SERVER_PORT + id); + shuffleServerConf.setInteger("rss.jetty.http.port", 19081 + id * 100); + shuffleServerConf.setString("rss.storage.basePath", basePath); + return new MockedShuffleServer(shuffleServerConf); + } + + @BeforeClass + public static void initCluster() throws Exception { + CoordinatorConf coordinatorConf = getCoordinatorConf(); + createCoordinatorServer(coordinatorConf); + + ShuffleServerConf shuffleServerConf = getShuffleServerConf(); + shuffleServerConf.setLong("rss.server.app.expired.withoutHeartbeat", 1000); + shuffleServerConf.setLong("rss.server.app.expired.withoutHeartbeat", 1000); + File tmpDir = Files.createTempDir(); + tmpDir.deleteOnExit(); + + shuffleServers.add(createServer(0)); + shuffleServers.add(createServer(1)); + shuffleServers.add(createServer(2)); + + shuffleServerInfo0 = + new ShuffleServerInfo("127.0.0.1-20001", shuffleServers.get(0).getIp(), SHUFFLE_SERVER_PORT); + shuffleServerInfo1 = + new ShuffleServerInfo("127.0.0.1-20002", shuffleServers.get(1).getIp(), SHUFFLE_SERVER_PORT + 1); + shuffleServerInfo2 = + new ShuffleServerInfo("127.0.0.1-20003", shuffleServers.get(1).getIp(), SHUFFLE_SERVER_PORT + 2); + for (CoordinatorServer coordinator : coordinators) { + coordinator.start(); + } + for (ShuffleServer shuffleServer : shuffleServers) { + shuffleServer.start(); + } + Thread.sleep(2000); + } + + public static void cleanCluster() throws Exception { + for (CoordinatorServer coordinator : coordinators) { + coordinator.stopServer(); + } + for (ShuffleServer shuffleServer : shuffleServers) { + shuffleServer.stopServer(); + } + shuffleServers = Lists.newArrayList(); + coordinators = Lists.newArrayList(); + } + + @Before + public void InitEnv() throws Exception { + // spark.rss.data.replica=3 + // spark.rss.data.replica.write=2 + // spark.rss.data.replica.read=2 + ((ShuffleServerGrpcClient)ShuffleServerClientFactory + .getInstance().getShuffleServerClient("GRPC", shuffleServerInfo0)).adjustTimeout(10); + ((ShuffleServerGrpcClient)ShuffleServerClientFactory + .getInstance().getShuffleServerClient("GRPC", shuffleServerInfo1)).adjustTimeout(10); + ((ShuffleServerGrpcClient)ShuffleServerClientFactory + .getInstance().getShuffleServerClient("GRPC", shuffleServerInfo2)).adjustTimeout(10); + } + + @After + public void cleanEnv() throws Exception { + if (shuffleWriteClientImpl != null) { + shuffleWriteClientImpl.close(); + } + cleanCluster(); + initCluster(); + } + + + @Test + public void QuorumConfigTest() throws Exception { + try { + RssUtils.checkQuorumSetting(3, 1, 1); + fail(EXPECTED_EXCEPTION_MESSAGE); + } catch (Exception e) { + assertTrue(e.getMessage().startsWith("Replica config is unsafe")); + } + try { + RssUtils.checkQuorumSetting(3, 4, 1); + fail(EXPECTED_EXCEPTION_MESSAGE); + } catch (Exception e) { + assertTrue(e.getMessage().startsWith("Replica config is invalid")); + } + try { + RssUtils.checkQuorumSetting(0, 0, 0); + fail(EXPECTED_EXCEPTION_MESSAGE); + } catch (Exception e) { + assertTrue(e.getMessage().startsWith("Replica config is invalid")); + } + } + + @Test + public void rpcFailedTest() throws Exception { + String testAppId = "rpcFailedTest"; + registerShuffleServer(testAppId); + Map expectedData = Maps.newHashMap(); + Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + + // simulator of failed servers + fakedShuffleServerInfo0 = + new ShuffleServerInfo("127.0.0.1-20001", shuffleServers.get(0).getIp(), SHUFFLE_SERVER_PORT + 100); + fakedShuffleServerInfo1 = + new ShuffleServerInfo("127.0.0.1-20002", shuffleServers.get(1).getIp(), SHUFFLE_SERVER_PORT + 200); + fakedShuffleServerInfo2 = + new ShuffleServerInfo("127.0.0.1-20003", shuffleServers.get(2).getIp(), SHUFFLE_SERVER_PORT + 300); + + // case1: When only 1 server is failed, the block sending should success + List blocks = createShuffleBlockList( + 0, 0, 0, 3, 25, blockIdBitmap, + expectedData, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, fakedShuffleServerInfo2)); + + Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0); + SendShuffleDataResult result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks); + Roaring64NavigableMap failedBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + Roaring64NavigableMap succBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + for (Long blockId : result.getSuccessBlockIds()) { + succBlockIdBitmap.addLong(blockId); + } + for (Long blockId : result.getFailedBlockIds()) { + failedBlockIdBitmap.addLong(blockId); + } + assertEquals(0, failedBlockIdBitmap.getLongCardinality()); + assertEquals(blockIdBitmap, succBlockIdBitmap); + + ShuffleReadClientImpl readClient = new ShuffleReadClientImpl(StorageType.MEMORY_LOCALFILE.name(), + testAppId, 0, 0, 100, 1, + 10, 1000, "", blockIdBitmap, taskIdBitmap, + Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, fakedShuffleServerInfo2), null); + // The data should be read + validateResult(readClient, expectedData); + + // case2: When 2 servers are failed, the block sending should fail + blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + blocks = createShuffleBlockList( + 0, 0, 0, 3, 25, blockIdBitmap, + expectedData, Lists.newArrayList(shuffleServerInfo0, fakedShuffleServerInfo1, fakedShuffleServerInfo2)); + result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks); + failedBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + succBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + for (Long blockId : result.getSuccessBlockIds()) { + succBlockIdBitmap.addLong(blockId); + } + for (Long blockId : result.getFailedBlockIds()) { + failedBlockIdBitmap.addLong(blockId); + } + assertEquals(blockIdBitmap, failedBlockIdBitmap); + assertEquals(0, succBlockIdBitmap.getLongCardinality()); + + // The client should not read any data, because write is failed + assertEquals(readClient.readShuffleBlockData(), null); + } + + private void enableTimeout(MockedShuffleServer server, long timeout) { + ((MockedGrpcServer)server.getServer()).getService() + .enableMockedTimeout(timeout); + } + + private void disableTimeout(MockedShuffleServer server) { + ((MockedGrpcServer)server.getServer()).getService() + .disableMockedTimeout(); + } + + private void registerShuffleServer(String testAppId){ + + shuffleWriteClientImpl = new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, + 3, 2, 2); + shuffleWriteClientImpl.registerShuffle(shuffleServerInfo0, + testAppId, 0, Lists.newArrayList(new PartitionRange(0, 0))); + shuffleWriteClientImpl.registerShuffle(shuffleServerInfo1, + testAppId, 0, Lists.newArrayList(new PartitionRange(0, 0))); + shuffleWriteClientImpl.registerShuffle(shuffleServerInfo2, + testAppId, 0, Lists.newArrayList(new PartitionRange(0, 0))); + } + + @Test + public void case1() throws Exception { + String testAppId = "case1"; + registerShuffleServer(testAppId); + Map expectedData = Maps.newHashMap(); + Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + + // only 1 server is timout, the block sending should success + enableTimeout((MockedShuffleServer)shuffleServers.get(2), 100); + + List blocks = createShuffleBlockList( + 0, 0, 0, 3, 25, blockIdBitmap, + expectedData, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)); + + // report result should success + Map> partitionToBlockIds = Maps.newHashMap(); + partitionToBlockIds.put(0, Lists.newArrayList(blockIdBitmap.stream().iterator())); + Map> partitionToServers = Maps.newHashMap(); + partitionToServers.put(0, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)); + shuffleWriteClientImpl.reportShuffleResult(partitionToServers, testAppId, 0, 0L, + partitionToBlockIds, 1); + Roaring64NavigableMap report = shuffleWriteClientImpl.getShuffleResult("GRPC", + Sets.newHashSet(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2), + testAppId, 0, 0); + assertEquals(report, blockIdBitmap); + + // data read should success + Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0); + SendShuffleDataResult result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks); + Roaring64NavigableMap succBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + for (Long blockId : result.getSuccessBlockIds()) { + succBlockIdBitmap.addLong(blockId); + } + assertEquals(0, result.getFailedBlockIds().size()); + assertEquals(blockIdBitmap, succBlockIdBitmap); + + ShuffleReadClientImpl readClient = new ShuffleReadClientImpl(StorageType.MEMORY_LOCALFILE.name(), + testAppId, 0, 0, 100, 1, + 10, 1000, "", blockIdBitmap, taskIdBitmap, + Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2), null); + validateResult(readClient, expectedData); + } + + @Test + public void case2() throws Exception { + String testAppId = "case2"; + registerShuffleServer(testAppId); + + Map expectedData = Maps.newHashMap(); + Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0); + // When 2 servers are timeout, the block sending should fail + enableTimeout((MockedShuffleServer)shuffleServers.get(1), 100); + enableTimeout((MockedShuffleServer)shuffleServers.get(2), 100); + + List blocks = createShuffleBlockList( + 0, 0, 0, 3, 25, blockIdBitmap, + expectedData, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)); + SendShuffleDataResult result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks); + Roaring64NavigableMap failedBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + for (Long blockId : result.getFailedBlockIds()) { + failedBlockIdBitmap.addLong(blockId); + } + assertEquals(blockIdBitmap, failedBlockIdBitmap); + assertEquals(0, result.getSuccessBlockIds().size()); + + // report result should fail + Map> partitionToBlockIds = Maps.newHashMap(); + Map> partitionToServers = Maps.newHashMap(); + partitionToBlockIds.put(0, Lists.newArrayList(blockIdBitmap.stream().iterator())); + partitionToServers.put(0, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)); + try { + shuffleWriteClientImpl.reportShuffleResult(partitionToServers, testAppId, 0, 0L, + partitionToBlockIds, 1); + fail(EXPECTED_EXCEPTION_MESSAGE); + } catch (Exception e){ + assertTrue(e.getMessage().startsWith("Report shuffle result is failed")); + } + // get result should also fail + try { + shuffleWriteClientImpl.getShuffleResult("GRPC", + Sets.newHashSet(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2), + testAppId, 0, 0); + fail(EXPECTED_EXCEPTION_MESSAGE); + } catch (Exception e) { + assertTrue(e.getMessage().startsWith("Get shuffle result is failed")); + } + } + + @Test + public void case3() throws Exception { + String testAppId = "case3"; + registerShuffleServer(testAppId); + disableTimeout((MockedShuffleServer)shuffleServers.get(0)); + disableTimeout((MockedShuffleServer)shuffleServers.get(1)); + disableTimeout((MockedShuffleServer)shuffleServers.get(2)); + + // When 1 server is timeout and 1 server is failed after sending, the block sending should fail + enableTimeout((MockedShuffleServer)shuffleServers.get(2), 100); + + Map expectedData = Maps.newHashMap(); + Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + List blocks = createShuffleBlockList( + 0, 0, 0, 3, 25, blockIdBitmap, + expectedData, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)); + SendShuffleDataResult result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks); + Roaring64NavigableMap failedBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + Roaring64NavigableMap succBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + for (Long blockId : result.getSuccessBlockIds()) { + succBlockIdBitmap.addLong(blockId); + } + for (Long blockId : result.getFailedBlockIds()) { + failedBlockIdBitmap.addLong(blockId); + } + assertEquals(blockIdBitmap, succBlockIdBitmap); + assertEquals(0, failedBlockIdBitmap.getLongCardinality()); + + Map> partitionToBlockIds = Maps.newHashMap(); + partitionToBlockIds.put(0, Lists.newArrayList(blockIdBitmap.stream().iterator())); + Map> partitionToServers = Maps.newHashMap(); + partitionToServers.put(0, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)); + shuffleWriteClientImpl.reportShuffleResult(partitionToServers, testAppId, 0, 0L, + partitionToBlockIds, 1); + + Roaring64NavigableMap report = shuffleWriteClientImpl.getShuffleResult("GRPC", + Sets.newHashSet(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2), + testAppId, 0, 0); + assertEquals(report, blockIdBitmap); + + // let this server be failed, the reading will be also be failed + shuffleServers.get(1).stopServer(); + try { + report = shuffleWriteClientImpl.getShuffleResult("GRPC", + Sets.newHashSet(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2), + testAppId, 0, 0); + fail(EXPECTED_EXCEPTION_MESSAGE); + } catch (Exception e) { + assertTrue(e.getMessage().startsWith("Get shuffle result is failed")); + } + + // When the timeout of one server is recovered, the block sending should success + disableTimeout((MockedShuffleServer)shuffleServers.get(2)); + report = shuffleWriteClientImpl.getShuffleResult("GRPC", + Sets.newHashSet(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2), + testAppId, 0, 0); + assertEquals(report, blockIdBitmap); + } + + @Test + public void case4() throws Exception { + String testAppId = "case4"; + registerShuffleServer(testAppId); + // when 1 server is timeout, the sending multiple blocks should success + enableTimeout((MockedShuffleServer)shuffleServers.get(2), 100); + + Map expectedData = Maps.newHashMap(); + Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0); + for (int i = 0; i < 5; i++) { + List blocks = createShuffleBlockList( + 0, 0, 0, 3, 25, blockIdBitmap, + expectedData, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)); + SendShuffleDataResult result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks); + assertTrue(result.getSuccessBlockIds().size() == 3); + assertTrue(result.getFailedBlockIds().size() == 0); + } + + ShuffleReadClientImpl readClient = new ShuffleReadClientImpl(StorageType.MEMORY_LOCALFILE.name(), + testAppId, 0, 0, 100, 1, + 10, 1000, "", blockIdBitmap, taskIdBitmap, + Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2), null); + validateResult(readClient, expectedData); + } + + @Test + public void case5() throws Exception { + // this case is to simulate server restarting. + String testAppId = "case5"; + registerShuffleServer(testAppId); + + Map expectedData = Maps.newHashMap(); + Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + List blocks = createShuffleBlockList( + 0, 0, 0, 3, 25, blockIdBitmap, + expectedData, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)); + + // report result should success + Map> partitionToBlockIds = Maps.newHashMap(); + partitionToBlockIds.put(0, Lists.newArrayList(blockIdBitmap.stream().iterator())); + Map> partitionToServers = Maps.newHashMap(); + partitionToServers.put(0, Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)); + shuffleWriteClientImpl.reportShuffleResult(partitionToServers, testAppId, 0, 0L, + partitionToBlockIds, 1); + Roaring64NavigableMap report = shuffleWriteClientImpl.getShuffleResult("GRPC", + Sets.newHashSet(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2), + testAppId, 0, 0); + assertEquals(report, blockIdBitmap); + + // data read should success + Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0); + SendShuffleDataResult result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks); + Roaring64NavigableMap succBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + for (Long blockId : result.getSuccessBlockIds()) { + succBlockIdBitmap.addLong(blockId); + } + assertEquals(0, result.getFailedBlockIds().size()); + assertEquals(blockIdBitmap, succBlockIdBitmap); + + // when one server is restarted, getShuffleResult should success + shuffleServers.get(1).stopServer(); + shuffleServers.set(1, createServer(1)); + shuffleServers.get(1).start(); + report = shuffleWriteClientImpl.getShuffleResult("GRPC", + Sets.newHashSet(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2), + testAppId, 0, 0); + assertEquals(report, blockIdBitmap); + + // when two servers are restarted, getShuffleResult should fail + shuffleServers.get(2).stopServer(); + shuffleServers.set(2, createServer(2)); + shuffleServers.get(2).start(); + try { + report = shuffleWriteClientImpl.getShuffleResult("GRPC", + Sets.newHashSet(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2), + testAppId, 0, 0); + fail(EXPECTED_EXCEPTION_MESSAGE); + } catch (Exception e) { + assertTrue(e.getMessage().startsWith("Get shuffle result is failed")); + } + } + + protected void validateResult(ShuffleReadClientImpl readClient, Map expectedData, + Roaring64NavigableMap blockIdBitmap) { + CompressedShuffleBlock csb = readClient.readShuffleBlockData(); + Roaring64NavigableMap matched = Roaring64NavigableMap.bitmapOf(); + while (csb != null && csb.getByteBuffer() != null) { + for (Map.Entry entry : expectedData.entrySet()) { + if (compareByte(entry.getValue(), csb.getByteBuffer())) { + matched.addLong(entry.getKey()); + break; + } + } + csb = readClient.readShuffleBlockData(); + } + assertTrue(blockIdBitmap.equals(matched)); + } +} diff --git a/integration-test/common/src/test/java/com/tencent/rss/test/ShuffleServerGrpcTest.java b/integration-test/common/src/test/java/com/tencent/rss/test/ShuffleServerGrpcTest.java index 1e73d7ae..1ce926c4 100644 --- a/integration-test/common/src/test/java/com/tencent/rss/test/ShuffleServerGrpcTest.java +++ b/integration-test/common/src/test/java/com/tencent/rss/test/ShuffleServerGrpcTest.java @@ -96,7 +96,7 @@ public void createClient() { public void clearResourceTest() throws Exception { final ShuffleWriteClient shuffleWriteClient = ShuffleClientFactory.getInstance().createShuffleWriteClient( - "GRPC", 2, 10000L, 4); + "GRPC", 2, 10000L, 4, 1, 1, 1); shuffleWriteClient.registerCoordinators("127.0.0.1:19999"); shuffleWriteClient.registerShuffle( new ShuffleServerInfo("127.0.0.1-20001", "127.0.0.1", 20001), diff --git a/integration-test/common/src/test/java/com/tencent/rss/test/ShuffleWithRssClientTest.java b/integration-test/common/src/test/java/com/tencent/rss/test/ShuffleWithRssClientTest.java index 1c74a51c..6e56a89f 100644 --- a/integration-test/common/src/test/java/com/tencent/rss/test/ShuffleWithRssClientTest.java +++ b/integration-test/common/src/test/java/com/tencent/rss/test/ShuffleWithRssClientTest.java @@ -85,7 +85,8 @@ public static void setupServers() throws Exception { @Before public void createClient() { - shuffleWriteClientImpl = new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1); + shuffleWriteClientImpl = new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, + 1, 1, 1); } @After @@ -116,24 +117,25 @@ public void rpcFailTest() throws Exception { for (Long blockId : result.getSuccessBlockIds()) { succBlockIdBitmap.addLong(blockId); } - assertEquals(blockIdBitmap, failedBlockIdBitmap); + // There will no failed blocks when replica=2 + assertEquals(failedBlockIdBitmap.getLongCardinality(), 0); assertEquals(blockIdBitmap, succBlockIdBitmap); boolean commitResult = shuffleWriteClientImpl.sendCommit(Sets.newHashSet( shuffleServerInfo1, fakeShuffleServerInfo), testAppId, 0, 2); assertFalse(commitResult); + // Report will success when replica=2 Map> ptb = Maps.newHashMap(); - ptb.put(1, Lists.newArrayList(1L)); - try { - Map> partitionToServers = Maps.newHashMap(); - partitionToServers.put(1, Lists.newArrayList( - shuffleServerInfo1, fakeShuffleServerInfo)); - shuffleWriteClientImpl.reportShuffleResult(partitionToServers, testAppId, 0, 0, ptb, 2); - fail(EXPECTED_EXCEPTION_MESSAGE); - } catch (Exception e) { - assertTrue(e.getMessage().contains("Report shuffle result is failed for")); - } + ptb.put(0, Lists.newArrayList(blockIdBitmap.stream().iterator())); + Map> partitionToServers = Maps.newHashMap(); + partitionToServers.put(0, Lists.newArrayList( + shuffleServerInfo1, fakeShuffleServerInfo)); + shuffleWriteClientImpl.reportShuffleResult(partitionToServers, testAppId, 0, 0, ptb, 2); + Roaring64NavigableMap report = shuffleWriteClientImpl.getShuffleResult("GRPC", + Sets.newHashSet(shuffleServerInfo1, fakeShuffleServerInfo), + testAppId, 0, 0); + assertEquals(blockIdBitmap, report); } @Test diff --git a/internal-client/src/main/java/com/tencent/rss/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/com/tencent/rss/client/impl/grpc/ShuffleServerGrpcClient.java index a6fcd60e..46857630 100644 --- a/internal-client/src/main/java/com/tencent/rss/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/com/tencent/rss/client/impl/grpc/ShuffleServerGrpcClient.java @@ -23,6 +23,7 @@ import java.util.Random; import java.util.concurrent.TimeUnit; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import com.google.protobuf.ByteString; import org.slf4j.Logger; @@ -90,6 +91,7 @@ public class ShuffleServerGrpcClient extends GrpcClient implements ShuffleServer private static final Logger LOG = LoggerFactory.getLogger(ShuffleServerGrpcClient.class); private static final long FAILED_REQUIRE_ID = -1; private static final long RPC_TIMEOUT_DEFAULT_MS = 60000; + private long rpcTimeout = RPC_TIMEOUT_DEFAULT_MS; private ShuffleServerBlockingStub blockingStub; public ShuffleServerGrpcClient(String host, int port) { @@ -267,7 +269,7 @@ private SendShuffleDataResponse doSendData(SendShuffleDataRequest rpcRequest) { while (retryNum < maxRetryAttempts) { try { SendShuffleDataResponse response = blockingStub.withDeadlineAfter( - RPC_TIMEOUT_DEFAULT_MS, TimeUnit.MILLISECONDS).sendShuffleData(rpcRequest); + rpcTimeout, TimeUnit.MILLISECONDS).sendShuffleData(rpcRequest); return response; } catch (Exception e) { retryNum++; @@ -373,7 +375,7 @@ private ReportShuffleResultResponse doReportShuffleResult(ReportShuffleResultReq while (retryNum < maxRetryAttempts) { try { ReportShuffleResultResponse response = blockingStub.withDeadlineAfter( - RPC_TIMEOUT_DEFAULT_MS, TimeUnit.MILLISECONDS).reportShuffleResult(rpcRequest); + rpcTimeout, TimeUnit.MILLISECONDS).reportShuffleResult(rpcRequest); return response; } catch (Exception e) { retryNum++; @@ -392,7 +394,9 @@ public RssGetShuffleResultResponse getShuffleResult(RssGetShuffleResultRequest r .setShuffleId(request.getShuffleId()) .setPartitionId(request.getPartitionId()) .build(); - GetShuffleResultResponse rpcResponse = blockingStub.getShuffleResult(rpcRequest); + GetShuffleResultResponse rpcResponse = blockingStub + .withDeadlineAfter(rpcTimeout, TimeUnit.MILLISECONDS) + .getShuffleResult(rpcRequest); StatusCode statusCode = rpcResponse.getStatus(); RssGetShuffleResultResponse response; @@ -547,4 +551,9 @@ private List toBufferSegments(List block } return ret; } + + @VisibleForTesting + public void adjustTimeout(long timeout) { + rpcTimeout = timeout; + } } diff --git a/server/src/main/java/com/tencent/rss/server/ShuffleServer.java b/server/src/main/java/com/tencent/rss/server/ShuffleServer.java index 3a6a6ffd..384c2f73 100644 --- a/server/src/main/java/com/tencent/rss/server/ShuffleServer.java +++ b/server/src/main/java/com/tencent/rss/server/ShuffleServer.java @@ -22,6 +22,7 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import io.prometheus.client.CollectorRegistry; @@ -155,8 +156,7 @@ private void initialization() throws Exception { shuffleTaskManager = new ShuffleTaskManager(shuffleServerConf, shuffleFlushManager, shuffleBufferManager, storageManager); - ShuffleServerFactory shuffleServerFactory = new ShuffleServerFactory(this); - server = shuffleServerFactory.getServer(); + setServer(); // it's the system tag for server's version tags.add(Constants.SHUFFLE_SERVER_VERSION); @@ -220,6 +220,12 @@ public ServerInterface getServer() { return server; } + @VisibleForTesting + public void setServer() { + ShuffleServerFactory shuffleServerFactory = new ShuffleServerFactory(this); + server = shuffleServerFactory.getServer(); + } + public void setServer(ServerInterface server) { this.server = server; } diff --git a/server/src/main/java/com/tencent/rss/server/ShuffleServerFactory.java b/server/src/main/java/com/tencent/rss/server/ShuffleServerFactory.java index 2e4a127d..f3a5fefe 100644 --- a/server/src/main/java/com/tencent/rss/server/ShuffleServerFactory.java +++ b/server/src/main/java/com/tencent/rss/server/ShuffleServerFactory.java @@ -41,6 +41,14 @@ public ServerInterface getServer() { } } + public ShuffleServer getShuffleServer() { + return shuffleServer; + } + + public ShuffleServerConf getConf() { + return conf; + } + enum ServerType { GRPC } diff --git a/server/src/test/java/com/tencent/rss/server/MockedGrpcServer.java b/server/src/test/java/com/tencent/rss/server/MockedGrpcServer.java new file mode 100644 index 00000000..bc0fac94 --- /dev/null +++ b/server/src/test/java/com/tencent/rss/server/MockedGrpcServer.java @@ -0,0 +1,35 @@ +/* + * Tencent is pleased to support the open source community by making + * Firestorm-Spark remote shuffle server available. + * + * Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * https://opensource.org/licenses/Apache-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.tencent.rss.server; + +import com.tencent.rss.common.config.RssBaseConf; +import com.tencent.rss.common.metrics.GRPCMetrics; +import com.tencent.rss.common.rpc.GrpcServer; + +public class MockedGrpcServer extends GrpcServer { + MockedShuffleServerGrpcService service; + public MockedGrpcServer(RssBaseConf conf, MockedShuffleServerGrpcService service, + GRPCMetrics grpcMetrics) { + super(conf, service, grpcMetrics); + this.service = service; + } + public MockedShuffleServerGrpcService getService() { + return service; + } +} diff --git a/server/src/test/java/com/tencent/rss/server/MockedShuffleServer.java b/server/src/test/java/com/tencent/rss/server/MockedShuffleServer.java new file mode 100644 index 00000000..646f3c01 --- /dev/null +++ b/server/src/test/java/com/tencent/rss/server/MockedShuffleServer.java @@ -0,0 +1,30 @@ +/* + * Tencent is pleased to support the open source community by making + * Firestorm-Spark remote shuffle server available. + * + * Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * https://opensource.org/licenses/Apache-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.tencent.rss.server; + +public class MockedShuffleServer extends ShuffleServer { + public MockedShuffleServer(ShuffleServerConf shuffleServerConf) throws Exception { + super(shuffleServerConf); + } + + @Override + public void setServer() { + setServer(new MockedShuffleServerFactory(this).getServer()); + } +} diff --git a/server/src/test/java/com/tencent/rss/server/MockedShuffleServerFactory.java b/server/src/test/java/com/tencent/rss/server/MockedShuffleServerFactory.java new file mode 100644 index 00000000..a61cd9fd --- /dev/null +++ b/server/src/test/java/com/tencent/rss/server/MockedShuffleServerFactory.java @@ -0,0 +1,45 @@ +/* + * Tencent is pleased to support the open source community by making + * Firestorm-Spark remote shuffle server available. + * + * Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * https://opensource.org/licenses/Apache-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.tencent.rss.server; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.tencent.rss.common.rpc.ServerInterface; + +public class MockedShuffleServerFactory extends ShuffleServerFactory { + private static final Logger LOG = LoggerFactory.getLogger(MockedShuffleServerFactory.class); + public MockedShuffleServerFactory(MockedShuffleServer shuffleServer) { + super(shuffleServer); + } + + @Override + public ServerInterface getServer() { + ShuffleServerConf conf = getConf(); + ShuffleServer shuffleServer = getShuffleServer(); + String type = conf.getString(ShuffleServerConf.RPC_SERVER_TYPE); + if (type.equals(ServerType.GRPC.name())) { + return new MockedGrpcServer(conf, new MockedShuffleServerGrpcService(shuffleServer), + shuffleServer.getGrpcMetrics()); + } else { + throw new UnsupportedOperationException("Unsupported server type " + type); + } + } + +} diff --git a/server/src/test/java/com/tencent/rss/server/MockedShuffleServerGrpcService.java b/server/src/test/java/com/tencent/rss/server/MockedShuffleServerGrpcService.java new file mode 100644 index 00000000..dc51b444 --- /dev/null +++ b/server/src/test/java/com/tencent/rss/server/MockedShuffleServerGrpcService.java @@ -0,0 +1,77 @@ +/* + * Tencent is pleased to support the open source community by making + * Firestorm-Spark remote shuffle server available. + * + * Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * https://opensource.org/licenses/Apache-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package com.tencent.rss.server; + +import com.google.common.util.concurrent.Uninterruptibles; +import io.grpc.stub.StreamObserver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.tencent.rss.proto.RssProtos; + +import java.util.concurrent.TimeUnit; + +public class MockedShuffleServerGrpcService extends ShuffleServerGrpcService { + + private static final Logger LOG = LoggerFactory.getLogger(MockedShuffleServerGrpcService.class); + + private long mockedTimeout = -1L; + + public void enableMockedTimeout(long timeout) { + mockedTimeout = timeout; + } + + public void disableMockedTimeout() { + mockedTimeout = -1; + } + + public MockedShuffleServerGrpcService(ShuffleServer shuffleServer) { + super(shuffleServer); + } + + @Override + public void sendShuffleData(RssProtos.SendShuffleDataRequest request, + StreamObserver responseObserver) { + if (mockedTimeout > 0) { + LOG.info("Add a mocked timeout on sendShuffleData"); + Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS); + } + super.sendShuffleData(request, responseObserver); + } + + @Override + public void reportShuffleResult(RssProtos.ReportShuffleResultRequest request, + StreamObserver responseObserver) { + if (mockedTimeout > 0) { + LOG.info("Add a mocked timeout on reportShuffleResult"); + Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS); + } + super.reportShuffleResult(request, responseObserver); + } + + @Override + public void getShuffleResult(RssProtos.GetShuffleResultRequest request, + StreamObserver responseObserver) { + if (mockedTimeout > 0) { + LOG.info("Add a mocked timeout on getShuffleResult"); + Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS); + } + super.getShuffleResult(request, responseObserver); + } +}