diff --git a/src/main/scala/com/wintersoldier/linkinJMS/writeToMQ.scala b/src/main/scala/com/wintersoldier/linkinJMS/writeToMQ.scala index d33acb0..20e862f 100644 --- a/src/main/scala/com/wintersoldier/linkinJMS/writeToMQ.scala +++ b/src/main/scala/com/wintersoldier/linkinJMS/writeToMQ.scala @@ -5,28 +5,26 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import javax.jms.{Connection, MessageProducer, Session, Topic} -class writeToMQ(implicit spark : SparkSession) extends Serializable -{ - import spark.implicits._ +class writeToMQ(implicit sparkSession: SparkSession, df: DataFrame) extends Serializable { + private var clientId: String = "iAmWritingClient" private var topicName: String = "writing2thisTopic" private var username: String = "username" private var password: String = "password" - private var brokerURL : String = "tcp://localhost:61616" + private var brokerURL: String = "tcp://localhost:61616" private var connection: Connection = _ private var session: Session = _ private var topic: Topic = _ private var producer: MessageProducer = _ private var latestBatchID = -1L - def __init__(brokerURL:String, clientId: String, topicName: String, username: String, password: String): Unit = { + def __init__(brokerURL: String, clientId: String, topicName: String, username: String, password: String): Unit = { this.brokerURL = brokerURL this.clientId = clientId this.topicName = topicName this.username = username this.password = password - createConnections() } def createConnections(): Unit = { @@ -38,6 +36,18 @@ class writeToMQ(implicit spark : SparkSession) extends Serializable println("connection successful") } + def directWrite(): Unit = { + df.writeStream + .option("checkpointLocation", "/home/wintersoldier/Desktop/checkpoint") + .foreachBatch((batch: DataFrame, batchID: Long) => { + println("The batch ID is: " + batchID) + createConnections() + batch.show() + writeOn(batch, batchID) + }) + .start + .awaitTermination() + } def writeOn(batch: DataFrame, batchId: Long): Unit = { if (batchId >= this.latestBatchID) { @@ -53,22 +63,7 @@ class writeToMQ(implicit spark : SparkSession) extends Serializable } } - def directWrite(df: DataFrame): Unit = - { - createConnections() - df.writeStream - .option("checkpointLocation", "/home/wintersoldier/Desktop/checkpoint") - .foreachBatch((batch: DataFrame, batchID: Long) => { - println("The batch ID is: " + batchID) - batch.show() - writeOn(batch, batchID) - }) - .start - .awaitTermination() - } - - def closeConnection(): Unit = - { + def closeConnection(): Unit = { this.producer.close() this.connection.close() this.session.close() diff --git a/src/main/scala/com/wintersoldier/sampleApp.scala b/src/main/scala/com/wintersoldier/sampleApp.scala index d2fc459..3e9d3ac 100644 --- a/src/main/scala/com/wintersoldier/sampleApp.scala +++ b/src/main/scala/com/wintersoldier/sampleApp.scala @@ -1,37 +1,21 @@ package com.wintersoldier - -import com.wintersoldier.linkinJMS.writeToMQ import org.apache.spark.sql.{DataFrame, SparkSession} -//import org.apache.activemq.ActiveMQConnectionFactory -//import org.apache.spark.sql.DataFrame -//import javax.jms.{Connection, MessageProducer, Session, Topic} - -object sampleApp { +object sampleApp +{ implicit val spark: SparkSession = SparkSession .builder() .appName("sampleApp") - .master("local[*]") + .master("local[2]") .getOrCreate() import spark.implicits._ - spark.sparkContext.setCheckpointDir("/home/wintersoldier/Desktop/checkpoint") +// spark.sparkContext.setCheckpointDir("/home/wintersoldier/Desktop/checkpoint") spark.sparkContext.setLogLevel("ERROR") - // Writing to topic related part -// val clientId: String = "iAmWritingClient" -// val topicName: String = "writing2thisTopic" -// val username: String = "username" -// val password: String = "password" -// val connection: Connection = new ActiveMQConnectionFactory("username", "password", "tcp://localhost:61616").createConnection() -// val session: Session = connection.createSession(false, Session.AUTO_ACKNOWLEDGE) -// val topic: Topic = session.createTopic("writing2thisTopic") -// val producer: MessageProducer = session.createProducer(topic) -// private var latestBatchID = -1L - def main(array: Array[String]): Unit = { val brokerUrl: String = "tcp://localhost:61616" @@ -40,67 +24,37 @@ object sampleApp { val password: String = "password" val connectionType: String = "activemq" val clientId: String = "coldplay" - val acknowledge: String = "true" + val acknowledge: String = "false" val readInterval: String = "2000" val queueName : String = "sampleQ" - val df = spark + implicit val df: DataFrame = spark .readStream .format("com.wintersoldier.linkinJMS") .option("connection", connectionType) .option("brokerUrl", brokerUrl) - .option("topic", topicName) +// .option("topic", topicName) + .option("queue", queueName) .option("username", username) .option("password", password) .option("acknowledge", acknowledge) .option("clientId", clientId) .option("readInterval", readInterval) -// .option("queue", queueName) .load() - val ob = new writeToMQ() -// ob.directWrite(df) - //solution https://medium.com/swlh/spark-serialization-errors-e0eebcf0f6e6 df.writeStream - .format("console") - .outputMode("append") +// .option("checkpointLocation", "/home/wintersoldier/Desktop/checkpoint") + .foreachBatch((batch: DataFrame, batchID: Long) => { + println("The batch ID is: " + batchID) + batch.show() + }) .start .awaitTermination() -// df.writeStream -// // .outputMode("append") -// // .format("console") -// .option("checkpointLocation", "/home/wintersoldier/Desktop/checkpoint") -// .foreachBatch((batch: DataFrame, batchID: Long) => { -// println("The batch ID is: " + batchID) -// batch.show() -//// writeOn(batch, batchID) -// }) -// .start -// .awaitTermination() - - // Closing the writing part -// producer.close() -// connection.close() -// session.close() - spark.close() } -// def writeOn(batch: DataFrame, batchId: Long): Unit = { -// if (batchId >= latestBatchID) { -// batch.foreachPartition(rowIter => { -// rowIter.foreach( -// record => { -// val msg = this.session.createTextMessage(record.toString()) -// producer.send(msg) -// }) -// } -// ) -// } -// } - } diff --git a/src/main/scala/org/apache/spark/sql/jms/JmsStreamingSource.scala b/src/main/scala/org/apache/spark/sql/jms/JmsStreamingSource.scala index fce2479..0fcf364 100644 --- a/src/main/scala/org/apache/spark/sql/jms/JmsStreamingSource.scala +++ b/src/main/scala/org/apache/spark/sql/jms/JmsStreamingSource.scala @@ -6,7 +6,7 @@ import org.apache.spark.sql.execution.streaming.{Offset, Source} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.util.LongAccumulator - +import java.util.Random import javax.jms._ import scala.collection.mutable.ListBuffer @@ -17,38 +17,64 @@ class JmsStreamingSource(sqlContext: SQLContext, failOnDataLoss: Boolean ) extends Source { - lazy val RECEIVER_TIMEOUT: Long = parameters.getOrElse("readInterval", "1000").toLong - val clientName: String = parameters.getOrElse("clientId", "client000") - val topicName: String = parameters.getOrElse("topic", "") - val queueName: String = parameters.getOrElse("queue", "") + lazy val RECEIVER_INTERVAL: Long = parameters.getOrElse("readInterval", "1000").toLong + val clientName: String = parameters.getOrElse("clientId", "") + val srcType: String = parameters.getOrElse("mqSrcType", "") + val srcName: String = parameters.getOrElse("mqSrcName", "") + + val connection: Connection = DefaultSource.connectionFactory(parameters).createConnection() + if (clientName != "") + connection.setClientID(clientName) + else + println("[WARN] No 'clientId' passed, this will result a nonDurable connection in case of Topic ") - val connection: Connection = DefaultSource.connectionFactory(parameters).createConnection - connection.setClientID(clientName) + // if transacted is set to true then itt does not matter which ACK you pass to the function + // you should just use session.commit() + // to recover call session.recover()/ session.rollback() + val session: Session = connection.createSession(true, Session.CLIENT_ACKNOWLEDGE) - val session: Session = connection.createSession(false, Session.CLIENT_ACKNOWLEDGE) - val typeOfSub: Int = getTheSub // 1-> Topic 0-> Queue + // storing the session into a variable for later use +// JmsSessionManager.setSession(sess = session) + + val typeOfSub: Int = getTheSub // 1-> Topic 0-> Queue 2-> Not specified if (typeOfSub == 2) { + println("<><><><><><> [ERROR] type 'queue'/'topic' not passed <><><><><><>") throw new IllegalArgumentException } - private val subscriberT: Option[TopicSubscriber] = if (typeOfSub == 1) Some(session.createDurableSubscriber(session.createTopic(topicName), clientName)) else None - private val subscriberQ: Option[MessageConsumer] = if (typeOfSub == 0) Some(session.createConsumer(session.createQueue(queueName))) else None + + private val subscriberT: Option[TopicSubscriber] = if (typeOfSub == 1) Some(getSubscriberT) else None + private val subscriberQ: Option[MessageConsumer] = if (typeOfSub == 0) Some(getConsumerQ) else None + + + private def getConsumerQ: MessageConsumer = { + session.createConsumer(session.createQueue(srcName)) + } + var counter: LongAccumulator = sqlContext.sparkContext.longAccumulator("counter") + def getSubscriberT: TopicSubscriber = { + if (clientName == "") { + val random: Random = new Random() + val rand1 = random.nextInt(10000) + val rand2 = random.nextInt(100000) + session.createDurableSubscriber(session.createTopic(srcName), s"default_client$rand1$rand2") + } else + session.createDurableSubscriber(session.createTopic(srcName), clientName) + } + def getTheSub: Int = { - if (topicName.trim != "") { + if (srcType == "topic") { 1 } - else if (topicName.trim == "" && queueName.trim != "") { + else if (srcType == "queue") { 0 } else { - println("<><><><><><>ERROR: Neither 'queue' name nor 'topic' name passed<><><><><><>") 2 } } - connection.start() override def getOffset: Option[Offset] = { @@ -64,25 +90,27 @@ class JmsStreamingSource(sqlContext: SQLContext, def getTextMsg: TextMessage = { if (typeOfSub == 1) - subscriberT.get.receive(RECEIVER_TIMEOUT).asInstanceOf[TextMessage] + subscriberT.get.receive(RECEIVER_INTERVAL).asInstanceOf[TextMessage] else - subscriberQ.get.receive(RECEIVER_TIMEOUT).asInstanceOf[TextMessage] + subscriberQ.get.receive(RECEIVER_INTERVAL).asInstanceOf[TextMessage] } val textMsg: TextMessage = getTextMsg - // the below code is to test the acknowledgement of individual messages - /* if(textMsg!=null && textMsg.getText == "testingFail") - { - val iota : Int = 3/0 - }*/ + // if(textMsg!=null && textMsg.getText == "testingFail") + // { + // val iota : Int = 3/0 + // } // I am using this line to acknowledge individual textMessages + // shift this if (parameters.getOrElse("acknowledge", "false").toBoolean && textMsg != null) { - textMsg.acknowledge() + // textMsg.acknowledge() // use this with client_ack + session.commit() // use this while doing transacted session } + textMsg match { case null => break = false case _ => messageList += JmsMessage(textMsg) @@ -94,8 +122,9 @@ class JmsStreamingSource(sqlContext: SQLContext, fromString(message.correlationId), fromString(message.jmsType), fromString(message.messageId), - fromString(message.queue) - )) + fromString(message.queue)) + ) + val rdd = sqlContext.sparkContext.parallelize(internalRDD) sqlContext.internalCreateDataFrame(rdd, schema = schema, isStreaming = true) @@ -107,6 +136,7 @@ class JmsStreamingSource(sqlContext: SQLContext, override def stop(): Unit = { session.close() + JmsSessionManager.setSession(null) connection.close() }