Skip to content

Commit

Permalink
fix ut and refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
wangshengjie123 committed Mar 16, 2024
1 parent 89254fc commit 599be24
Show file tree
Hide file tree
Showing 17 changed files with 130 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,33 +141,33 @@ protected Compressor initialValue() {

protected static class ReduceFileGroups {
public Map<Integer, Set<PartitionLocation>> partitionGroups;
public Set<PushFailedBatch> pushFailedBatchSet;
public Map<String, Set<PushFailedBatch>> pushFailedBatches;
public int[] mapAttempts;
public Set<Integer> partitionIds;

ReduceFileGroups(
Map<Integer, Set<PartitionLocation>> partitionGroups,
int[] mapAttempts,
Set<Integer> partitionIds,
Set<PushFailedBatch> pushFailedBatches) {
Map<String, Set<PushFailedBatch>> pushFailedBatches) {
this.partitionGroups = partitionGroups;
this.mapAttempts = mapAttempts;
this.partitionIds = partitionIds;
this.pushFailedBatchSet = pushFailedBatches;
this.pushFailedBatches = pushFailedBatches;
}

public ReduceFileGroups() {
this.partitionGroups = null;
this.mapAttempts = null;
this.partitionIds = null;
this.pushFailedBatchSet = null;
this.pushFailedBatches = null;
}

public void update(ReduceFileGroups fileGroups) {
partitionGroups = fileGroups.partitionGroups;
mapAttempts = fileGroups.mapAttempts;
partitionIds = fileGroups.partitionIds;
pushFailedBatchSet = fileGroups.pushFailedBatchSet;
pushFailedBatches = fileGroups.pushFailedBatches;
}
}

Expand Down Expand Up @@ -1041,8 +1041,8 @@ public void onSuccess(ByteBuffer response) {
nextBatchId);
if (dataPushFailureTrackingEnabled) {
pushState.addFailedBatch(
new PushFailedBatch(
mapId, attemptId, nextBatchId, partitionId, latest.getEpoch()));
latest.getUniqueId(),
new PushFailedBatch(mapId, attemptId, nextBatchId, latest.getEpoch()));
}
ReviveRequest reviveRequest =
new ReviveRequest(
Expand Down Expand Up @@ -1112,8 +1112,8 @@ public void onSuccess(ByteBuffer response) {
public void onFailure(Throwable e) {
if (dataPushFailureTrackingEnabled) {
pushState.addFailedBatch(
new PushFailedBatch(
mapId, attemptId, nextBatchId, partitionId, latest.getEpoch()));
latest.getUniqueId(),
new PushFailedBatch(mapId, attemptId, nextBatchId, latest.getEpoch()));
}
if (pushState.exception.get() != null) {
return;
Expand Down Expand Up @@ -1417,8 +1417,8 @@ public void onSuccess(ByteBuffer response) {
if (dataPushFailureTrackingEnabled) {
for (int i = 0; i < numBatches; i++) {
pushState.addFailedBatch(
new PushFailedBatch(
mapId, attemptId, batchIds[i], partitionIds[i], epochs[i]));
partitionUniqueIds[i],
new PushFailedBatch(mapId, attemptId, batchIds[i], epochs[i]));
}
}
ReviveRequest[] requests =
Expand Down Expand Up @@ -1481,7 +1481,8 @@ public void onFailure(Throwable e) {
if (dataPushFailureTrackingEnabled) {
for (int i = 0; i < numBatches; i++) {
pushState.addFailedBatch(
new PushFailedBatch(mapId, attemptId, batchIds[i], partitionIds[i], epochs[i]));
partitionUniqueIds[i],
new PushFailedBatch(mapId, attemptId, batchIds[i], epochs[i]));
}
}
if (pushState.exception.get() != null) {
Expand Down Expand Up @@ -1750,7 +1751,7 @@ public CelebornInputStream readPartition(
shuffleKey,
fileGroups.partitionGroups.get(partitionId).toArray(new PartitionLocation[0]),
fileGroups.mapAttempts,
fileGroups.pushFailedBatchSet,
fileGroups.pushFailedBatches,
attemptNumber,
startMapIndex,
endMapIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public static CelebornInputStream create(
String shuffleKey,
PartitionLocation[] locations,
int[] attempts,
Set<PushFailedBatch> failedBatchSet,
Map<String, Set<PushFailedBatch>> failedBatchSet,
int attemptNumber,
int startMapIndex,
int endMapIndex,
Expand All @@ -73,8 +73,11 @@ public static CelebornInputStream create(
// if startMapIndex > endMapIndex, means partition is skew partition.
// locations will split to sub-partitions with startMapIndex size.
PartitionLocation[] filterLocations = locations;
if (conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex) {
filterLocations = getSkewPartitionLocations(locations, startMapIndex, endMapIndex);
boolean splitSkewPartitionWithoutMapRange =
conf.clientPushFailureTrackingEnabled() && startMapIndex > endMapIndex;
if (splitSkewPartitionWithoutMapRange) {
filterLocations = getSubSkewPartitionLocations(locations, startMapIndex, endMapIndex);
endMapIndex = Integer.MAX_VALUE;
}
return new CelebornInputStreamImpl(
conf,
Expand All @@ -92,11 +95,12 @@ public static CelebornInputStream create(
shuffleId,
partitionId,
exceptionMaker,
splitSkewPartitionWithoutMapRange,
metricsCallback);
}
}

public static PartitionLocation[] getSkewPartitionLocations(
public static PartitionLocation[] getSubSkewPartitionLocations(
PartitionLocation[] locations, int subPartitionSize, int subPartitionIndex) {
Set<PartitionLocation> sortSet =
new TreeSet<>(
Expand All @@ -112,18 +116,18 @@ public static PartitionLocation[] getSkewPartitionLocations(
sortSet.addAll(Arrays.asList(locations));
PartitionLocation[] orderedPartitionLocations = sortSet.toArray(new PartitionLocation[0]);

List<PartitionLocation> result = new LinkedList<>();

int step = locations.length / subPartitionSize;
List<PartitionLocation> result = new ArrayList<>(step + 1);

// if partition location is [1,2,3,4,5,6,7,8,9,10], and skew partition split to 3 task:
// task 0: 1, 6, 7
// task 1: 2, 5, 8
// task 2: 3, 4, 9, 10
for (int i = 0; i < step + 1; i++) {
if (i % 2 == 0 && (i * 3 + subPartitionIndex) < locations.length) {
if (i % 2 == 0 && (i * subPartitionSize + subPartitionIndex) < locations.length) {
result.add(orderedPartitionLocations[i * subPartitionSize + subPartitionIndex]);
} else if (((i + 1) * subPartitionSize - subPartitionIndex - 1) < locations.length) {
} else if (i % 2 == 1
&& ((i + 1) * subPartitionSize - subPartitionIndex - 1) < locations.length) {
result.add(orderedPartitionLocations[(i + 1) * subPartitionSize - subPartitionIndex - 1]);
}
}
Expand Down Expand Up @@ -176,7 +180,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {

private Map<Integer, Set<Integer>> batchesRead = new HashMap<>();

private final Set<PushFailedBatch> failedBatches;
private final Map<String, Set<PushFailedBatch>> failedBatches;

private byte[] compressedBuf;
private byte[] rawDataBuf;
Expand Down Expand Up @@ -216,15 +220,15 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private ExceptionMaker exceptionMaker;
private boolean closed = false;

private final boolean pushShuffleFailureTrackingEnabled;
private final boolean splitSkewPartitionWithoutMapRange;

CelebornInputStreamImpl(
CelebornConf conf,
TransportClientFactory clientFactory,
String shuffleKey,
PartitionLocation[] locations,
int[] attempts,
Set<PushFailedBatch> failedBatchSet,
Map<String, Set<PushFailedBatch>> failedBatchSet,
int attemptNumber,
int startMapIndex,
int endMapIndex,
Expand All @@ -234,6 +238,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
int shuffleId,
int partitionId,
ExceptionMaker exceptionMaker,
boolean splitSkewPartitionWithoutMapRange,
MetricsCallback metricsCallback)
throws IOException {
this.conf = conf;
Expand All @@ -253,7 +258,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
!conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout();
this.failedBatches = failedBatchSet;
this.pushShuffleFailureTrackingEnabled = conf.clientPushFailureTrackingEnabled();
this.splitSkewPartitionWithoutMapRange = splitSkewPartitionWithoutMapRange;
this.fetchExcludedWorkers = fetchExcludedWorkers;

if (conf.clientPushReplicateEnabled()) {
Expand Down Expand Up @@ -299,7 +304,9 @@ private PartitionLocation nextReadableLocation() {
return null;
}
PartitionLocation currentLocation = locations[fileIndex];
while (skipLocation(startMapIndex, endMapIndex, currentLocation)) {
// if pushShuffleFailureTrackingEnabled is true, should not skip location
while (!splitSkewPartitionWithoutMapRange
&& skipLocation(startMapIndex, endMapIndex, currentLocation)) {
skipCount.increment();
fileIndex++;
if (fileIndex == locationCount) {
Expand Down Expand Up @@ -666,15 +673,13 @@ private boolean fillBuffer() throws IOException {

// de-duplicate
if (attemptId == attempts[mapId]) {
if (pushShuffleFailureTrackingEnabled) {
if (splitSkewPartitionWithoutMapRange) {
PushFailedBatch failedBatch =
new PushFailedBatch(
mapId,
attemptId,
batchId,
currentReader.getLocation().getId(),
currentReader.getLocation().getEpoch());
if (this.failedBatches.contains(failedBatch)) {
mapId, attemptId, batchId, currentReader.getLocation().getEpoch());
if (this.failedBatches
.get(currentReader.getLocation().getUniqueId())
.contains(failedBatch)) {
logger.warn("Skip duplicated batch: {}.", failedBatch);
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ public DfsPartitionReader(
.setFileName(location.getFileName())
.setStartIndex(startMapIndex)
.setEndIndex(endMapIndex)
.setShuffleDataNeedSort(conf.clientPushFailureTrackingEnabled())
.build()
.toByteArray());
ByteBuffer response = client.sendRpcSync(openStream.toByteBuffer(), fetchTimeoutMs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ public LocalPartitionReader(
.setStartIndex(startMapIndex)
.setEndIndex(endMapIndex)
.setReadLocalShuffle(true)
.setShuffleDataNeedSort(conf.clientPushFailureTrackingEnabled())
.build()
.toByteArray());
ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), fetchTimeoutMs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ public void onFailure(int chunkIndex, Throwable e) {
.setFileName(location.getFileName())
.setStartIndex(startMapIndex)
.setEndIndex(endMapIndex)
.setShuffleDataNeedSort(conf.clientPushFailureTrackingEnabled())
.build()
.toByteArray());
ByteBuffer response = client.sendRpcSync(openStreamMsg.toByteBuffer(), fetchTimeoutMs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.celeborn.client

import java.util
import java.util.Collections
import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, ScheduledFuture, TimeUnit}
import java.util.concurrent.atomic.{AtomicInteger, LongAdder}

Expand Down Expand Up @@ -210,7 +211,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
attemptId: Int,
numMappers: Int,
partitionId: Int = -1,
pushFailedBatches: util.Set[PushFailedBatch] = Sets.newHashSet()): (Boolean, Boolean) = {
pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap())
: (Boolean, Boolean) = {
getCommitHandler(shuffleId).finishMapperAttempt(
shuffleId,
mapId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ object LifecycleManager {
// shuffle id -> partition id -> partition locations
type ShuffleFileGroups =
ConcurrentHashMap[Int, ConcurrentHashMap[Integer, util.Set[PartitionLocation]]]
// shuffle id -> partition uniqueId -> PushFailedBatch set
type ShufflePushFailedBatches =
ConcurrentHashMap[Int, ConcurrentHashMap[Integer, util.Set[PushFailedBatch]]]
ConcurrentHashMap[Int, util.HashMap[String, util.Set[PushFailedBatch]]]
type ShuffleAllocatedWorkers =
ConcurrentHashMap[Int, ConcurrentHashMap[WorkerInfo, ShufflePartitionLocationInfo]]
type ShuffleFailedWorkers = ConcurrentHashMap[WorkerInfo, (StatusCode, Long)]
Expand Down Expand Up @@ -758,7 +759,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
mapId: Int,
attemptId: Int,
numMappers: Int,
pushFailedBatches: util.Set[PushFailedBatch]): Unit = {
pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]]): Unit = {

val (mapperAttemptFinishedSuccess, allMapperFinished) =
commitManager.finishMapperAttempt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ abstract class CommitHandler(
attemptId: Int,
numMappers: Int,
partitionId: Int,
pushFailedBatches: util.Set[PushFailedBatch],
pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean)

def registerShuffle(shuffleId: Int, numMappers: Int): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class MapPartitionCommitHandler(
attemptId: Int,
numMappers: Int,
partitionId: Int,
pushFailedBatches: util.Set[PushFailedBatch],
pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = {
val inProcessingPartitionIds =
inProcessMapPartitionEndIds.computeIfAbsent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable

import com.google.common.cache.{Cache, CacheBuilder}
import com.google.common.collect.Sets

import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
Expand Down Expand Up @@ -82,10 +83,19 @@ class ReducePartitionCommitHandler(
.maximumSize(rpcCacheSize)
.build().asInstanceOf[Cache[Int, ByteBuffer]]

val newMapFunc: function.Function[Int, ConcurrentHashMap[Integer, util.Set[PushFailedBatch]]] =
new util.function.Function[Int, ConcurrentHashMap[Integer, util.Set[PushFailedBatch]]]() {
override def apply(s: Int): ConcurrentHashMap[Integer, util.Set[PushFailedBatch]] = {
JavaUtils.newConcurrentHashMap[Integer, util.Set[PushFailedBatch]]()
private val newShuffleId2PushFailedBatchMapFunc
: function.Function[Int, util.HashMap[String, util.Set[PushFailedBatch]]] =
new util.function.Function[Int, util.HashMap[String, util.Set[PushFailedBatch]]]() {
override def apply(s: Int): util.HashMap[String, util.Set[PushFailedBatch]] = {
new util.HashMap[String, util.Set[PushFailedBatch]]()
}
}

private val uniqueId2PushFailedBatchMapFunc
: function.Function[String, util.Set[PushFailedBatch]] =
new util.function.Function[String, util.Set[PushFailedBatch]]() {
override def apply(s: String): util.Set[PushFailedBatch] = {
Sets.newHashSet[PushFailedBatch]()
}
}

Expand Down Expand Up @@ -243,7 +253,7 @@ class ReducePartitionCommitHandler(
attemptId: Int,
numMappers: Int,
partitionId: Int,
pushFailedBatches: util.Set[PushFailedBatch],
pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = {
shuffleMapperAttempts.synchronized {
if (getMapperAttempts(shuffleId) == null) {
Expand All @@ -257,8 +267,14 @@ class ReducePartitionCommitHandler(
if (null != pushFailedBatches && !pushFailedBatches.isEmpty) {
val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent(
shuffleId,
newMapFunc)
pushFailedBatchesMap.put(mapId, pushFailedBatches)
newShuffleId2PushFailedBatchMapFunc)
pushFailedBatches.forEach((k, v) => {
val partitionPushFailedBatches = pushFailedBatchesMap.computeIfAbsent(
k,
uniqueId2PushFailedBatchMapFunc)
partitionPushFailedBatches.addAll(v)
})
pushFailedBatchesMap.get(pushFailedBatches)
}
// Mapper with this attemptId finished, also check all other mapper finished or not.
(true, !attempts.exists(_ < 0))
Expand Down Expand Up @@ -311,8 +327,7 @@ class ReducePartitionCommitHandler(
pushFailedBatches =
shufflePushFailedBatches.getOrDefault(
shuffleId,
JavaUtils.newConcurrentHashMap()).values().asScala.flatMap(x =>
x.asScala.toSet[PushFailedBatch]).toSet.asJava)
new util.HashMap[String, util.Set[PushFailedBatch]]()))
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
}
})
Expand Down
Loading

0 comments on commit 599be24

Please sign in to comment.