Spark中makeRDD源码分析

tech2022-10-19  132

makeRDD源码解析

// 返回ParallelCollectionRDD def makeRDD[T: ClassTag]( seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = withScope { parallelize(seq, numSlices) } //这里分区数numSlices参数进行了初始化,如果没传入该参数就会是初始化的默认值 //将代码块{parallelize(seq, numSlices)}作为参数传给withScope调用 val rdd1: RDD[Int] = sparkContext.makeRDD(list) val rdd2: RDD[Int] = sparkContext.parallelize(list) //这两个创建RDD的方式等价,因为makeRDD调用的是parallelize方法

makeRDD方法实际上是将传入的集合和分区数两个参数传给parallelize方法然后将返回结果作为参数传给withScope方法调用,下面来分析一下parallelize的源码:

// 1.parallelize源码 /* Distribute a local Scala collection to form an RDD 译文:将本地集合分发到RDD */ def parallelize[T: ClassTag]( seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = withScope { assertNotStopped() new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } //Seq是传入的集合,numSlices是分区数,这里分区数是默认值defaultParallelism,以下来分析defaultParallelism源码: 点开找到重写方法: override def defaultParallelism(): Int = backend.defaultParallelism() 再点开:def defaultParallelism(): Int ctrl+H看一下实现类,点击LocalSchedulerBackend 在该类里找到源码如下: override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) "spark.default.parallelism"是默认并行度,会从配置conf中获取,获取不到,则为totalCores 所以可以得出结论: 若在调用makeRDD方法时,传入了参数作为分区数,那么该参数值是最终分区数,若没有传参数,则分区数为默认值 totalCores,那么totalCores是多少呢?举个例子:在本地模式local环境下, setMaster(local): 此时分区数默认为1 setMaster(local[n]) 此时分区数默认为n setMaster(local[*]) 此时分区数默认为当前环境下最大核数 (假如电脑为816线程,那么local[*]默认分区数为16) 在调用saveAsTextFile方法生成文件时,每个分区对应一个文件 // 2.parallelize调用ParallelCollectionRDD,ParallelCollectionRDD伴生类源码如下 private[spark] class ParallelCollectionRDD[T: ClassTag]( sc: SparkContext, @transient private val data: Seq[T], numSlices: Int, locationPrefs: Map[Int, Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. override def getPartitions: Array[Partition] = { val slices = ParallelCollectionRDD.slice(data, numSlices).toArray slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } override def compute(s: Partition, context: TaskContext): Iterator[T] = { new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) } override def getPreferredLocations(s: Partition): Seq[String] = { locationPrefs.getOrElse(s.index, Nil) } } //3.parallelize实际调用的伴生对象ParallelCollectionRDD源码: private object ParallelCollectionRDD { /** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection * is an inclusive Range, we use inclusive range for the last slice. */ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { //检查分区数是否合法,设置分区数小于1则报异常 if (numSlices < 1) { throw new IllegalArgumentException("Positive number of partitions required") } // Sequences need to be sliced at the same set of index positions for operations // like RDD.zip() to behave as expected def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { (0 until numSlices).iterator.map { i => val start = ((i * length) / numSlices).toInt val end = (((i + 1) * length) / numSlices).toInt (start, end) } } //这里模式匹配判断传入的Seq集合是不是range类 seq match { case r: Range => positions(r.length, numSlices).zipWithIndex.map { case ((start, end), index) => // If the range is inclusive, use inclusive range for the last slice if (r.isInclusive && index == numSlices - 1) { new Range.Inclusive(r.start + start * r.step, r.end, r.step) } else { new Range(r.start + start * r.step, r.start + end * r.step, r.step) } }.toSeq.asInstanceOf[Seq[Seq[T]]] case nr: NumericRange[_] => // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) var r = nr for ((start, end) <- positions(nr.length, numSlices)) { val sliceSize = end - start slices += r.take(sliceSize).asInstanceOf[Seq[T]] r = r.drop(sliceSize) } slices //不是range类则执行这一条 case _ => //先将集合转换成Array集合 val array = seq.toArray // To prevent O(n^2) operations for List etc //这里将集合的长度和分区数作为参数调用上面的positions方法,返回一个元素为(start, end)的迭代器,迭代器中元素个数为分区数。然后迭代器中每个元组(start,end)调用slice方法来切分数组 positions(array.length, numSlices).map { case (start, end) => //这里的slice方法是切分数组=>(from,until) array.slice(start, end).toSeq }.toSeq } } } // 4.slice方法源码如下: override def slice(from: Int, until: Int): Array[T] = { val lo = math.max(from, 0) val hi = math.min(math.max(until, 0), repr.length) val size = math.max(hi - lo, 0) val result = java.lang.reflect.Array.newInstance(elementClass, size) if (size > 0) { Array.copy(repr, lo, result, 0, size) } result.asInstanceOf[Array[T]] } /* slice切分数组规则如下: 例如:设置分区数为3,List(1,2,3,4,5)经过position方法后,变成Array((0,1),(1,3),(3,5)) 那么Array调用map后,Array中每个元组调用slice方法,首先是(0,1),List(1,2,3,4,5)从0号索引开始不包含1,也就是1 放在第一个分区,然后(1,3),List(1,2,3,4,5)从1号索引位置开始不包含3,那么就是2,3放在第二个分区,然后是(3,5),List(1,2,3,4,5)从3号索引位置开始不包含5号索引位置的数字,即4,5放在第三个分区内 */

最后看一下withScope方法源码:

// 5.withScope源码:不用过分关注,parallelize返回RDD对象然后作为参数调用withScope返回RDD /** * Execute the given body such that all RDDs created in this body will have the same scope. 译文:执行给定的主体,使在该主体中创建的所有RDDs具有相同的作用域 * The name of the scope will be the first method name in the stack trace that is not the same as this method's. 译文:作用域的名称将是堆栈跟踪中与此方法不同的第一个方法名称。 * Note: Return statements are NOT allowed in body. */ private[spark] def withScope[T]( sc: SparkContext, allowNesting: Boolean = false)(body: => T): T = { val ourMethodName = "withScope" val callerMethodName = Thread.currentThread.getStackTrace() .dropWhile(_.getMethodName != ourMethodName) .find(_.getMethodName != ourMethodName) .map(_.getMethodName) .getOrElse { // Log a warning just in case, but this should almost certainly never happen logWarning("No valid method name for this RDD operation scope!") "N/A" } withScope[T](sc, callerMethodName, allowNesting, ignoreParent = false)(body) }

总结:

最新回复(0)