手记

Spark外部数据源

1-jdbc外部数据源学习
2-自己实现简单的外部数据源

JDBC外部数据源学习

入库:读取mysql的简单操作:
val jdbcDF = spark.read.jdbc(url,tableName,prop)

1-首先会设置format,然后进行load操作
format("jdbc").load()
2-DataFrameReader 中 load操作
//获取BaseRelation
sparkSession.baseRelationToDataFrame(
      DataSource.apply(
        sparkSession,
        paths = paths,
        userSpecifiedSchema = userSpecifiedSchema,
        className = source,
        options = extraOptions.toMap).resolveRelation())

 //获取provider
 lazy val providingClass: Class[_] = DataSource.lookupDataSource(className)
 获取到JdbcRelationProvider,然后通过resolveRelation()方法创建JDBCRelation
 def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = {
    val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
      // TODO: Throw when too much is given.
      case (dataSource: SchemaRelationProvider, Some(schema)) =>
        dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema)
      case (dataSource: RelationProvider, None) =>
        dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
      case (_: SchemaRelationProvider, None) =>
        throw new AnalysisException(s"A schema needs to be specified when using $className.")
      case (dataSource: RelationProvider, Some(schema)) =>
        val baseRelation =
          dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
       /**省略具体代码*/

JDBCRelation提供具体的查询或者写入方法,可以查看它继承的类PrunedFilteredScan,带有过滤的查询功能

JDBCRelation(
    parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
  extends BaseRelation
  with PrunedFilteredScan
  with InsertableRelation {
   /**省略具体代码*/

  }

3-JDBCRelation主要的方法:1-获取表的schema,后续用于查询具体字段使用等
override val schema: StructType = JDBCRDD.resolveTable(jdbcOptions)
可以看到具体实现就是通过jdbc的方式进行的
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))2-buildScan(requiredColumns: Array[String], filters: Array[Filter])
查询表中的信息,并且可以指定列或者过滤功能。
JDBCRDD.scanTable(
      sparkSession.sparkContext,
      schema,
      requiredColumns,
      filters,
      parts,
      jdbcOptions).asInstanceOf[RDD[Row]]

注意scanTable方法返回的其实是一个迭代器:
RDD[InternalRow] 

Build and return JDBCRDD from the given information.
我们继续看JDBCRDD会发现具体实现的方法compute:
val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause"
    stmt = conn.prepareStatement(sqlText,
        ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
    stmt.setFetchSize(options.fetchSize)
    rs = stmt.executeQuery()
    val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)

可以看到也是通过jdcbc的方式获取到ResultSet,这里并不会对结果集进行遍历获取,而是
会将结果集传递给一个迭代器NextIterator,如下,这样做的好处我猜应该是等到真正需要对结果做操作的时候,比如入库才会真正通过getNext()方法遍历具体的数据.否则现在就会把获取到的数据保存起来,会占用比较大的内存,然后产生DataFrame。有可能内存就不够用了。

private[spark] def resultSetToSparkInternalRows(
      resultSet: ResultSet,
      schema: StructType,
      inputMetrics: InputMetrics): Iterator[InternalRow] = {
    new NextIterator[InternalRow] {
      private[this] val rs = resultSet
      private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
      private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))

      override protected def close(): Unit = {
        try {
          rs.close()
        } catch {
          case e: Exception => logWarning("Exception closing resultset", e)
        }
      }

      override protected def getNext(): InternalRow = {
        if (rs.next()) {
          inputMetrics.incRecordsRead(1)
          var i = 0
          while (i < getters.length) {
            getters(i).apply(rs, mutableRow, i)
            if (rs.wasNull) mutableRow.setNullAt(i)
            i = i + 1
          }
          mutableRow
        } else {
          finished = true
          null.asInstanceOf[InternalRow]
        }
      }
    }
  }


4-具体使用

    // 取得该表数据
    val jdbcDF = spark.read.jdbc(url,tableName,prop)

    jdbcDF.createOrReplaceTempView("test")
    spark.sql("select * from test").show()

