事由
上周工作中遇到一个bug,现象是一个spark streaming的job会不定期地hang住,不退出也不继续运行。这个job经是用pyspark写的,以kafka为数据源,会在每个batch结束时将统计结果写入mysql。经过排查,我们在driver进程中发现有有若干线程都出于Sl状态(睡眠状态),进而使用gdb调试发现了一处死锁。
这是MySQLdb库旧版本中的一处bug,在此不再赘述,有兴趣的可以看这个issue。不过这倒是提起了我对另外一件事的兴趣,就是driver进程——严格的说应该是driver进程的python子进程——中的这些线程是从哪来的?当然,这些线程的存在很容易理解,我们开启了spark.streaming.concurrentJobs参数,有多个batch可以同时执行,每个线程对应一个batch。但翻遍pyspark的python代码,都没有找到有相关线程启动的地方,于是简单调研了一下pyspark到底是怎么工作的,做个记录。
本文概括
Py4J的线程模型
pyspark基本原理(driver端)
CPython中的deque的线程安全
涉及软件版本
spark: 2.1.0
py4j: 0.10.4
Py4J
spark是由scala语言编写的,pyspark并没有像豆瓣开源的dpark用python复刻了spark,而只是提供了一层可以与原生JVM通信的python API,Py4J就是python与JVM之间的这座桥梁。这个库分为Java和Python两部分,基本原理是:
Java部分,通过
py4j.GatewayServer
监听一个tcp socket(记做server_socket)Python部分,所有对JVM中对象的访问或者方法的调用,都是通过
py4j.JavaGateway
向上面这个socket完成的。另外,Python部分在创建
JavaGateway
对象时,可以选择同时创建一个CallbackServer
,它会在Python这册监听一个tcp socket(记做callback_socket),用来给Java回调Python代码提供一条渠道。Py4J提供了一套文本协议用来在tcp socket间传递命令。
pyspark driver工作流程
首先,一个spark job被提交后,如果被判定这是一个python的job,spark driver会找到相应的入口,即
org.apache.spark.deploy.PythonRunner
的main
函数,这个函数中会启动GatewayServer
// Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such val gatewayServer = new py4j.GatewayServer(null, 0) val thread = new Thread(new Runnable() { override def run(): Unit = Utils.logUncaughtExceptions { gatewayServer.start() } }) thread.setName("py4j-gateway-init") thread.setDaemon(true) thread.start()
然后,会创建一个Python子进程来运行我们提交上来的python入口文件,并把刚才
GatewayServer
监听的那个端口写入到子进程的环境变量中去(这样Python才知道要通过那个端口访问JVM)
// Launch Python process val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment() env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
Python子进程这边,我们是通过pyspark提供的python API编写的这个程序,在创建
SparkContext
(python)时,会初始化_gateway
变量(JavaGateway
对象)和_jvm
变量(JVMView
对象)
@classmethod def _ensure_initialized(cls, instance=None, gateway=None, conf=None): """ Checks whether a SparkContext is initialized or not. Throws error if a SparkContext is already running. """ with SparkContext._lock: if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway(conf) SparkContext._jvm = SparkContext._gateway.jvm if instance: if (SparkContext._active_spark_context and SparkContext._active_spark_context != instance): currentMaster = SparkContext._active_spark_context.master currentAppName = SparkContext._active_spark_context.appName callsite = SparkContext._active_spark_context._callsite # Raise error if there is already a running Spark context raise ValueError( "Cannot run multiple SparkContexts at once; " "existing SparkContext(app=%s, master=%s)" " created by %s at %s:%s " % (currentAppName, currentMaster, callsite.function, callsite.file, callsite.linenum)) else: SparkContext._active_spark_context = instance
其中launch_gateway
函数可见pyspark/java_gateway.py
。
上面初始化的这个
_jvm
对象值得一说,在pyspark中很多对JVM的调用其实都是通过它来进行的,比如很多python种对应的spark对象都有一个_jsc
变量,它是JVM中的SparkContext对象在Python中的影子
,它是这么初始化的
def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization """ return self._jvm.JavaSparkContext(jconf)
这里_jvm
为什么能直接调用JavaSparkContext
这个JVM环境中的构造函数呢?我们看JVMView
中的__getattr__
方法:
def __getattr__(self, name): if name == UserHelpAutoCompletion.KEY: return UserHelpAutoCompletion() answer = self._gateway_client.send_command( proto.REFLECTION_COMMAND_NAME + proto.REFL_GET_UNKNOWN_SUB_COMMAND_NAME + name + "\n" + self._id + "\n" + proto.END_COMMAND_PART) if answer == proto.SUCCESS_PACKAGE: return JavaPackage(name, self._gateway_client, jvm_id=self._id) elif answer.startswith(proto.SUCCESS_CLASS): return JavaClass( answer[proto.CLASS_FQN_START:], self._gateway_client) else: raise Py4JError("{0} does not exist in the JVM".format(name))
self._gateway_client.send_command
其实就是向server_socket
发送访问对象请求的命令了,最后根据响应值生成不同类型的影子
对象,针对我们这里的JavaSparkContext
,就是一个JavaClass
对象。这个系列的类型还包括了JavaMember
,JavaPackage
等等,他们也通过__getattr__
来实现Java对象属性访问以及方法的调用。
我们刚才介绍Py4j时说过Python端在创建JavaGateway时,可以选择同时创建一个
CallbackClient
,默认情况下,一个普通的pyspark job是不会启动回调服务的,因为用不着,所有的交互都是Python --> JVM
这种模式的。那什么时候需要呢?streaming job就需要(具体流程我们稍后介绍),这就(终于!)引出了我们今天主要讨论的Py4J线程模型的问题。
Py4J线程模型
我们已经知道了Python与JVM双方向的通信分别是通过server_socket
和callack_socket
来完成的,这两个socket的处理模型都是多线程模型,即,每收到一个连接就启动一个线程来处理。我们只看Python --> JVM
这条通路的情况,另外一边是一样的
Server端(Java)
protected void processSocket(Socket socket) { try { this.lock.lock(); if(!this.isShutdown) { socket.setSoTimeout(this.readTimeout); Py4JServerConnection gatewayConnection = this.createConnection(this.gateway, socket); this.connections.add(gatewayConnection); this.fireConnectionStarted(gatewayConnection); } } catch (Exception var6) { this.fireConnectionError(var6); } finally { this.lock.unlock(); } }
继续看createConnection
:
protected Py4JServerConnection createConnection(Gateway gateway, Socket socket) throws IOException { GatewayConnection connection = new GatewayConnection(gateway, socket, this.customCommands, this.listeners); connection.startConnection(); return connection; }
其中connection.startConnection
其实就是创建了一个新线程,来负责处理这个连接。
Client端(Python)
我们来看GatewayClient
中的send_command
方法:
def send_command(self, command, retry=True, binary=False): """Sends a command to the JVM. This method is not intended to be called directly by Py4J users. It is usually called by :class:`JavaMember` instances. :param command: the `string` command to send to the JVM. The command must follow the Py4J protocol. :param retry: if `True`, the GatewayClient tries to resend a message if it fails. :param binary: if `True`, we won't wait for a Py4J-protocol response from the other end; we'll just return the raw connection to the caller. The caller becomes the owner of the connection, and is responsible for closing the connection (or returning it this `GatewayClient` pool using `_give_back_connection`). :rtype: the `string` answer received from the JVM (The answer follows the Py4J protocol). The guarded `GatewayConnection` is also returned if `binary` is `True`. """ connection = self._get_connection() try: response = connection.send_command(command) if binary: return response, self._create_connection_guard(connection) else: self._give_back_connection(connection) except Py4JNetworkError as pne: if connection: reset = False if isinstance(pne.cause, socket.timeout): reset = True connection.close(reset) if self._should_retry(retry, connection, pne): logging.info("Exception while sending command.", exc_info=True) response = self.send_command(command, binary=binary) else: logging.exception( "Exception while sending command.") response = proto.ERROR return response
这里这个self._get_connection
是这么实现的
def _get_connection(self): if not self.is_connected: raise Py4JNetworkError("Gateway is not connected.") try: connection = self.deque.pop() except IndexError: connection = self._create_connection() return connection
这里使用了一个deque
(也就是Python标准库中的collections.deque
)来维护一个连接池,如果有空闲的连接,就可以直接使用,如果没有,就新建一个连接。现在问题来了,如果deque不是线程安全的,那么这段代码在多线程环境就会有问题。那么deque是不是线程安全的呢?
deque的线程安全
当然是了,Py4J当然不会犯这样的低级错误,我们看标准库的文档:
Deques support thread-safe, memory efficient appends and pops from either side of the deque with approximately the same O(1) performance in either direction.
是线程安全的,不过措辞有点模糊,没有明确指出哪些方法是线程安全的,不过可以明确的是至少append的pop都是。之所以去查一下,是因为我也有点含糊,因为Python标准库还有另外一个Queue.Queue
,在多线程编程中经常使用,肯定是线程安全的,于是很容易误以为deque不是线程安全的,所以我们才要一个新的Queue。这个问题,推荐阅读stackoverflow上Jonathan的这个答案——他的回答不是被采纳的最高票,不过我认为他的回答比高票更有说服力
高票答案一直强调说
deque是线程安全的
这个事实是个意外,是CPython中存在GIL造成的,其他Python解释器就不一定遵守。关于这一点我是不认同的,deque在CPython中的实现确实依赖的GIL才变成了线程安全的,但deque的双端append的pop是线程安全的
这件事是白纸黑字写在Python文档中的,其他虚拟机的实现必须遵守,否则就不能称之为合格的Python实现。那为什么还要有一个内部显式用了锁来做线程同步的
Queue.Queue
呢?Jonathan给出的回答是Queue
的put
和get
可以是blocking的,而deque
不行,这样一来,当你需要在多个线程中进行通信时(比如最简单的一个Producer - Consumer模式的实现),Queue
往往是最佳选择。
关于deque是否是线程安全这个问题,我将调研的结果写在了这个知乎问题的答案下Python中的deque是线程安全的吗?,就不在赘述了,这篇文章已经太长了。
关于Py4J线程模型的问题,还可以参考官方文档中的解释。
pyspark streaming与CallbackServer
刚才提到,如果是streaming的job,GatewayServer在初始化时会同时创建一个CallbackServer,提供JVM --> Python
这条通路。
@classmethod def _ensure_initialized(cls): SparkContext._ensure_initialized() gw = SparkContext._gateway java_import(gw.jvm, "org.apache.spark.streaming.*") java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") # start callback server # getattr will fallback to JVM, so we cannot test by hasattr() if "_callback_server" not in gw.__dict__ or gw._callback_server is None: gw.callback_server_parameters.eager_load = True gw.callback_server_parameters.daemonize = True gw.callback_server_parameters.daemonize_connections = True gw.callback_server_parameters.port = 0 gw.start_callback_server(gw.callback_server_parameters) cbport = gw._callback_server.server_socket.getsockname()[1] gw._callback_server.port = cbport # gateway with real port gw._python_proxy_port = gw._callback_server.port # get the GatewayServer object in JVM by ID jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) # update the port of CallbackClient with real port jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing cls._transformerSerializer = TransformFunctionSerializer( SparkContext._active_spark_context, CloudPickleSerializer(), gw)
为什么需要这样呢?一个streaming job通常需要调用foreachRDD
,并提供一个函数,这个函数会在每个batch被回调:
def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. """ if func.__code__.co_argcount == 1: old_func = func func = lambda t, rdd: old_func(rdd) jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc)
这里,Python函数func
被封装成了一个TransformFunction
对象,在scala端spark也定义了同样接口一个trait
:
/** * Interface for Python callback function which is used to transform RDDs */private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] /** * Get the failure, if any, in the last call to `call`. * * @return the failure message if there was a failure, or `null` if there was no failure. */ def getLastFailure: String}
这样是Py4J提供的机制,这样就可以让JVM通过这个影子接口
回调Python中的对象了,下面就是scala中的callForeachRDD
函数,它把PythonTransformFunction
又封装了一层成为scala中的TransformFunction
, 但不管如何封装,最后都会调用PythonTransformFunction
接口中的call
方法完成对Python的回调。
/** * helper function for DStream.foreachRDD(), * cannot be `foreachRDD`, it will confusing py4j */ def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) { val func = new TransformFunction((pfunc)) jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) }
所以,终于要回答这个问题了,我们一开始看到的driver中的多个线程是怎么来的?
python调用
foreachRDD
提供一个TranformFunction
给scala端scala端调用自己的
foreachRDD
进行正常的spark streaming作业由于我们开启了
spark.streaming.concurrentJobs
,多个batch可以同时运行,这在scala端是通过线程池来进行的,每个batch都需要回调Python中的TranformFunction
,而按照我们之前介绍的Py4J线程模型,多个并发的回调会发现没有可用的socket连接而生成新的,而在CallbackServer(Python)这端,每个新连接都会创建一个新线程来处理。这样就出现了driver的Python进程中出现多个线程的现象。
作者:Garfieldog
链接:https://www.jianshu.com/p/013fe44422c9