对Spark的那些【魔改】
source link: http://www.jianshu.com/p/6ad4ea093e96?amp%3Butm_medium=referral
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
前言
这两年做 streamingpro 时,不可避免的需要对Spark做大量的增强。就如同我之前吐槽的,Spark大量使用了new进行对象的创建,导致里面的实现基本没有办法进行替换。
比如SparkEnv里有个属性叫closureSerializer,是专门做任务的序列化反序列化的,当然也负责对函数闭包的序列化反序列化。我们看看内部是怎么实现的:
val serializer = instantiateClassFromConf[Serializer]( "spark.serializer", "org.apache.spark.serializer.JavaSerializer") logDebug(s"Using serializer: ${serializer.getClass}") val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) val closureSerializer = new JavaSerializer(conf) val envInstance = new SparkEnv( ..... closureSerializer, ....
这里直接new了一个JavaSerializer,并不能做配置。如果不改源码,你没有任何办法可以替换掉掉这个实现。同理,如果我想替换掉Executor的实现,基本也是不可能的。
今年有两个大地方涉及到了对Spark的【魔改】,也就是不通过改源码,使用原有发型包,通过添加新代码的方式来对Spark进行增强。
二层RPC的支持
我们知道,在Spark里,我们只能通过Task才能touch到Executor。现有的API你是没办法直接操作到所有或者指定部分的Executor。比如,我希望所有Executor都加载一个资源文件,现在是没办法做到的。为了能够对Executor进行直接的操作,那就需要建立一个新的通讯层。那具体怎么做呢?
首先,在Driver端建立一个Backend,这个比较简单,
class PSDriverBackend(sc: SparkContext) extends Logging { val conf = sc.conf var psDriverRpcEndpointRef: RpcEndpointRef = null def createRpcEnv = { val isDriver = sc.env.executorId == SparkContext.DRIVER_IDENTIFIER val bindAddress = sc.conf.get(DRIVER_BIND_ADDRESS) val advertiseAddress = sc.conf.get(DRIVER_HOST_ADDRESS) var port = sc.conf.getOption("spark.ps.driver.port").getOrElse("7777").toInt val ioEncryptionKey = if (sc.conf.get(IO_ENCRYPTION_ENABLED)) { Some(CryptoStreamUtils.createKey(sc.conf)) } else { None } logInfo(s"setup ps driver rpc env: ${bindAddress}:${port} clientMode=${!isDriver}") var createSucess = false var count = 0 val env = new AtomicReference[RpcEnv]() while (!createSucess && count < 10) { try { env.set(RpcEnv.create("PSDriverEndpoint", bindAddress, port, sc.conf, sc.env.securityManager, clientMode = !isDriver)) createSucess = true } catch { case e: Exception => logInfo("fail to create rpcenv", e) count += 1 port += 1 } } if (env.get() == null) { logError(s"fail to create rpcenv finally with attemp ${count} ") } env.get() } def start() = { val env = createRpcEnv val pSDriverBackend = new PSDriverEndpoint(sc, env) psDriverRpcEndpointRef = env.setupEndpoint("ps-driver-endpoint", pSDriverBackend) } }
这样,你可以理解为在Driver端启动了一个PRC Server。要运行这段代码也非常简单,直接在主程序里运行即可:
// parameter server should be enabled by default if (!params.containsKey("streaming.ps.enable") || params.get("streaming.ps.enable").toString.toBoolean) { logger.info("ps enabled...") if (ss.sparkContext.isLocal) { localSchedulerBackend = new LocalPSSchedulerBackend(ss.sparkContext) localSchedulerBackend.start() } else { logger.info("start PSDriverBackend") psDriverBackend = new PSDriverBackend(ss.sparkContext) psDriverBackend.start() } }
这里我们需要实现local模式和cluster模式两种。
Driver启动了一个PRC Server,那么Executor端如何启动呢?Executor端似乎没有任何一个地方可以让我启动一个PRC Server? 其实有的,只是非常trick,我们知道Spark是允许自定义Metrics的,并且会调用用户实现的metric特定的方法,我们只要开发一个metric Sink,在里面启动RPC Server,骗过Spark即可。具体时下如下:
class PSServiceSink(val property: Properties, val registry: MetricRegistry, securityMgr: SecurityManager) extends Sink with Logging { def env = SparkEnv.get var psDriverUrl: String = null var psExecutorId: String = null var hostname: String = null var cores: Int = 0 var appId: String = null val psDriverPort = 7777 var psDriverHost: String = null var workerUrl: Option[String] = None val userClassPath = new mutable.ListBuffer[URL]() def parseArgs = { //val runtimeMxBean = ManagementFactory.getRuntimeMXBean(); //var argv = runtimeMxBean.getInputArguments.toList var argv = System.getProperty("sun.java.command").split("\\s+").toList ..... psDriverHost = host psDriverUrl = "spark://ps-driver-endpoint@" + psDriverHost + ":" + psDriverPort } parseArgs def createRpcEnv = { val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER val bindAddress = hostname val advertiseAddress = "" val port = env.conf.getOption("spark.ps.executor.port").getOrElse("0").toInt val ioEncryptionKey = if (env.conf.get(IO_ENCRYPTION_ENABLED)) { Some(CryptoStreamUtils.createKey(env.conf)) } else { None } //logInfo(s"setup ps driver rpc env: ${bindAddress}:${port} clientMode=${!isDriver}") RpcEnv.create("PSExecutorBackend", bindAddress, port, env.conf, env.securityManager, clientMode = !isDriver) } override def start(): Unit = { new Thread(new Runnable { override def run(): Unit = { logInfo(s"delay PSExecutorBackend 3s") Thread.sleep(3000) logInfo(s"start PSExecutor;env:${env}") if (env.executorId != SparkContext.DRIVER_IDENTIFIER) { val rpcEnv = createRpcEnv val pSExecutorBackend = new PSExecutorBackend(env, rpcEnv, psDriverUrl, psExecutorId, hostname, cores) PSExecutorBackend.executorBackend = Some(pSExecutorBackend) rpcEnv.setupEndpoint("ps-executor-endpoint", pSExecutorBackend) } } }).start() } ... }
到这里,我们就能成功启动RPC Server,并且连接上Driver中的PRC Server。现在,你就可以在不修改Spark 源码的情况下,尽情的写通讯相关的代码了,让你可以更好的控制Executor。
比如在PSExecutorBackend 实现如下代码:
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case Message.TensorFlowModelClean(modelPath) => { logInfo("clean tensorflow model") TFModelLoader.close(modelPath) context.reply(true) } case Message.CopyModelToLocal(modelPath, destPath) => { logInfo(s"copying model: ${modelPath} -> ${destPath}") HDFSOperator.copyToLocalFile(destPath, modelPath, true) context.reply(true) } }
接着你就可以在Spark里写如下的代码调用了:
val psDriverBackend = runtime.asInstanceOf[SparkRuntime].psDriverBackend psDriverBackend.psDriverRpcEndpointRef.send(Message.TensorFlowModelClean("/tmp/ok"))
是不是很酷。
修改闭包的序列化方式
Spark的任务调度开销非常大。对于一个复杂的任务,业务逻辑代码执行时间大约是3-7ms,但是整个spark运行的开销大概是1.3s左右。
经过详细dig发现,sparkContext里RDD转化时,会对函数进行clean操作,clean操作的过程中,默认会检查是不是能序列化(就是序列化一遍,没抛出异常就算可以序列化)。而序列化成本相当高(默认使用的JavaSerializer并且对于函数和任务序列化,是不可更改的),单次序列化耗时就达到200ms左右,在local模式下对其进行优化,可以减少600ms左右的请求时间。
当然,需要申明的是,这个是针对local模式进行修改的。那具体怎么做的呢?
我们先看看Spark是怎么调用序列化函数的,首先在SparkContext里,clean函数是这样的:
private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { ClosureCleaner.clean(f, checkSerializable) f }
调用的是ClosureCleaner.clean方法,该方法里是这么调用学序列化的:
try { if (SparkEnv.get != null) { SparkEnv.get.closureSerializer.newInstance().serialize(func) } } catch { case ex: Exception => throw new SparkException("Task not serializable", ex) }
SparkEnv是在SparkContext初始化的时候创建的,该对象里面包含了closureSerializer,该对象通过new JavaSerializer创建。既然序列化太慢,又因为我们其实是在Local模式下,本身是可以不需要序列化的,所以我们这里想办法把closureSerializer的实现替换掉。正如我们前面吐槽,因为在Spark代码里写死了,没有暴露任何自定义的可能性,所以我们又要魔改一下了。
首先,我们新建一个SparkEnv的子类:
class WowSparkEnv( ....) extends SparkEnv(
接着实现一个自定义的Serializer:
class LocalNonOpSerializerInstance(javaD: SerializerInstance) extends SerializerInstance { private def isClosure(cls: Class[_]): Boolean = { cls.getName.contains("$anonfun$") } override def serialize[T: ClassTag](t: T): ByteBuffer = { if (isClosure(t.getClass)) { val uuid = UUID.randomUUID().toString LocalNonOpSerializerInstance.maps.put(uuid, t.asInstanceOf[AnyRef]) ByteBuffer.wrap(uuid.getBytes()) } else { javaD.serialize(t) } } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val s = StandardCharsets.UTF_8.decode(bytes).toString() if (LocalNonOpSerializerInstance.maps.containsKey(s)) { LocalNonOpSerializerInstance.maps.remove(s).asInstanceOf[T] } else { bytes.flip() javaD.deserialize(bytes) } } override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val s = StandardCharsets.UTF_8.decode(bytes).toString() if (LocalNonOpSerializerInstance.maps.containsKey(s)) { LocalNonOpSerializerInstance.maps.remove(s).asInstanceOf[T] } else { bytes.flip() javaD.deserialize(bytes, loader) } } override def serializeStream(s: OutputStream): SerializationStream = { javaD.serializeStream(s) } override def deserializeStream(s: InputStream): DeserializationStream = { javaD.deserializeStream(s) }
接着我们需要再封装一个LocalNonOpSerializer,
class LocalNonOpSerializer(conf: SparkConf) extends Serializer with Externalizable { val javaS = new JavaSerializer(conf) override def newInstance(): SerializerInstance = { new LocalNonOpSerializerInstance(javaS.newInstance()) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { javaS.writeExternal(out) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { javaS.readExternal(in) } }
现在,万事俱备,只欠东风了,我们怎么才能把这些代码让Spark运行起来。具体做法非常魔幻,实现一个enhance类:
def enhanceSparkEnvForAPIService(session: SparkSession) = { val env = SparkEnv.get //创建一个新的WowSparkEnv对象,然后将里面的Serializer替换成我们自己的LocalNonOpSerializer val wowEnv = new WowSparkEnv( ..... new LocalNonOpSerializer(env.conf): Serializer, ....) // 将SparkEnv object里的实例替换成我们的 //WowSparkEnv SparkEnv.set(wowEnv) //但是很多地方在SparkContext启动后都已经在使用之前就已经生成的SparkEnv,我们需要做些调整 //我们先把之前已经启动的LocalSchedulerBackend里的scheduer停掉 val localScheduler = session.sparkContext.schedulerBackend.asInstanceOf[LocalSchedulerBackend] val scheduler = ReflectHelper.field(localScheduler, "scheduler") val totalCores = localScheduler.totalCores localScheduler.stop() //创建一个新的LocalSchedulerBackend val wowLocalSchedulerBackend = new WowLocalSchedulerBackend(session.sparkContext.getConf, scheduler.asInstanceOf[TaskSchedulerImpl], totalCores) wowLocalSchedulerBackend.start() //把SparkContext里的_schedulerBackend替换成我们的实现 ReflectHelper.field(session.sparkContext, "_schedulerBackend", wowLocalSchedulerBackend) }
完工。
其实还有很多
比如在Spark里,Python Worker默认一分钟没有被使用是会被杀死的,但是在StreamingPro里,这些python worker因为都要加载模型,所以启动成本是非常高的,杀了之后再启动就没办法忍受了,通过类似的方式进行魔改,从而使得空闲时间是可配置的。如果大家感兴趣,可以翻看StreamingPro相关代码。
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK