1-jdbc外部数据源学习
2-自己实现简单的外部数据源
JDBC外部数据源学习
入库:读取mysql的简单操作:
val jdbcDF = spark.read.jdbc(url,tableName,prop)
1-首先会设置format,然后进行load操作
format("jdbc").load()
2-DataFrameReader 中 load操作
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap).resolveRelation())
lazy val providingClass: Class[_] = DataSource.lookupDataSource(className)
获取到JdbcRelationProvider,然后通过resolveRelation()方法创建JDBCRelation
def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = {
val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
case (dataSource: SchemaRelationProvider, Some(schema)) =>
dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema)
case (dataSource: RelationProvider, None) =>
dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
case (_: SchemaRelationProvider, None) =>
throw new AnalysisException(s"A schema needs to be specified when using $className.")
case (dataSource: RelationProvider, Some(schema)) =>
val baseRelation =
dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
JDBCRelation提供具体的查询或者写入方法,可以查看它继承的类PrunedFilteredScan,带有过滤的查询功能
JDBCRelation(
parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
extends BaseRelation
with PrunedFilteredScan
with InsertableRelation {
}
3-JDBCRelation主要的方法:
(1)-获取表的schema,后续用于查询具体字段使用等
override val schema: StructType = JDBCRDD.resolveTable(jdbcOptions)
可以看到具体实现就是通过jdbc的方式进行的
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
(2)-buildScan(requiredColumns: Array[String], filters: Array[Filter])
查询表中的信息,并且可以指定列或者过滤功能。
JDBCRDD.scanTable(
sparkSession.sparkContext,
schema,
requiredColumns,
filters,
parts,
jdbcOptions).asInstanceOf[RDD[Row]]
注意scanTable方法返回的其实是一个迭代器:
RDD[InternalRow]
Build and return JDBCRDD from the given information.
我们继续看JDBCRDD会发现具体实现的方法compute:
val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
可以看到也是通过jdcbc的方式获取到ResultSet,这里并不会对结果集进行遍历获取,而是
会将结果集传递给一个迭代器NextIterator,如下,这样做的好处我猜应该是等到真正需要对结果做操作的时候,比如入库才会真正通过getNext()方法遍历具体的数据.否则现在就会把获取到的数据保存起来,会占用比较大的内存,然后产生DataFrame。有可能内存就不够用了。
private[spark] def resultSetToSparkInternalRows(
resultSet: ResultSet,
schema: StructType,
inputMetrics: InputMetrics): Iterator[InternalRow] = {
new NextIterator[InternalRow] {
private[this] val rs = resultSet
private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))
override protected def close(): Unit = {
try {
rs.close()
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
}
override protected def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
while (i < getters.length) {
getters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}
mutableRow
} else {
finished = true
null.asInstanceOf[InternalRow]
}
}
}
}
4-具体使用
val jdbcDF = spark.read.jdbc(url,tableName,prop)
jdbcDF.createOrReplaceTempView("test")
spark.sql("select * from test").show()
当进行spark.sql("select * from test")查询操作的时候,其实就是调用了上面的buildScan方法。
自己实现简单的外部数据源
代码实现
1-继承RelationProvider,实现我们自己的Provider
class DefaultSource extends RelationProvider with SchemaRelationProvider {
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
createRelation(sqlContext, parameters, null)
}
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String],
schema: StructType): BaseRelation = {
val path = parameters.get("path")
path match {
case Some(p) => new ExcelRDDDatasourceRelation(sqlContext, p, schema)
case _ => throw new IllegalArgumentException("Path is required for custom-datasource format!!")
}
}
}
2-实现具体的BaseRelation
class ExcelRDDDatasourceRelation(override val sqlContext : SQLContext, path : String, excelSchema : StructType)
extends BaseRelation with TableScan with Serializable {
override def schema: StructType = {
if (excelSchema != null) {
excelSchema
} else {
StructType(
StructField("a", StringType, false) ::
StructField("b", StringType, true) ::
StructField("c", StringType, true) :: Nil
)
}
}
override def buildScan(): RDD[Row] = {
println("TableScan: buildScan called...")
val rdd=sqlContext.sparkContext.binaryFiles(path)
rdd.flatMap(x=>{
new ExcelUtils().compute(sqlContext.sparkContext,schema,x)
})
}
}
3-实现读取核心业务,返回迭代器
class ExcelUtils {
var inputStream: DataInputStream = null
def compute(sc: SparkContext, schema: StructType,
path: (String, PortableDataStream)): Iterator[org.apache.spark.sql.Row] = {
val pathName = path._1
println("pathName:" + pathName)
val stream = path._2
inputStream = stream.open()
val workbook = StreamingReader.builder.rowCacheSize(100).bufferSize(4096).open(inputStream)
val sheet = workbook.getSheetAt(0)
val sheetRs = sheet.iterator()
resultSetToSparkRows(sheetRs, schema)
}
def resultSetToSparkRows(sheetRs: util.Iterator[org.apache.poi.ss.usermodel.Row],
schema: StructType): Iterator[org.apache.spark.sql.Row] = {
new ExcelNextIterator[Row] {
override protected def close(): Unit = {
println("close stream")
inputStream.close()
}
override protected def getNext(): org.apache.spark.sql.Row = {
if (sheetRs.hasNext) {
val r = sheetRs.next()
import scala.collection.JavaConversions._
val cells = new util.ArrayList[String]
for (c <- r) {
cells.add(c.getStringCellValue)
}
Row.fromSeq(cells)
} else {
finished = true
null.asInstanceOf[Row]
}
}
}
}
}
4-测试
object ExcelApp extends App {
println("Application started...")
val conf = new SparkConf().setAppName("spark-custom-datasource")
val spark = SparkSession.builder().config(conf).master("local").getOrCreate()
val df = spark.sqlContext.read.format("com.fayayo.excel.spark.sql.datasource").load("F:\\Excel\\demo03.xlsx")
df.createOrReplaceTempView("test")
spark.sql("select * from test").show()
val prop = new Properties()
prop.put("user", "root")
prop.put("password", "root123")
df.write.mode("append").jdbc("jdbc:mysql://localhost:3306/test", "test.abc", prop)
println("Application Ended...")
}
可以看到数据写入了mysql表中。
源码
github代码地址:https://github.com/lizu18xz/spark-extend-dataSource
打开App,阅读手记