Skip to content

Commit

Permalink
Fix the source task generation (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
elakito authored Oct 6, 2021
1 parent f3d2805 commit 0b3bb93
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 126 deletions.
58 changes: 19 additions & 39 deletions src/main/scala/com/sap/kafka/connect/source/GenericSourceTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ import com.sap.kafka.connect.source.querier.{BulkTableQuerier, IncrColTableQueri
import com.sap.kafka.utils.ExecuteWithExceptions
import org.apache.kafka.common.config.ConfigException
import org.apache.kafka.common.utils.{SystemTime, Time}
import org.apache.kafka.connect.errors.ConnectException
import org.apache.kafka.connect.source.{SourceRecord, SourceTask}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._

import scala.collection.mutable

abstract class GenericSourceTask extends SourceTask {
protected var configRawProperties: Option[util.Map[String, String]] = None
protected var config: BaseConfig = _
private val tableQueue = new mutable.Queue[TableQuerier]()
protected var time: Time = new SystemTime()
Expand All @@ -28,6 +29,7 @@ abstract class GenericSourceTask extends SourceTask {

override def start(props: util.Map[String, String]): Unit = {
log.info("Read records from HANA")
configRawProperties = Some(props)

ExecuteWithExceptions[Unit, ConfigException, HANAConfigMissingException] (
new HANAConfigMissingException("Couldn't start HANASourceTask due to configuration error")) { () =>
Expand All @@ -40,54 +42,31 @@ abstract class GenericSourceTask extends SourceTask {

val topics = config.topics

var tables: List[(String, String)] = Nil
if (topics.forall(topic => config.topicProperties(topic).keySet.contains("table.name"))) {
tables = topics.map(topic =>
(config.topicProperties(topic)("table.name"), topic))
}

var query: List[(String, String)] = Nil
if (topics.forall(topic => config.topicProperties(topic).keySet.contains("query"))) {
query = topics.map(topic =>
(config.topicProperties(topic)("query"), topic))
}

if (tables.isEmpty && query.isEmpty) {
throw new ConnectException("Invalid configuration: each HANASourceTask must have" +
" one table assigned to it")
}

val queryMode = config.queryMode

val tableOrQueryInfos = queryMode match {
case BaseConfigConstants.QUERY_MODE_TABLE =>
getTables(tables)
case BaseConfigConstants.QUERY_MODE_SQL =>
getQueries(query)
}
val tableOrQueryInfos = getTableOrQueryInfos()

val mode = config.mode
var offsets: util.Map[util.Map[String, String], util.Map[String, Object]] = null
var incrementingCols: List[String] = List()

if (mode.equals(BaseConfigConstants.MODE_INCREMENTING)) {
val partitions =
new util.ArrayList[util.Map[String, String]](tables.length)
new util.ArrayList[util.Map[String, String]](tableOrQueryInfos.length)

queryMode match {
case BaseConfigConstants.QUERY_MODE_TABLE =>
tableOrQueryInfos.foreach(tableInfo => {
val partition = new util.HashMap[String, String]()
partition.put(SourceConnectorConstants.TABLE_NAME_KEY, tableInfo._3)
partition.put(SourceConnectorConstants.TABLE_NAME_KEY, s"${tableInfo._1}${tableInfo._2}")
partitions.add(partition)
incrementingCols :+= config.topicProperties(tableInfo._4)("incrementing.column.name")
incrementingCols :+= config.topicProperties(tableInfo._3)("incrementing.column.name")
})
case BaseConfigConstants.QUERY_MODE_SQL =>
tableOrQueryInfos.foreach(queryInfo => {
val partition = new util.HashMap[String, String]()
partition.put(SourceConnectorConstants.QUERY_NAME_KEY, queryInfo._1)
partitions.add(partition)
incrementingCols :+= config.topicProperties(queryInfo._4)("incrementing.column.name")
incrementingCols :+= config.topicProperties(queryInfo._3)("incrementing.column.name")
})

}
Expand All @@ -99,7 +78,7 @@ abstract class GenericSourceTask extends SourceTask {
val partition = new util.HashMap[String, String]()
queryMode match {
case BaseConfigConstants.QUERY_MODE_TABLE =>
partition.put(SourceConnectorConstants.TABLE_NAME_KEY, tableOrQueryInfo._3)
partition.put(SourceConnectorConstants.TABLE_NAME_KEY, s"${tableOrQueryInfo._1}${tableOrQueryInfo._2}")
case BaseConfigConstants.QUERY_MODE_SQL =>
partition.put(SourceConnectorConstants.QUERY_NAME_KEY, tableOrQueryInfo._1)
case _ =>
Expand All @@ -108,13 +87,11 @@ abstract class GenericSourceTask extends SourceTask {

val offset = if (offsets == null) null else offsets.get(partition)

val topic = tableOrQueryInfo._4

if (mode.equals(BaseConfigConstants.MODE_BULK)) {
tableQueue += new BulkTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, topic,
tableQueue += new BulkTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, tableOrQueryInfo._3,
config, Some(jdbcClient))
} else if (mode.equals(BaseConfigConstants.MODE_INCREMENTING)) {
tableQueue += new IncrColTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, topic,
tableQueue += new IncrColTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, tableOrQueryInfo._3,
incrementingCols(count),
if (offset == null) null else offset.asScala.toMap,
config, Some(jdbcClient))
Expand Down Expand Up @@ -182,11 +159,14 @@ abstract class GenericSourceTask extends SourceTask {
null
}

protected def getTables(tables: List[Tuple2[String, String]])
: List[Tuple4[String, Int, String, String]]

protected def getQueries(query: List[(String, String)])
: List[Tuple4[String, Int, String, String]]
def getTableOrQueryInfos(): List[Tuple3[String, Int, String]] = {
val props = configRawProperties.get
props.asScala.filter(p => p._1.startsWith("_tqinfos.") && p._1.endsWith(".name")).map(
t => Tuple3(
t._2,
props.get(t._1.replace("name", "partition")).toInt,
props.get(t._1.replace("name", "topic")))).toList
}

protected def createJdbcClient(): HANAJdbcClient
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,60 @@
package com.sap.kafka.connect.source.hana

import com.sap.kafka.client.hana.{HANAConfigInvalidInputException, HANAConfigMissingException, HANAJdbcClient}
import com.sap.kafka.connect.config.BaseConfigConstants
import com.sap.kafka.connect.config.hana.{HANAConfig, HANAParameters}
import com.sap.kafka.utils.ExecuteWithExceptions

import java.util
import org.apache.kafka.common.config.ConfigDef
import org.apache.kafka.common.config.{ConfigDef, ConfigException}
import org.apache.kafka.connect.connector.Task
import org.apache.kafka.connect.source.SourceConnector
import org.apache.kafka.connect.errors.ConnectException
import org.apache.kafka.connect.source.{SourceConnector, SourceConnectorContext}

import scala.collection.JavaConverters._

class HANASourceConnector extends SourceConnector {
private var configProperties: Option[util.Map[String, String]] = None

private var configRawProperties: Option[util.Map[String, String]] = None
private var hanaClient: HANAJdbcClient = _
private var tableOrQueryInfos: List[Tuple3[String, Int, String]] = _
private var configProperties: HANAConfig = _
override def context(): SourceConnectorContext = super.context()
override def version(): String = getClass.getPackage.getImplementationVersion

override def start(properties: util.Map[String, String]): Unit = {
configProperties = Some(properties)
configRawProperties = Some(properties)
configProperties = HANAParameters.getConfig(properties)
hanaClient = new HANAJdbcClient(configProperties)

val topics = configProperties.topics
var tables: List[(String, String)] = Nil
if (topics.forall(topic => configProperties.topicProperties(topic).keySet.contains("table.name"))) {
tables = topics.map(topic =>
(configProperties.topicProperties(topic)("table.name"), topic))
}
var query: List[(String, String)] = Nil
if (topics.forall(topic => configProperties.topicProperties(topic).keySet.contains("query"))) {
query = topics.map(topic =>
(configProperties.topicProperties(topic)("query"), topic))
}

if (tables.isEmpty && query.isEmpty) {
throw new ConnectException("Invalid configuration: each HANAConnector must have one table or query associated")
}

tableOrQueryInfos = configProperties.queryMode match {
case BaseConfigConstants.QUERY_MODE_TABLE =>
getTables(hanaClient, tables)
case BaseConfigConstants.QUERY_MODE_SQL =>
getQueries(query)
}
}

override def taskClass(): Class[_ <: Task] = classOf[HANASourceTask]

override def taskConfigs(maxTasks: Int): util.List[util.Map[String, String]] = {
(1 to maxTasks).map(c => configProperties.get).toList.asJava
val tableOrQueryGroups = createTableOrQueryGroups(tableOrQueryInfos, maxTasks)
createTaskConfigs(tableOrQueryGroups, configRawProperties.get).asJava
}

override def stop(): Unit = {
Expand All @@ -29,4 +64,92 @@ class HANASourceConnector extends SourceConnector {
override def config(): ConfigDef = {
new ConfigDef
}

private def getTables(hanaClient: HANAJdbcClient, tables: List[Tuple2[String, String]]) : List[Tuple3[String, Int, String]] = {
val connection = hanaClient.getConnection

// contains fullTableName, partitionNum, topicName
var tableInfos: List[Tuple3[String, Int, String]] = List()
val noOfTables = tables.size
var tablecount = 1

var stmtToFetchPartitions = s"SELECT SCHEMA_NAME, TABLE_NAME, PARTITION FROM SYS.M_CS_PARTITIONS WHERE "
tables.foreach(table => {
if (!(configProperties.topicProperties(table._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE)) {
table._1 match {
case BaseConfigConstants.TABLE_NAME_FORMAT(schema, tablename) =>
stmtToFetchPartitions += s"(SCHEMA_NAME = '$schema' AND TABLE_NAME = '$tablename')"

if (tablecount < noOfTables) {
stmtToFetchPartitions += " OR "
}
tablecount = tablecount + 1
case _ =>
throw new HANAConfigInvalidInputException("The table name is invalid. Does not follow naming conventions")
}
}
})

if (tablecount > 1) {
val stmt = connection.createStatement()
val partitionRs = stmt.executeQuery(stmtToFetchPartitions)

while (partitionRs.next()) {
val tableName = "\"" + partitionRs.getString(1) + "\".\"" + partitionRs.getString(2) + "\""
tableInfos :+= Tuple3(tableName, partitionRs.getInt(3),
tables.filter(table => table._1 == tableName).map(table => table._2).head.toString)
}
}

// fill tableInfo for tables whose entry is not in M_CS_PARTITIONS
val tablesInInfo = tableInfos.map(tableInfo => tableInfo._1)
val tablesToBeAdded = tables.filterNot(table => tablesInInfo.contains(table._1))

tablesToBeAdded.foreach(tableToBeAdded => {
if (configProperties.topicProperties(tableToBeAdded._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE) {
tableInfos :+= Tuple3(getTableName(tableToBeAdded._1)._2, 0, tableToBeAdded._2)
} else {
tableInfos :+= Tuple3(tableToBeAdded._1, 0, tableToBeAdded._2)
}
})

tableInfos
}

private def getQueries(queryTuple: List[(String, String)]): List[(String, Int, String)] =
queryTuple.map(query => (query._1, 0, query._2))

private def createTableOrQueryGroups(tableOrQueryInfos: List[Tuple3[String, Int, String]], count: Int)
: List[List[Tuple3[String, Int, String]]] = {
val groupSize = count match {
case c if c > tableOrQueryInfos.size => 1
case _ => ((tableOrQueryInfos.size + count - 1) / count)
}
tableOrQueryInfos.grouped(groupSize).toList
}

private def createTaskConfigs(tableOrQueryGroups: List[List[Tuple3[String, Int, String]]], config: java.util.Map[String, String])
: List[java.util.Map[String, String]] = {
tableOrQueryGroups.map(g => {
var gconfig = new java.util.HashMap[String,String](config)
for ((t, i) <- g.zipWithIndex) {
gconfig.put(s"_tqinfos.$i.name", t._1)
gconfig.put(s"_tqinfos.$i.partition", t._2.toString)
gconfig.put(s"_tqinfos.$i.topic", t._3)
}
gconfig
})
}

private def getTableName(tableName: String): (Option[String], String) = {
tableName match {
case BaseConfigConstants.TABLE_NAME_FORMAT(schema, table) =>
(Some(schema), table)
case BaseConfigConstants.COLLECTION_NAME_FORMAT(table) =>
(None, table)
case _ =>
throw new HANAConfigInvalidInputException(s"The table name mentioned in `{topic}.table.name` is invalid." +
s" Does not follow naming conventions")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,79 +16,11 @@ class HANASourceTask extends GenericSourceTask {

override def version(): String = getClass.getPackage.getImplementationVersion


override def createJdbcClient(): HANAJdbcClient = {
config match {
case hanaConfig: HANAConfig => new HANAJdbcClient(hanaConfig)
case _ => throw new RuntimeException("Cannot create HANA Jdbc Client")
}
}

override def getTables(tables: List[Tuple2[String, String]])
: List[Tuple4[String, Int, String, String]] = {
val connection = jdbcClient.getConnection

// contains fullTableName, partitionNum, fullTableName + partitionNum, topicName
var tableInfos: List[Tuple4[String, Int, String, String]] = List()
val noOfTables = tables.size
var tablecount = 1

var stmtToFetchPartitions = s"SELECT SCHEMA_NAME, TABLE_NAME, PARTITION FROM SYS.M_CS_PARTITIONS WHERE "
tables.foreach(table => {
if (!(config.topicProperties(table._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE)) {
table._1 match {
case BaseConfigConstants.TABLE_NAME_FORMAT(schema, tablename) =>
stmtToFetchPartitions += s"(SCHEMA_NAME = '$schema' AND TABLE_NAME = '$tablename')"

if (tablecount < noOfTables) {
stmtToFetchPartitions += " OR "
}
tablecount = tablecount + 1
case _ =>
throw new HANAConfigInvalidInputException("The table name is invalid. Does not follow naming conventions")
}
}
})

if (tablecount > 1) {
val stmt = connection.createStatement()
val partitionRs = stmt.executeQuery(stmtToFetchPartitions)

while (partitionRs.next()) {
val tableName = "\"" + partitionRs.getString(1) + "\".\"" + partitionRs.getString(2) + "\""
tableInfos :+= Tuple4(tableName, partitionRs.getInt(3), tableName + partitionRs.getInt(3),
tables.filter(table => table._1 == tableName)
.map(table => table._2).head.toString)
}
}

// fill tableInfo for tables whose entry is not in M_CS_PARTITIONS
val tablesInInfo = tableInfos.map(tableInfo => tableInfo._1)
val tablesToBeAdded = tables.filterNot(table => tablesInInfo.contains(table._1))

tablesToBeAdded.foreach(tableToBeAdded => {
if (config.topicProperties(tableToBeAdded._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE) {
tableInfos :+= Tuple4(getTableName(tableToBeAdded._1)._2, 0, getTableName(tableToBeAdded._1)._2 + "0", tableToBeAdded._2)
} else {
tableInfos :+= Tuple4(tableToBeAdded._1, 0, tableToBeAdded._1 + "0", tableToBeAdded._2)
}
})

tableInfos
}

override protected def getQueries(queryTuple: List[(String, String)]): List[(String, Int, String, String)] =
queryTuple.map(query => (query._1, 0, null, query._2))

private def getTableName(tableName: String): (Option[String], String) = {
tableName match {
case BaseConfigConstants.TABLE_NAME_FORMAT(schema, table) =>
(Some(schema), table)
case BaseConfigConstants.COLLECTION_NAME_FORMAT(table) =>
(None, table)
case _ =>
throw new HANAConfigInvalidInputException(s"The table name mentioned in `{topic}.table.name` is invalid." +
s" Does not follow naming conventions")
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.sap.kafka.connect.source

import com.sap.kafka.client.MetaSchema
import com.sap.kafka.connect.source.hana.HANASourceConnector
import org.apache.kafka.connect.data.Schema.Type
import org.apache.kafka.connect.data.{Field, Schema, Struct}
import org.apache.kafka.connect.source.SourceRecord
Expand All @@ -11,11 +12,14 @@ class HANASourceTaskConversionTest extends HANASourceTaskTestBase {

override def beforeAll(): Unit = {
super.beforeAll()
task.start(singleTableConfig())
connector = new HANASourceConnector
connector.start(singleTableConfig())
task.start(connector.taskConfigs(1).get(0))
}

override def afterAll(): Unit = {
task.stop()
connector.stop()
super.afterAll()
}

Expand Down
Loading

0 comments on commit 0b3bb93

Please sign in to comment.