Inside Spark Technology: detailed explanation of Shuffle

Posted by rbrown on Wed, 08 Sep 2021 04:50:23 +0200

Next, we will introduce some more detailed implementation details.

Shuffle is undoubtedly a key point of performance tuning. This paper will deeply analyze the implementation details of Spark Shuffle from the perspective of source code implementation.

The upper boundary of each Stage requires either reading data from external storage or reading the output of the previous Stage; The lower boundary, either needs to be written to the local file system for reading by the child Stage, or the ResultTask needs to output the results.

Start with org.apache.spark.rdd.ShuffledRDD. Because ShuffledRDD is the beginning of a Stage, it needs to obtain the output result of the previous Stage, and then perform the next operation. So how is this data acquisition realized? Following the implementation of ShuffledRDD, we can sort out this line. First, let's take a look at how compute is implemented.

  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .asInstanceOf[Iterator[(K, C)]]

It needs to get the shuffleReader from the shufflemanager, and then read the data for calculation. Take a look at shufflemanager:

 // Let the user specify short names for shuffle managers
    val shortShuffleMgrNames = Map(
      "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
      "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
    val shuffleMgrName = conf.get("spark.shuffle.manager", "hash")
    val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
    val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)

ShuffleManager is divided into hash and sort. Hash is the default, that is, Shuffle does not sort. Students familiar with MapReduce know that MapReduce must be sorted anyway, that is, those to the Reduce end have been sorted. Of course, this is also to deal with a large amount of data. Before Spark1.1, only hash based Shuffle was supported, and sort based Shuffle was a newly added experimental function in 1.1.
As the name implies, the data in Reduce needs to be in order, so it can be processed immediately after Reduce obtains the data; Instead of waiting for all the data to be processed. This will be explained through the source code. Sort means sorting. In fact, sort may be more meaningful for sortByKey conversion.

ShuffledRDD obtains the results of the previous Stage through org.apache.spark.shuffle.hash.HashShuffleReader. HashShuffleReader gets the results through org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher$#fetch. Fetch forwards the request by calling

  def getMultiple(
      blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
      serializer: Serializer,
      readMetrics: ShuffleReadMetrics): BlockFetcherIterator = {
    val iter = new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer,

The final implementation is in #initialize,

  override def initialize() {
      // Split local and remote blocks.
      // Get the data list that needs to be requested remotely, and put the blockid of the local data in the localBlocksToFetch,
      // And read locally at
      val remoteRequests = splitLocalRemoteBlocks()
      // Add the remote requests into our queue in a random order
      fetchRequests ++= Utils.randomize(remoteRequests)
      // Send out initial requests for blocks, up to our maxBytesInFlight
      while (!fetchRequests.isEmpty && //Ensure that the memory occupied does not exceed the set value spark.reducer.maxMbInFlight. The default value is 48M
        (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
      val numFetches = remoteRequests.size - fetchRequests.size
      logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
      // Get Local Blocks
      startTime = System.currentTimeMillis
      getLocalBlocks() // Get from local
      logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")

The specific policies on how to obtain them are in #splitlocalremoteblocks

    protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
      // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
      // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
      // nodes, rather than blocking on reading output from one node.
      // In order to get data quickly, 5 threads will be started each time to fetch data from up to 5 node s;
      // The data requested each time will not exceed spark.reducer.maxMbInFlight (the default value is 48MB) / 5.
      // There are several reasons for this:
      // 1. Avoid occupying too much bandwidth of the target machine. In today's mainstream gigabit network card, bandwidth is still more important.
      //    If a connection will occupy 48M bandwidth, this Network IO may become a bottleneck.
      // 2. The requested data can be parallelized, so the time for requesting data can be greatly reduced. The total time to request data is the longest.
      //    If it is not a parallel request, the total time will be the sum of all request times.
      // spark.reducer.maxMbInFlight is also set to avoid taking up too much memory
      val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
      logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
      // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
      // at most maxBytesInFlight in order to limit the amount of data in flight.
      val remoteRequests = new ArrayBuffer[FetchRequest]
      var totalBlocks = 0
      for ((address, blockInfos) <- blocksByAddress) { //  address is actually the executor_id
        totalBlocks += blockInfos.size
        if (address == blockManagerId) { //If the data is local, go to local read directly
          // Filter out zero-sized blocks
          localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
          _numBlocksToFetch += localBlocksToFetch.size
        } else {
          val iterator = blockInfos.iterator
          var curRequestSize = 0L
          var curBlocks = new ArrayBuffer[(BlockId, Long)]
          while (iterator.hasNext) {
          // blockId is,
          // Format: "shuffle_"+ shuffleId + "_" +  mapId + "_" +  reduceId
            val (blockId, size) =
            // Skip empty blocks
            if (size > 0) { //Filter out files with size 0
              curBlocks += ((blockId, size))
              remoteBlocksToFetch += blockId
              _numBlocksToFetch += 1
              curRequestSize += size
            } else if (size < 0) {
              throw new BlockException(blockId, "Negative block size " + size)
            if (curRequestSize >= targetRequestSize) { // Avoid excessive amount of data requested at one time
              // Add this FetchRequest
              remoteRequests += new FetchRequest(address, curBlocks)
              curBlocks = new ArrayBuffer[(BlockId, Long)]
              logDebug(s"Creating fetch request of $curRequestSize at $address")
              curRequestSize = 0
          // Add in the final request
          if (!curBlocks.isEmpty) { // Put the remaining requests in the last request.
            remoteRequests += new FetchRequest(address, curBlocks)
      logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
        totalBlocks + " blocks")

Original link

Topics: Scala Big Data Spark