事由
上周工作中遇到一个bug,现象是一个spark streaming的job会不定期地hang住,不退出也不继续运行。这个job经是用pyspark写的,以kafka为数据源,会在每个batch结束时将统计结果写入mysql。经过排查,我们在driver进程中发现有有若干线程都出于Sl状态(睡眠状态),进而使用gdb调试发现了一处死锁。
这是MySQLdb库旧版本中的一处bug,在此不再赘述,有兴趣的可以看这个issue。不过这倒是提起了我对另外一件事的兴趣,就是driver进程——严格的说应该是driver进程的python子进程——中的这些线程是从哪来的?当然,这些线程的存在很容易理解,我们开启了spark.streaming.concurrentJobs参数,有多个batch可以同时执行,每个线程对应一个batch。但翻遍pyspark的python代码,都没有找到有相关线程启动的地方,于是简单调研了一下pyspark到底是怎么工作的,做个记录。
本文概括
Py4J的线程模型
pyspark基本原理(driver端)
CPython中的deque的线程安全
涉及软件版本
spark: 2.1.0
py4j: 0.10.4
Py4J
spark是由scala语言编写的,pyspark并没有像豆瓣开源的dpark用python复刻了spark,而只是提供了一层可以与原生JVM通信的python API,Py4J就是python与JVM之间的这座桥梁。这个库分为Java和Python两部分,基本原理是:
Java部分,通过
py4j.GatewayServer监听一个tcp socket(记做server_socket)Python部分,所有对JVM中对象的访问或者方法的调用,都是通过
py4j.JavaGateway向上面这个socket完成的。另外,Python部分在创建
JavaGateway对象时,可以选择同时创建一个CallbackServer,它会在Python这册监听一个tcp socket(记做callback_socket),用来给Java回调Python代码提供一条渠道。Py4J提供了一套文本协议用来在tcp socket间传递命令。
pyspark driver工作流程
首先,一个spark job被提交后,如果被判定这是一个python的job,spark driver会找到相应的入口,即
org.apache.spark.deploy.PythonRunner的main函数,这个函数中会启动GatewayServer
// Launch a Py4J gateway server for the process to connect to; this will let it see our
// Java system properties and such
val gatewayServer = new py4j.GatewayServer(null, 0) val thread = new Thread(new Runnable() { override def run(): Unit = Utils.logUncaughtExceptions {
gatewayServer.start()
}
})
thread.setName("py4j-gateway-init")
thread.setDaemon(true)
thread.start()然后,会创建一个Python子进程来运行我们提交上来的python入口文件,并把刚才
GatewayServer监听的那个端口写入到子进程的环境变量中去(这样Python才知道要通过那个端口访问JVM)
// Launch Python process
val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment()
env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) // pass conf spark.pyspark.python to python process, the only way to pass info to
// python process is through environment variable.
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronizePython子进程这边,我们是通过pyspark提供的python API编写的这个程序,在创建
SparkContext(python)时,会初始化_gateway变量(JavaGateway对象)和_jvm变量(JVMView对象)
@classmethod def _ensure_initialized(cls, instance=None, gateway=None, conf=None): """ Checks whether a SparkContext is initialized or not. Throws error if a SparkContext is already running. """ with SparkContext._lock: if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway(conf) SparkContext._jvm = SparkContext._gateway.jvm if instance: if (SparkContext._active_spark_context and SparkContext._active_spark_context != instance): currentMaster = SparkContext._active_spark_context.master currentAppName = SparkContext._active_spark_context.appName callsite = SparkContext._active_spark_context._callsite # Raise error if there is already a running Spark context raise ValueError( "Cannot run multiple SparkContexts at once; " "existing SparkContext(app=%s, master=%s)" " created by %s at %s:%s " % (currentAppName, currentMaster, callsite.function, callsite.file, callsite.linenum)) else: SparkContext._active_spark_context = instance
其中launch_gateway函数可见pyspark/java_gateway.py。
上面初始化的这个
_jvm对象值得一说,在pyspark中很多对JVM的调用其实都是通过它来进行的,比如很多python种对应的spark对象都有一个_jsc变量,它是JVM中的SparkContext对象在Python中的影子,它是这么初始化的
def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization """ return self._jvm.JavaSparkContext(jconf)
这里_jvm为什么能直接调用JavaSparkContext这个JVM环境中的构造函数呢?我们看JVMView中的__getattr__方法:
def __getattr__(self, name):
if name == UserHelpAutoCompletion.KEY: return UserHelpAutoCompletion()
answer = self._gateway_client.send_command(
proto.REFLECTION_COMMAND_NAME +
proto.REFL_GET_UNKNOWN_SUB_COMMAND_NAME + name + "\n" + self._id + "\n" + proto.END_COMMAND_PART) if answer == proto.SUCCESS_PACKAGE: return JavaPackage(name, self._gateway_client, jvm_id=self._id) elif answer.startswith(proto.SUCCESS_CLASS): return JavaClass(
answer[proto.CLASS_FQN_START:], self._gateway_client) else: raise Py4JError("{0} does not exist in the JVM".format(name))self._gateway_client.send_command其实就是向server_socket发送访问对象请求的命令了,最后根据响应值生成不同类型的影子对象,针对我们这里的JavaSparkContext,就是一个JavaClass对象。这个系列的类型还包括了JavaMember,JavaPackage等等,他们也通过__getattr__来实现Java对象属性访问以及方法的调用。
我们刚才介绍Py4j时说过Python端在创建JavaGateway时,可以选择同时创建一个
CallbackClient,默认情况下,一个普通的pyspark job是不会启动回调服务的,因为用不着,所有的交互都是Python --> JVM这种模式的。那什么时候需要呢?streaming job就需要(具体流程我们稍后介绍),这就(终于!)引出了我们今天主要讨论的Py4J线程模型的问题。
Py4J线程模型
我们已经知道了Python与JVM双方向的通信分别是通过server_socket和callack_socket来完成的,这两个socket的处理模型都是多线程模型,即,每收到一个连接就启动一个线程来处理。我们只看Python --> JVM这条通路的情况,另外一边是一样的
Server端(Java)
protected void processSocket(Socket socket) { try { this.lock.lock(); if(!this.isShutdown) {
socket.setSoTimeout(this.readTimeout);
Py4JServerConnection gatewayConnection = this.createConnection(this.gateway, socket); this.connections.add(gatewayConnection); this.fireConnectionStarted(gatewayConnection);
}
} catch (Exception var6) { this.fireConnectionError(var6);
} finally { this.lock.unlock();
}
}继续看createConnection:
protected Py4JServerConnection createConnection(Gateway gateway, Socket socket) throws IOException {
GatewayConnection connection = new GatewayConnection(gateway, socket, this.customCommands, this.listeners);
connection.startConnection(); return connection;
}其中connection.startConnection其实就是创建了一个新线程,来负责处理这个连接。
Client端(Python)
我们来看GatewayClient中的send_command方法:
def send_command(self, command, retry=True, binary=False):
"""Sends a command to the JVM. This method is not intended to be
called directly by Py4J users. It is usually called by
:class:`JavaMember` instances.
:param command: the `string` command to send to the JVM. The command
must follow the Py4J protocol.
:param retry: if `True`, the GatewayClient tries to resend a message
if it fails.
:param binary: if `True`, we won't wait for a Py4J-protocol response
from the other end; we'll just return the raw connection to the
caller. The caller becomes the owner of the connection, and is
responsible for closing the connection (or returning it this
`GatewayClient` pool using `_give_back_connection`).
:rtype: the `string` answer received from the JVM (The answer follows
the Py4J protocol). The guarded `GatewayConnection` is also returned
if `binary` is `True`.
"""
connection = self._get_connection() try:
response = connection.send_command(command) if binary: return response, self._create_connection_guard(connection) else:
self._give_back_connection(connection) except Py4JNetworkError as pne: if connection:
reset = False
if isinstance(pne.cause, socket.timeout):
reset = True
connection.close(reset) if self._should_retry(retry, connection, pne):
logging.info("Exception while sending command.", exc_info=True)
response = self.send_command(command, binary=binary) else:
logging.exception( "Exception while sending command.")
response = proto.ERROR return response这里这个self._get_connection是这么实现的
def _get_connection(self): if not self.is_connected:
raise Py4JNetworkError("Gateway is not connected.") try:
connection = self.deque.pop()
except IndexError:
connection = self._create_connection() return connection这里使用了一个deque(也就是Python标准库中的collections.deque)来维护一个连接池,如果有空闲的连接,就可以直接使用,如果没有,就新建一个连接。现在问题来了,如果deque不是线程安全的,那么这段代码在多线程环境就会有问题。那么deque是不是线程安全的呢?
deque的线程安全
当然是了,Py4J当然不会犯这样的低级错误,我们看标准库的文档:
Deques support thread-safe, memory efficient appends and pops from either side of the deque with approximately the same O(1) performance in either direction.
是线程安全的,不过措辞有点模糊,没有明确指出哪些方法是线程安全的,不过可以明确的是至少append的pop都是。之所以去查一下,是因为我也有点含糊,因为Python标准库还有另外一个Queue.Queue,在多线程编程中经常使用,肯定是线程安全的,于是很容易误以为deque不是线程安全的,所以我们才要一个新的Queue。这个问题,推荐阅读stackoverflow上Jonathan的这个答案——他的回答不是被采纳的最高票,不过我认为他的回答比高票更有说服力
高票答案一直强调说
deque是线程安全的这个事实是个意外,是CPython中存在GIL造成的,其他Python解释器就不一定遵守。关于这一点我是不认同的,deque在CPython中的实现确实依赖的GIL才变成了线程安全的,但deque的双端append的pop是线程安全的这件事是白纸黑字写在Python文档中的,其他虚拟机的实现必须遵守,否则就不能称之为合格的Python实现。那为什么还要有一个内部显式用了锁来做线程同步的
Queue.Queue呢?Jonathan给出的回答是Queue的put和get可以是blocking的,而deque不行,这样一来,当你需要在多个线程中进行通信时(比如最简单的一个Producer - Consumer模式的实现),Queue往往是最佳选择。
关于deque是否是线程安全这个问题,我将调研的结果写在了这个知乎问题的答案下Python中的deque是线程安全的吗?,就不在赘述了,这篇文章已经太长了。
关于Py4J线程模型的问题,还可以参考官方文档中的解释。
pyspark streaming与CallbackServer
刚才提到,如果是streaming的job,GatewayServer在初始化时会同时创建一个CallbackServer,提供JVM --> Python这条通路。
@classmethod
def _ensure_initialized(cls):
SparkContext._ensure_initialized()
gw = SparkContext._gateway
java_import(gw.jvm, "org.apache.spark.streaming.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") # start callback server
# getattr will fallback to JVM, so we cannot test by hasattr()
if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
gw.callback_server_parameters.eager_load = True
gw.callback_server_parameters.daemonize = True
gw.callback_server_parameters.daemonize_connections = True
gw.callback_server_parameters.port = 0
gw.start_callback_server(gw.callback_server_parameters)
cbport = gw._callback_server.server_socket.getsockname()[1]
gw._callback_server.port = cbport # gateway with real port
gw._python_proxy_port = gw._callback_server.port # get the GatewayServer object in JVM by ID
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) # update the port of CallbackClient with real port
jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) # register serializer for TransformFunction
# it happens before creating SparkContext when loading from checkpointing
cls._transformerSerializer = TransformFunctionSerializer(
SparkContext._active_spark_context, CloudPickleSerializer(), gw)为什么需要这样呢?一个streaming job通常需要调用foreachRDD,并提供一个函数,这个函数会在每个batch被回调:
def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. """ if func.__code__.co_argcount == 1: old_func = func func = lambda t, rdd: old_func(rdd) jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc)
这里,Python函数func被封装成了一个TransformFunction对象,在scala端spark也定义了同样接口一个trait:
/**
* Interface for Python callback function which is used to transform RDDs
*/private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] /**
* Get the failure, if any, in the last call to `call`.
*
* @return the failure message if there was a failure, or `null` if there was no failure.
*/
def getLastFailure: String}这样是Py4J提供的机制,这样就可以让JVM通过这个影子接口回调Python中的对象了,下面就是scala中的callForeachRDD函数,它把PythonTransformFunction又封装了一层成为scala中的TransformFunction, 但不管如何封装,最后都会调用PythonTransformFunction接口中的call方法完成对Python的回调。
/**
* helper function for DStream.foreachRDD(),
* cannot be `foreachRDD`, it will confusing py4j
*/
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
val func = new TransformFunction((pfunc))
jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
}所以,终于要回答这个问题了,我们一开始看到的driver中的多个线程是怎么来的?
python调用
foreachRDD提供一个TranformFunction给scala端scala端调用自己的
foreachRDD进行正常的spark streaming作业由于我们开启了
spark.streaming.concurrentJobs,多个batch可以同时运行,这在scala端是通过线程池来进行的,每个batch都需要回调Python中的TranformFunction,而按照我们之前介绍的Py4J线程模型,多个并发的回调会发现没有可用的socket连接而生成新的,而在CallbackServer(Python)这端,每个新连接都会创建一个新线程来处理。这样就出现了driver的Python进程中出现多个线程的现象。
作者:Garfieldog
链接:https://www.jianshu.com/p/013fe44422c9
随时随地看视频