当进行spark.sql("select * from test")查询操作的时候,其实就是调用了上面的buildScan方法。

自己实现简单的外部数据源

这里我们实现一个读取excel的外部数据源

代码实现

1-继承RelationProvider,实现我们自己的Provider
class DefaultSource extends RelationProvider with SchemaRelationProvider {

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    createRelation(sqlContext, parameters, null)
  }


  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String],
                              schema: StructType): BaseRelation = {

    val path = parameters.get("path")

    path match {
      case Some(p) => new ExcelRDDDatasourceRelation(sqlContext, p, schema)
      case _ => throw new IllegalArgumentException("Path is required for custom-datasource format!!")
    }
  }

}
2-实现具体的BaseRelation
class ExcelRDDDatasourceRelation(override val sqlContext : SQLContext, path : String, excelSchema : StructType)
  extends BaseRelation with TableScan with Serializable {

  override def schema: StructType = {
    if (excelSchema != null) {
      excelSchema
    } else {
      StructType(
        StructField("a", StringType, false) ::
        StructField("b", StringType, true) ::
        StructField("c", StringType, true) :: Nil
      )
    }
  }

  //真正调用的方法
  override def buildScan(): RDD[Row] = {
    println("TableScan: buildScan called...")

    val rdd=sqlContext.sparkContext.binaryFiles(path)

    rdd.flatMap(x=>{
      new ExcelUtils().compute(sqlContext.sparkContext,schema,x)
    })

  }

}

3-实现读取核心业务,返回迭代器
class ExcelUtils {

  var inputStream: DataInputStream = null

  def compute(sc: SparkContext, schema: StructType,
              path: (String, PortableDataStream)): Iterator[org.apache.spark.sql.Row] = {

    val pathName = path._1
    println("pathName:" + pathName)

    val stream = path._2
    inputStream = stream.open()

    //解析excel
    val workbook = StreamingReader.builder.rowCacheSize(100).bufferSize(4096).open(inputStream)

    val sheet = workbook.getSheetAt(0)

    val sheetRs = sheet.iterator()

    resultSetToSparkRows(sheetRs, schema)
  }

  //返回迭代器
  def resultSetToSparkRows(sheetRs: util.Iterator[org.apache.poi.ss.usermodel.Row],
                           schema: StructType): Iterator[org.apache.spark.sql.Row] = {

    new ExcelNextIterator[Row] {


      //遍历结束执行
      override protected def close(): Unit = {
        println("close stream")
        inputStream.close()
      }

      //获取下一条数据
      override protected def getNext(): org.apache.spark.sql.Row = {

        if (sheetRs.hasNext) {

          val r = sheetRs.next()
          import scala.collection.JavaConversions._
          val cells = new util.ArrayList[String]
          //获取每一行数据
          for (c <- r) {
            cells.add(c.getStringCellValue)
          }

          //转为Row
          Row.fromSeq(cells)
        } else {
          //结束标识
          finished = true
          null.asInstanceOf[Row]
        }

      }

    }

  }

}

4-测试
object ExcelApp extends App {
  println("Application started...")

  val conf = new SparkConf().setAppName("spark-custom-datasource")
  val spark = SparkSession.builder().config(conf).master("local").getOrCreate()

  val df = spark.sqlContext.read.format("com.fayayo.excel.spark.sql.datasource").load("F:\\Excel\\demo03.xlsx")


  df.createOrReplaceTempView("test")
  spark.sql("select * from test").show()

  //创建Properties存储数据库相关属性
  val prop = new Properties()
  prop.put("user", "root")
  prop.put("password", "root123")
  df.write.mode("append").jdbc("jdbc:mysql://localhost:3306/test", "test.abc", prop)

  println("Application Ended...")
}
可以看到数据写入了mysql表中。

源码

github代码地址:https://github.com/lizu18xz/spark-extend-dataSource
0人推荐
随时随地看视频
慕课网APP