Shuffle Write 请看 Shuffle Write解析。
本文将讲解shuffle Reduce部分,shuffle的下游Stage的第一个rdd是ShuffleRDD,通过其compute方法来获取上游Stage Shuffle Write溢写到磁盘文件数据的一个迭代器:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) .read() .asInstanceOf[Iterator[(K, C)]] }
从SparkEnv中获取shuffleManager(这里是SortShuffleManager),通过manager获取Reader并调用其read方法来得到一个迭代器。
override def getReader[K, C]( handle: ShuffleHandle, startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) }
getReader方法实例化了一个BlockStoreShuffleReader,参数有需要获取分区对应的partitionId,看看起read方法:
override def read(): Iterator[Product2[K, C]] = { val blockFetcherItr = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, blockManager, // 获取存储数据位置的元数据 mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), // 每次远程请求传输的最大大小 SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) // 用压缩加密来包装流 val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => serializerManager.wrapStream(blockId, inputStream) } val serializerInstance = dep.serializer.newInstance() // 对每个流生成K/V迭代器 val recordIter = wrappedStreams.flatMap { wrappedStream => serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator } // 每条记录读取后更新任务度量 val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() // 生成完整的迭代器 val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( recordIter.map { record => readMetrics.incRecordsRead(1) record }, context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // 在map端已经聚合一次了 val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { // 只在reduce端聚合 val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // 若需要全局排序 dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } }
首先实例化了ShuffleBlockFetcherIterator对象,其中一个参数:
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
该方法获取reduce端数据的来源的元数据,返回的是 Seq[(BlockManagerId, Seq[(BlockId, Long)])],即数据是来自于哪个节点的哪些block的,并且block的数据大小是多少,看看getMapSizesByExecutorId是怎么实现的:
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") // 获取元数据信息 val statuses = getStatuses(shuffleId) // 转换格式并得到指定partition的元数据信息 statuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } }
传入shuffleId获取对应shuffle的所有元数据信息
转换格式并获取指定partition的元数据
跟进getStatuses:
private def getStatuses(shuffleId: Int): Array[MapStatus] = { // 直接从mapStatuses中获取 val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") val startTime = System.currentTimeMillis var fetchedStatuses: Array[MapStatus] = null ...... if (fetchedStatuses == null) { // We won the race to fetch the statuses; do so logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { // 从远程获取元数据 val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) // 反序列化 fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") // 加入mapStatus mapStatuses.put(shuffleId, fetchedStatuses) } finally { fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } } } ..... } } else { return statuses } }
若能从mapStatuses获取到则直接返回,若不能则向mapOutputTrackerMaster通信发送GetMapOutputStatuses消息来获取元数据。
我们知道一个Executor对应一个CoarseGrainedExecutorBackend,构建CoarseGrainedExecutorBackend的时候会创建一个SparkEnv,创建SparkEnv的时候会创建一个mapOutputTracker,即mapOutputTracker和Executor一一对应,也就是每一个Executor都有一个mapOutputTracker来维护元数据信息。
这里的mapStatuses就是mapOutputTracker保存元数据信息的,mapOutputTracker和Executor一一对应,在该Executor上完成的Shuffle Write的元数据信息都会保存在其mapStatus里面,另外通过远程获取的其他Executor上完成的Shuffle Write的元数据信息也会在当前的mapStatuses中保存。
Executor对应的是mapOutputTrackerWorker,而Driver对应的是mapOutputTrackerMaster,两者都是在实例化SparkEnv的时候创建的,每个在Executor上完成的Shuffle Task的结果都会注册到driver端的mapOutputTrackerMaster中,即driver端的mapOutputTrackerMaster的mapStatuses保存这所有元数据信息,所以当一个Executor上的任务需要获取一个shuffle的输出时,会先在自己的mapStatuses中查找,找不到再和mapOutputTrackerMaster通信获取元数据。
mapOutputTrackerMaster收到消息后的处理逻辑:
case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))
调用了tracker的post方法:
def post(message: GetMapOutputMessage): Unit = { mapOutputRequests.offer(message) }
将该Message加入了mapOutputRequests中,mapOutputRequests是一个链式阻塞队列,在mapOutputTrackerMaster初始化的时候专门启动了一个线程池来执行这些请求:
private val threadpool: ThreadPoolExecutor = { val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8) val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher") for (i <- 0 until numThreads) { pool.execute(new MessageLoop) } pool }
看看线程处理类MessageLoop的run方法是怎么定义的:
private class MessageLoop extends Runnable { override def run(): Unit = { try { while (true) { try { // 取出一个GetMapOutputMessage val data = mapOutputRequests.take() if (data == PoisonPill) { // Put PoisonPill back so that other MessageLoops can see it. mapOutputRequests.offer(PoisonPill) return } val context = data.context val shuffleId = data.shuffleId val hostPort = context.senderAddress.hostPort logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) // 通过shuffleId获取对应序列化后的元数据信息 val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) // 返回数据 context.reply(mapOutputStatuses) } catch { case NonFatal(e) => logError(e.getMessage, e) } } } catch { case ie: InterruptedException => // exit } } }
通过shuffleId获取对应序列化后的元数据信息并返回,具体看看getSerializedMapOutputStatuses的实现:
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { var statuses: Array[MapStatus] = null var retBytes: Array[Byte] = null var epochGotten: Long = -1 // 从cache中检索出MapStatus,若没有则从mapStatuses中获取 def checkCachedStatuses(): Boolean = { epochLock.synchronized { if (epoch > cacheEpoch) { cachedSerializedStatuses.clear() clearCachedBroadcast() cacheEpoch = epoch } cachedSerializedStatuses.get(shuffleId) match { case Some(bytes) => retBytes = bytes true case None => logDebug("cached status not found for : " + shuffleId) statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus]) epochGotten = epoch false } } } if (checkCachedStatuses()) return retBytes var shuffleIdLock = shuffleIdLocks.get(shuffleId) if (null == shuffleIdLock) { val newLock = new Object() // in general, this condition should be false - but good to be paranoid val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) shuffleIdLock = if (null != prevLock) prevLock else newLock } // synchronize so we only serialize/broadcast it once since multiple threads call // in parallel shuffleIdLock.synchronized { if (checkCachedStatuses()) return retBytes // 序列化statues val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager, isLocal, minSizeForBroadcast) logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the epoch hasn't changed while we were working epochLock.synchronized { if (epoch == epochGotten) { cachedSerializedStatuses(shuffleId) = bytes if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast } else { logInfo("Epoch changed, not caching!") removeBroadcast(bcast) } } bytes } }
大体思路是先从缓存中获取元数据(MapStatuses),获取到直接返回,若没有则从mapStatuses获取,获取到后将其序列化后返回,随后返回给mapOutputTrackerWorker(刚才与之通信的节点),mapOutputTracker收到回复后又将元数据序列化并加入当前Executor的mapStatuses中。
再回到getMapSizesByExecutorId方法中,getStatuses得到shuffleID对应的所有的元数据信息后,通过convertMapStatuses方法将获得的元数据信息转化成形如Seq[(BlockManagerId, Seq[(BlockId, Long)])]格式的位置信息,用来读取指定的分区的数据:
private def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) // 存储指定partition的元数据 val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { for (part <- startPartition until endPartition) { splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) } } } splitsByAddress.toSeq }
这里的参数statuses:Array[MapStatus]是前面获取的上游stage所有的shuffle Write 文件的元数据,并且是按map端的partitionId排序的,通过zipWithIndex将元素和这个元素在数组中的ID(索引号)组合成键/值对,这里的索引号即是map端的partitionId,再根据shuffleId、mapPartitionId、reducePartitionId来构建ShuffleBlockId(在map端的ShuffleBlockId构建中的reducePartitionId始终是0,因为一个ShuffleMapTask就一个Block,而这里加入的真正的reducePartitionId在后面通过index文件获取对应reduce端partition偏移量的时候需要用到),并估算得到对应数据的大小,因为后面获取远程数据的时候需要限制大小,最后返回位置信息。
至此mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)方法完成,返回了指定分区对应的元数据MapStatus信息。
在初始化对象ShuffleBlockFetcherIterator的时候调用了其初始化方法initialize():
private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup. context.addTaskCompletionListener(_ => cleanup()) // 区分local blocks和remote blocks并返回远程请求FetchRequest val remoteRequests = splitLocalRemoteBlocks() // 将远程请求随机的加入到fetchRequests队列中 fetchRequests ++= Utils.randomize(remoteRequests) assert ((0 == reqsInFlight) == (0 == bytesInFlight), "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight) // 从fetchRequests取出远程请求,并使用sendRequest方法发送请求 fetchUpToMaxBytes() val numFetches = remoteRequests.size - fetchRequests.size logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) // 获取本地blocks fetchLocalBlocks() logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime)) }
区分local blocks和remote blocks,并返回远程请求FetchRequest加入到fetchRequests队列中
从fetchRequests取出远程请求,并使用sendRequest方法发送请求,获取远程数据
获取本地blocks
先看是怎么区分local blocks和remote blocks的:
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // 将一次能获取的数据最大大小/5,目的是增加并行度,最大为5个并行度 val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) // 存储远程请求的数组 val remoteRequests = new ArrayBuffer[FetchRequest] // Tracks total number of blocks (including zero sized blocks) var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size // 若block所在executor就是当前executor,则判断为本地,否则为远程 if (address.executorId == blockManager.blockManagerId.executorId) { // 过滤掉大小为0的blocks localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator var curRequestSize = 0L var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() // Skip empty blocks if (size > 0) { curBlocks += ((blockId, size)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } // 当请求大小超过了限制,则创建一个FetchRequest并加入到remoteRequests中 if (curRequestSize >= targetRequestSize) { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) curBlocks = new ArrayBuffer[(BlockId, Long)] logDebug(s"Creating fetch request of $curRequestSize at $address") curRequestSize = 0 } } // 将剩余的blocks创建一个FetchRequest并加入到remoteRequests中 if (curBlocks.nonEmpty) { remoteRequests += new FetchRequest(address, curBlocks) } } } logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") remoteRequests }
为了增加在远程节点获取数据的并行度,将一个请求的大小限制除以5作为最终的大小限制,即每次最多启动5个线程去最多5个节点上读取数据
判断是否是本地blocks的条件是block所在的executor和当前executor是否是同一个
遍历远程数据节点(Executor节点)的blocks,在一个节点上的请求数据超过大小限制则构建一个FetchRequest并加入到remoteRequests中,最后返回远程请求remoteRequests,这里的FetchRequest是对一个请求数据的包装,包括地址和blockId及大小
区分完local remote blocks后加入到了队列fetchRequests中,并调用fetchUpToMaxBytes()来获取远程数据:
private def fetchUpToMaxBytes(): Unit = { // Send fetch requests up to maxBytesInFlight while (fetchRequests.nonEmpty && (bytesInFlight == 0 || (reqsInFlight + 1 <= maxReqsInFlight && bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) { sendRequest(fetchRequests.dequeue()) } }
从fetchRequests中取出FetchRequest,并调用了sendRequest方法:
private[this] def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) bytesInFlight += req.size reqsInFlight += 1 // 转成map Map[blockId,size] val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val remainingBlocks = new HashSet[String]() ++= sizeMap.keys val blockIds = req.blocks.map(_._1.toString) val address = req.address // 通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据 shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, new BlockFetchingListener { // 将结果保存到results中 override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { // Only add the buffer to results queue if the iterator is not zombie, // i.e. cleanup() has not been called yet. ShuffleBlockFetcherIterator.this.synchronized { if (!isZombie) { // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() remainingBlocks -= blockId results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, remainingBlocks.isEmpty)) logDebug("remainingBlocks: " + remainingBlocks) } } logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) results.put(new FailureFetchResult(BlockId(blockId), address, e)) } } ) }
通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据,默认是通过NettyBlockTransferService的fetchBlocks方法实现的,不管是成功还是失败都将构建SuccessFetchResult & FailureFetchResult 结果放入results中。
获取完远程的数据接着通过fetchLocalBlocks()方法来获取本地的blocks信息:
private[this] def fetchLocalBlocks() { val iter = localBlocks.iterator while (iter.hasNext) { val blockId = iter.next() try { val buf = blockManager.getBlockData(blockId) shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) return } } }
迭代需要获取的block,直接从blockManager中获取数据,并通过结果数据构建SuccessFetchResult或者FailureFetchResult放入results中,看看在blockManager.getBlockData(blockId)的实现:
override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { getLocalBytes(blockId) match { case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) case None => // If this block manager receives a request for a block that it doesn't have then it's // likely that the master has outdated block statuses for this block. Therefore, we send // an RPC so that this block is marked as being unavailable from this block manager. reportBlockStatus(blockId, BlockStatus.empty) throw new BlockNotFoundException(blockId.toString) } } }
再看看getBlockData方法:
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // 根据ShuffleID和MapID获取索引文件 val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) val in = new DataInputStream(new FileInputStream(indexFile)) try { // 跳到对应Block的数据区 ByteStreams.skipFully(in, blockId.reduceId * 8) // partition对应的开始offset val offset = in.readLong() // partition对应的结束offset val nextOffset = in.readLong() new FileSegmentManagedBuffer( transportConf, getDataFile(blockId.shuffleId, blockId.mapId), offset, nextOffset - offset) } finally { in.close() } }
根据shuffleId和mapId获取index文件,并创建一个读文件的文件流,根据block的reduceId(上面获取对应partition元数据的时候提到过)跳过对应的Block的数据区,先后获取开始和结束的offset,然后在数据文件中读取数据。
得到所有数据结果result后,再回到read()方法中:
override def read(): Iterator[Product2[K, C]] = { val blockFetcherItr = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, blockManager, // 与mapOutputTrackerMaster通信获取存储数据位置的元数据 mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), // 每次传输的最大大小 SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) // 用压缩加密来包装流 val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => serializerManager.wrapStream(blockId, inputStream) } val serializerInstance = dep.serializer.newInstance() // 对每个流生成K/V迭代器 val recordIter = wrappedStreams.flatMap { wrappedStream => serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator } // 每条记录读取后更新任务度量 val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() // 生成完整的迭代器 val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( recordIter.map { record => readMetrics.incRecordsRead(1) record }, context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // 在map端已经聚合一次了 val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { // 只在reduce端聚合 val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // 若需要全局排序 dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } }
这里的ShuffleBlockFetcherIterator继承了Iterator,results可以被迭代,在其next()方法中将FetchResult以(blockId,inputStream)的形式返回:
case SuccessFetchResult(blockId, address, _, buf, _) => try { (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) } catch { case NonFatal(t) => throwFetchFailedException(blockId, address, t) }
在read()方法的后半部分会进行聚合和排序,和Shuffle Write部分很类似,这里大致描述一下。
在需要聚合的前提下,有map端聚合的时候执行combineCombinersByKey,没有则执行combineValuesByKey,但最终都调用了ExternalAppendOnlyMap的insertAll(iter)方法:
def combineCombinersByKey( iter: Iterator[_ <: Product2[K, C]], context: TaskContext): Iterator[(K, C)] = { val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) combiners.insertAll(iter) updateMetrics(context, combiners) combiners.iterator }
def combineValuesByKey( iter: Iterator[_ <: Product2[K, V]], context: TaskContext): Iterator[(K, C)] = { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) combiners.insertAll(iter) updateMetrics(context, combiners) combiners.iterator }
def insertAll(entries: Iterator[Product2[K, V]]): Unit = { if (currentMap == null) { throw new IllegalStateException( "Cannot insert new elements into a map after calling iterator") } // An update function for the map that we reuse across entries to avoid allocating // a new closure each time var curEntry: Product2[K, V] = null val update: (Boolean, C) => C = (hadVal, oldVal) => { if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2) } while (entries.hasNext) { curEntry = entries.next() val estimatedSize = currentMap.estimateSize() if (estimatedSize > _peakMemoryUsedBytes) { _peakMemoryUsedBytes = estimatedSize } if (maybeSpill(currentMap, estimatedSize)) { currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) addElementsRead() } }
在里面的迭代最终都会调用上面提到的ShuffleBlockFetcherIterator的next方法来获取数据。
每次update&insert也会估算currentMap的大小,并判断是否需要溢写到磁盘文件,若需要则将map中的数据根据定义的keyComparator对key进行排序后返回一个迭代器,然后写到一个临时的磁盘文件,然后新建一个map来放新的数据。
执行完combiners[ExternalAppendOnlyMap]的insertAll后,调用其iterator来返回一个代表一个完整partition数据(内存及spillFile)的迭代器:
override def iterator: Iterator[(K, C)] = { if (currentMap == null) { throw new IllegalStateException( "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") } if (spilledMaps.isEmpty) { CompletionIterator[(K, C), Iterator[(K, C)]]( destructiveIterator(currentMap.iterator), freeCurrentMap()) } else { new ExternalIterator() } }
跟进ExternalIterator类的实例化:
// A queue that maintains a buffer for each stream we are currently merging // This queue maintains the invariant that it only contains non-empty buffers private val mergeHeap = new mutable.PriorityQueue[StreamBuffer] // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator( currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap()) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => val kcPairs = new ArrayBuffer[(K, C)] readNextHashCode(it, kcPairs) if (kcPairs.length > 0) { mergeHeap.enqueue(new StreamBuffer(it, kcPairs)) } }
将currentMap中的数据经过排序后和spillFile数据的iterator组合在一起得到inputStreams ,迭代这个inputStreams ,将所有数据都保存在mergeHeadp中,在ExternalIterator方法的next()方法中将被访问到。
最后若需要对数据进行全局的排序,则通过只有排序参数的ExternalSorter的insertAll方法来进行排序,和Shuffle Write一样的这里就不细讲了。
最终返回一个指定partition所有数据的一个迭代器。
作者:BIGUFO
链接:https://www.jianshu.com/p/50278b0a0050