diff --git a/src/main/scala/com/sap/kafka/connect/source/GenericSourceTask.scala b/src/main/scala/com/sap/kafka/connect/source/GenericSourceTask.scala index affc937..5619e88 100644 --- a/src/main/scala/com/sap/kafka/connect/source/GenericSourceTask.scala +++ b/src/main/scala/com/sap/kafka/connect/source/GenericSourceTask.scala @@ -9,14 +9,15 @@ import com.sap.kafka.connect.source.querier.{BulkTableQuerier, IncrColTableQueri import com.sap.kafka.utils.ExecuteWithExceptions import org.apache.kafka.common.config.ConfigException import org.apache.kafka.common.utils.{SystemTime, Time} -import org.apache.kafka.connect.errors.ConnectException import org.apache.kafka.connect.source.{SourceRecord, SourceTask} import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ + import scala.collection.mutable abstract class GenericSourceTask extends SourceTask { + protected var configRawProperties: Option[util.Map[String, String]] = None protected var config: BaseConfig = _ private val tableQueue = new mutable.Queue[TableQuerier]() protected var time: Time = new SystemTime() @@ -28,6 +29,7 @@ abstract class GenericSourceTask extends SourceTask { override def start(props: util.Map[String, String]): Unit = { log.info("Read records from HANA") + configRawProperties = Some(props) ExecuteWithExceptions[Unit, ConfigException, HANAConfigMissingException] ( new HANAConfigMissingException("Couldn't start HANASourceTask due to configuration error")) { () => @@ -40,31 +42,8 @@ abstract class GenericSourceTask extends SourceTask { val topics = config.topics - var tables: List[(String, String)] = Nil - if (topics.forall(topic => config.topicProperties(topic).keySet.contains("table.name"))) { - tables = topics.map(topic => - (config.topicProperties(topic)("table.name"), topic)) - } - - var query: List[(String, String)] = Nil - if (topics.forall(topic => config.topicProperties(topic).keySet.contains("query"))) { - query = topics.map(topic => - (config.topicProperties(topic)("query"), topic)) - } - - if (tables.isEmpty && query.isEmpty) { - throw new ConnectException("Invalid configuration: each HANASourceTask must have" + - " one table assigned to it") - } - val queryMode = config.queryMode - - val tableOrQueryInfos = queryMode match { - case BaseConfigConstants.QUERY_MODE_TABLE => - getTables(tables) - case BaseConfigConstants.QUERY_MODE_SQL => - getQueries(query) - } + val tableOrQueryInfos = getTableOrQueryInfos() val mode = config.mode var offsets: util.Map[util.Map[String, String], util.Map[String, Object]] = null @@ -72,22 +51,22 @@ abstract class GenericSourceTask extends SourceTask { if (mode.equals(BaseConfigConstants.MODE_INCREMENTING)) { val partitions = - new util.ArrayList[util.Map[String, String]](tables.length) + new util.ArrayList[util.Map[String, String]](tableOrQueryInfos.length) queryMode match { case BaseConfigConstants.QUERY_MODE_TABLE => tableOrQueryInfos.foreach(tableInfo => { val partition = new util.HashMap[String, String]() - partition.put(SourceConnectorConstants.TABLE_NAME_KEY, tableInfo._3) + partition.put(SourceConnectorConstants.TABLE_NAME_KEY, s"${tableInfo._1}${tableInfo._2}") partitions.add(partition) - incrementingCols :+= config.topicProperties(tableInfo._4)("incrementing.column.name") + incrementingCols :+= config.topicProperties(tableInfo._3)("incrementing.column.name") }) case BaseConfigConstants.QUERY_MODE_SQL => tableOrQueryInfos.foreach(queryInfo => { val partition = new util.HashMap[String, String]() partition.put(SourceConnectorConstants.QUERY_NAME_KEY, queryInfo._1) partitions.add(partition) - incrementingCols :+= config.topicProperties(queryInfo._4)("incrementing.column.name") + incrementingCols :+= config.topicProperties(queryInfo._3)("incrementing.column.name") }) } @@ -99,7 +78,7 @@ abstract class GenericSourceTask extends SourceTask { val partition = new util.HashMap[String, String]() queryMode match { case BaseConfigConstants.QUERY_MODE_TABLE => - partition.put(SourceConnectorConstants.TABLE_NAME_KEY, tableOrQueryInfo._3) + partition.put(SourceConnectorConstants.TABLE_NAME_KEY, s"${tableOrQueryInfo._1}${tableOrQueryInfo._2}") case BaseConfigConstants.QUERY_MODE_SQL => partition.put(SourceConnectorConstants.QUERY_NAME_KEY, tableOrQueryInfo._1) case _ => @@ -108,13 +87,11 @@ abstract class GenericSourceTask extends SourceTask { val offset = if (offsets == null) null else offsets.get(partition) - val topic = tableOrQueryInfo._4 - if (mode.equals(BaseConfigConstants.MODE_BULK)) { - tableQueue += new BulkTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, topic, + tableQueue += new BulkTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, tableOrQueryInfo._3, config, Some(jdbcClient)) } else if (mode.equals(BaseConfigConstants.MODE_INCREMENTING)) { - tableQueue += new IncrColTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, topic, + tableQueue += new IncrColTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, tableOrQueryInfo._3, incrementingCols(count), if (offset == null) null else offset.asScala.toMap, config, Some(jdbcClient)) @@ -182,11 +159,14 @@ abstract class GenericSourceTask extends SourceTask { null } - protected def getTables(tables: List[Tuple2[String, String]]) - : List[Tuple4[String, Int, String, String]] - - protected def getQueries(query: List[(String, String)]) - : List[Tuple4[String, Int, String, String]] + def getTableOrQueryInfos(): List[Tuple3[String, Int, String]] = { + val props = configRawProperties.get + props.asScala.filter(p => p._1.startsWith("_tqinfos.") && p._1.endsWith(".name")).map( + t => Tuple3( + t._2, + props.get(t._1.replace("name", "partition")).toInt, + props.get(t._1.replace("name", "topic")))).toList + } protected def createJdbcClient(): HANAJdbcClient } diff --git a/src/main/scala/com/sap/kafka/connect/source/hana/HANASourceConnector.scala b/src/main/scala/com/sap/kafka/connect/source/hana/HANASourceConnector.scala index 826d243..804a100 100644 --- a/src/main/scala/com/sap/kafka/connect/source/hana/HANASourceConnector.scala +++ b/src/main/scala/com/sap/kafka/connect/source/hana/HANASourceConnector.scala @@ -1,25 +1,60 @@ package com.sap.kafka.connect.source.hana +import com.sap.kafka.client.hana.{HANAConfigInvalidInputException, HANAConfigMissingException, HANAJdbcClient} +import com.sap.kafka.connect.config.BaseConfigConstants +import com.sap.kafka.connect.config.hana.{HANAConfig, HANAParameters} +import com.sap.kafka.utils.ExecuteWithExceptions + import java.util -import org.apache.kafka.common.config.ConfigDef +import org.apache.kafka.common.config.{ConfigDef, ConfigException} import org.apache.kafka.connect.connector.Task -import org.apache.kafka.connect.source.SourceConnector +import org.apache.kafka.connect.errors.ConnectException +import org.apache.kafka.connect.source.{SourceConnector, SourceConnectorContext} import scala.collection.JavaConverters._ class HANASourceConnector extends SourceConnector { - private var configProperties: Option[util.Map[String, String]] = None - + private var configRawProperties: Option[util.Map[String, String]] = None + private var hanaClient: HANAJdbcClient = _ + private var tableOrQueryInfos: List[Tuple3[String, Int, String]] = _ + private var configProperties: HANAConfig = _ + override def context(): SourceConnectorContext = super.context() override def version(): String = getClass.getPackage.getImplementationVersion override def start(properties: util.Map[String, String]): Unit = { - configProperties = Some(properties) + configRawProperties = Some(properties) + configProperties = HANAParameters.getConfig(properties) + hanaClient = new HANAJdbcClient(configProperties) + + val topics = configProperties.topics + var tables: List[(String, String)] = Nil + if (topics.forall(topic => configProperties.topicProperties(topic).keySet.contains("table.name"))) { + tables = topics.map(topic => + (configProperties.topicProperties(topic)("table.name"), topic)) + } + var query: List[(String, String)] = Nil + if (topics.forall(topic => configProperties.topicProperties(topic).keySet.contains("query"))) { + query = topics.map(topic => + (configProperties.topicProperties(topic)("query"), topic)) + } + + if (tables.isEmpty && query.isEmpty) { + throw new ConnectException("Invalid configuration: each HANAConnector must have one table or query associated") + } + + tableOrQueryInfos = configProperties.queryMode match { + case BaseConfigConstants.QUERY_MODE_TABLE => + getTables(hanaClient, tables) + case BaseConfigConstants.QUERY_MODE_SQL => + getQueries(query) + } } override def taskClass(): Class[_ <: Task] = classOf[HANASourceTask] override def taskConfigs(maxTasks: Int): util.List[util.Map[String, String]] = { - (1 to maxTasks).map(c => configProperties.get).toList.asJava + val tableOrQueryGroups = createTableOrQueryGroups(tableOrQueryInfos, maxTasks) + createTaskConfigs(tableOrQueryGroups, configRawProperties.get).asJava } override def stop(): Unit = { @@ -29,4 +64,92 @@ class HANASourceConnector extends SourceConnector { override def config(): ConfigDef = { new ConfigDef } + + private def getTables(hanaClient: HANAJdbcClient, tables: List[Tuple2[String, String]]) : List[Tuple3[String, Int, String]] = { + val connection = hanaClient.getConnection + + // contains fullTableName, partitionNum, topicName + var tableInfos: List[Tuple3[String, Int, String]] = List() + val noOfTables = tables.size + var tablecount = 1 + + var stmtToFetchPartitions = s"SELECT SCHEMA_NAME, TABLE_NAME, PARTITION FROM SYS.M_CS_PARTITIONS WHERE " + tables.foreach(table => { + if (!(configProperties.topicProperties(table._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE)) { + table._1 match { + case BaseConfigConstants.TABLE_NAME_FORMAT(schema, tablename) => + stmtToFetchPartitions += s"(SCHEMA_NAME = '$schema' AND TABLE_NAME = '$tablename')" + + if (tablecount < noOfTables) { + stmtToFetchPartitions += " OR " + } + tablecount = tablecount + 1 + case _ => + throw new HANAConfigInvalidInputException("The table name is invalid. Does not follow naming conventions") + } + } + }) + + if (tablecount > 1) { + val stmt = connection.createStatement() + val partitionRs = stmt.executeQuery(stmtToFetchPartitions) + + while (partitionRs.next()) { + val tableName = "\"" + partitionRs.getString(1) + "\".\"" + partitionRs.getString(2) + "\"" + tableInfos :+= Tuple3(tableName, partitionRs.getInt(3), + tables.filter(table => table._1 == tableName).map(table => table._2).head.toString) + } + } + + // fill tableInfo for tables whose entry is not in M_CS_PARTITIONS + val tablesInInfo = tableInfos.map(tableInfo => tableInfo._1) + val tablesToBeAdded = tables.filterNot(table => tablesInInfo.contains(table._1)) + + tablesToBeAdded.foreach(tableToBeAdded => { + if (configProperties.topicProperties(tableToBeAdded._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE) { + tableInfos :+= Tuple3(getTableName(tableToBeAdded._1)._2, 0, tableToBeAdded._2) + } else { + tableInfos :+= Tuple3(tableToBeAdded._1, 0, tableToBeAdded._2) + } + }) + + tableInfos + } + + private def getQueries(queryTuple: List[(String, String)]): List[(String, Int, String)] = + queryTuple.map(query => (query._1, 0, query._2)) + + private def createTableOrQueryGroups(tableOrQueryInfos: List[Tuple3[String, Int, String]], count: Int) + : List[List[Tuple3[String, Int, String]]] = { + val groupSize = count match { + case c if c > tableOrQueryInfos.size => 1 + case _ => ((tableOrQueryInfos.size + count - 1) / count) + } + tableOrQueryInfos.grouped(groupSize).toList + } + + private def createTaskConfigs(tableOrQueryGroups: List[List[Tuple3[String, Int, String]]], config: java.util.Map[String, String]) + : List[java.util.Map[String, String]] = { + tableOrQueryGroups.map(g => { + var gconfig = new java.util.HashMap[String,String](config) + for ((t, i) <- g.zipWithIndex) { + gconfig.put(s"_tqinfos.$i.name", t._1) + gconfig.put(s"_tqinfos.$i.partition", t._2.toString) + gconfig.put(s"_tqinfos.$i.topic", t._3) + } + gconfig + }) + } + + private def getTableName(tableName: String): (Option[String], String) = { + tableName match { + case BaseConfigConstants.TABLE_NAME_FORMAT(schema, table) => + (Some(schema), table) + case BaseConfigConstants.COLLECTION_NAME_FORMAT(table) => + (None, table) + case _ => + throw new HANAConfigInvalidInputException(s"The table name mentioned in `{topic}.table.name` is invalid." + + s" Does not follow naming conventions") + } + } } \ No newline at end of file diff --git a/src/main/scala/com/sap/kafka/connect/source/hana/HANASourceTask.scala b/src/main/scala/com/sap/kafka/connect/source/hana/HANASourceTask.scala index 6cb6d0b..0fbbe39 100644 --- a/src/main/scala/com/sap/kafka/connect/source/hana/HANASourceTask.scala +++ b/src/main/scala/com/sap/kafka/connect/source/hana/HANASourceTask.scala @@ -16,7 +16,6 @@ class HANASourceTask extends GenericSourceTask { override def version(): String = getClass.getPackage.getImplementationVersion - override def createJdbcClient(): HANAJdbcClient = { config match { case hanaConfig: HANAConfig => new HANAJdbcClient(hanaConfig) @@ -24,71 +23,4 @@ class HANASourceTask extends GenericSourceTask { } } - override def getTables(tables: List[Tuple2[String, String]]) - : List[Tuple4[String, Int, String, String]] = { - val connection = jdbcClient.getConnection - - // contains fullTableName, partitionNum, fullTableName + partitionNum, topicName - var tableInfos: List[Tuple4[String, Int, String, String]] = List() - val noOfTables = tables.size - var tablecount = 1 - - var stmtToFetchPartitions = s"SELECT SCHEMA_NAME, TABLE_NAME, PARTITION FROM SYS.M_CS_PARTITIONS WHERE " - tables.foreach(table => { - if (!(config.topicProperties(table._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE)) { - table._1 match { - case BaseConfigConstants.TABLE_NAME_FORMAT(schema, tablename) => - stmtToFetchPartitions += s"(SCHEMA_NAME = '$schema' AND TABLE_NAME = '$tablename')" - - if (tablecount < noOfTables) { - stmtToFetchPartitions += " OR " - } - tablecount = tablecount + 1 - case _ => - throw new HANAConfigInvalidInputException("The table name is invalid. Does not follow naming conventions") - } - } - }) - - if (tablecount > 1) { - val stmt = connection.createStatement() - val partitionRs = stmt.executeQuery(stmtToFetchPartitions) - - while (partitionRs.next()) { - val tableName = "\"" + partitionRs.getString(1) + "\".\"" + partitionRs.getString(2) + "\"" - tableInfos :+= Tuple4(tableName, partitionRs.getInt(3), tableName + partitionRs.getInt(3), - tables.filter(table => table._1 == tableName) - .map(table => table._2).head.toString) - } - } - - // fill tableInfo for tables whose entry is not in M_CS_PARTITIONS - val tablesInInfo = tableInfos.map(tableInfo => tableInfo._1) - val tablesToBeAdded = tables.filterNot(table => tablesInInfo.contains(table._1)) - - tablesToBeAdded.foreach(tableToBeAdded => { - if (config.topicProperties(tableToBeAdded._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE) { - tableInfos :+= Tuple4(getTableName(tableToBeAdded._1)._2, 0, getTableName(tableToBeAdded._1)._2 + "0", tableToBeAdded._2) - } else { - tableInfos :+= Tuple4(tableToBeAdded._1, 0, tableToBeAdded._1 + "0", tableToBeAdded._2) - } - }) - - tableInfos - } - - override protected def getQueries(queryTuple: List[(String, String)]): List[(String, Int, String, String)] = - queryTuple.map(query => (query._1, 0, null, query._2)) - - private def getTableName(tableName: String): (Option[String], String) = { - tableName match { - case BaseConfigConstants.TABLE_NAME_FORMAT(schema, table) => - (Some(schema), table) - case BaseConfigConstants.COLLECTION_NAME_FORMAT(table) => - (None, table) - case _ => - throw new HANAConfigInvalidInputException(s"The table name mentioned in `{topic}.table.name` is invalid." + - s" Does not follow naming conventions") - } - } } diff --git a/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskConversionTest.scala b/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskConversionTest.scala index e51f618..1781f1c 100644 --- a/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskConversionTest.scala +++ b/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskConversionTest.scala @@ -1,6 +1,7 @@ package com.sap.kafka.connect.source import com.sap.kafka.client.MetaSchema +import com.sap.kafka.connect.source.hana.HANASourceConnector import org.apache.kafka.connect.data.Schema.Type import org.apache.kafka.connect.data.{Field, Schema, Struct} import org.apache.kafka.connect.source.SourceRecord @@ -11,11 +12,14 @@ class HANASourceTaskConversionTest extends HANASourceTaskTestBase { override def beforeAll(): Unit = { super.beforeAll() - task.start(singleTableConfig()) + connector = new HANASourceConnector + connector.start(singleTableConfig()) + task.start(connector.taskConfigs(1).get(0)) } override def afterAll(): Unit = { task.stop() + connector.stop() super.afterAll() } diff --git a/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskTestBase.scala b/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskTestBase.scala index 2d1f9a2..8b41303 100644 --- a/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskTestBase.scala +++ b/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskTestBase.scala @@ -1,11 +1,10 @@ package com.sap.kafka.connect.source import java.util - import com.sap.kafka.client.MetaSchema import com.sap.kafka.connect.MockJdbcClient import com.sap.kafka.connect.config.hana.HANAParameters -import com.sap.kafka.connect.source.hana.HANASourceTask +import com.sap.kafka.connect.source.hana.{HANASourceConnector, HANASourceTask} import org.apache.kafka.common.utils.Time import org.apache.kafka.connect.source.SourceTaskContext import org.apache.kafka.connect.storage.OffsetStorageReader @@ -58,6 +57,7 @@ class HANASourceTaskTestBase extends FunSuite protected val SECOND_TOPIC = "test-second-topic" protected var time: Time = _ protected var taskContext: SourceTaskContext = _ + protected var connector: HANASourceConnector = _ protected var task: HANASourceTask = _ protected var jdbcClient: MockJdbcClient = _ diff --git a/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskUpdateTest.scala b/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskUpdateTest.scala index 4afd64b..b03985c 100644 --- a/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskUpdateTest.scala +++ b/src/test/scala/com/sap/kafka/connect/source/HANASourceTaskUpdateTest.scala @@ -1,9 +1,8 @@ package com.sap.kafka.connect.source import java.util - import com.sap.kafka.client.MetaSchema -import com.sap.kafka.connect.source.hana.HANASourceTask +import com.sap.kafka.connect.source.hana.{HANASourceConnector, HANASourceTask} import org.apache.kafka.connect.data.{Field, Schema, SchemaBuilder, Struct} import org.apache.kafka.connect.source.SourceRecord import org.scalatest.BeforeAndAfterEach @@ -101,6 +100,7 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase test("bulk periodic load") { val connection = jdbcClient.getConnection + try { connection.setAutoCommit(true) val stmt = connection.createStatement() @@ -108,7 +108,9 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase val expectedSchema = SchemaBuilder.struct().name("expected schema") .field("id", Schema.INT32_SCHEMA) - task.start(singleTableConfig()) + val connector = new HANASourceConnector + connector.start(singleTableConfig()) + task.start(connector.taskConfigs(1).get(0)) var expectedData = new Struct(expectedSchema) .put("id", 1) @@ -140,6 +142,9 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase expectedSchema) count = count + 1 }) + + task.stop() + connector.stop() } finally { connection.close() } @@ -162,7 +167,9 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase .name("expected schema for second table") .field("id", Schema.INT32_SCHEMA) - multiTableLoadTask.start(multiTableConfig()) + val connector = new HANASourceConnector + connector.start(multiTableConfig()) + multiTableLoadTask.start(connector.taskConfigs(1).get(0)) val expectedDataForFirstTable = new Struct(expectedSchemaForSingleTable) .put("id", 1) @@ -187,6 +194,8 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase expectedSchemaForSecondTable) } }) + task.stop() + connector.stop() } finally { connection.close() } @@ -201,7 +210,9 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase val expectedSchema = SchemaBuilder.struct().name("expected schema") .field("id", Schema.INT32_SCHEMA) - queryLoadTask.start(singleTableQueryConfig()) + val connector = new HANASourceConnector + connector.start(singleTableQueryConfig()) + queryLoadTask.start(connector.taskConfigs(1).get(0)) var expectedData = new Struct(expectedSchema) .put("id", 1) @@ -232,6 +243,8 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase expectedSchema) count = count + 1 }) + task.stop() + connector.stop() } finally { connection.close() } @@ -248,7 +261,9 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase .field("id", Schema.INT32_SCHEMA) .field("name", Schema.STRING_SCHEMA) incrLoadTask.initialize(taskContext) - incrLoadTask.start(singleTableConfigInIncrementalMode(SINGLE_TABLE_NAME_FOR_INCR_LOAD, "id")) + val connector = new HANASourceConnector + connector.start(singleTableConfigInIncrementalMode(SINGLE_TABLE_NAME_FOR_INCR_LOAD, "id")) + incrLoadTask.start(connector.taskConfigs(1).get(0)) var expectedData = new Struct(expectedSchema) .put("id", 1) .put("name", "Lukas") @@ -280,6 +295,8 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase compareData(expectedData, record.value().asInstanceOf[Struct], expectedSchema) }) + task.stop() + connector.stop() } finally { connection.close() } @@ -295,8 +312,10 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase val expectedSchema = SchemaBuilder.struct().name("expected schema") .field("id", Schema.STRING_SCHEMA) .field("name", Schema.STRING_SCHEMA) + val connector = new HANASourceConnector + connector.start(singleTableConfigInIncrementalMode(SINGLE_TABLE_NAME_FOR_INCR2_LOAD, "id")) incr2LoadTask.initialize(taskContext) - incr2LoadTask.start(singleTableConfigInIncrementalMode(SINGLE_TABLE_NAME_FOR_INCR2_LOAD, "id")) + incr2LoadTask.start(connector.taskConfigs(1).get(0)) var expectedData = new Struct(expectedSchema) .put("id", "1") .put("name", "Lukas") @@ -328,6 +347,8 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase compareData(expectedData, record.value().asInstanceOf[Struct], expectedSchema) }) + task.stop() + connector.stop() } finally { connection.close() } @@ -343,8 +364,10 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase val expectedSchema = SchemaBuilder.struct().name("expected schema") .field("id", Schema.INT32_SCHEMA) .field("name", Schema.STRING_SCHEMA) + val connector = new HANASourceConnector + connector.start(singleTableConfigInIncrementalQueryMode()) incrQueryLoadTask.initialize(taskContext) - incrQueryLoadTask.start(singleTableConfigInIncrementalQueryMode()) + incrQueryLoadTask.start(connector.taskConfigs(1).get(0)) var expectedData = new Struct(expectedSchema) .put("id", 1) .put("name", "Lukas") @@ -376,6 +399,8 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase compareData(expectedData, record.value().asInstanceOf[Struct], expectedSchema) }) + task.stop() + connector.stop() } finally { connection.close() } @@ -393,7 +418,9 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase val expectedSchema = SchemaBuilder.struct().name("expected schema") .field("id", Schema.INT32_SCHEMA) - maxrowsLoadTask.start(singleTableMaxRowsConfig("2")) + val connector = new HANASourceConnector + connector.start(singleTableMaxRowsConfig("2")) + maxrowsLoadTask.start(connector.taskConfigs(1).get(0)) var expectedData = new Struct(expectedSchema) .put("id", 1) @@ -418,6 +445,8 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase assert(records.size() === 2) verifyRecords(i-1, 2, records, expectedSchema) } + task.stop() + connector.stop() } finally { connection.close() } @@ -435,8 +464,10 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase val expectedSchema = SchemaBuilder.struct().name("expected schema") .field("id", Schema.INT32_SCHEMA) + val connector = new HANASourceConnector + connector.start(singleTableMaxRowsConfigInIncrementalMode("2")) maxrowsIncrLoadTask.initialize(taskContext) - maxrowsIncrLoadTask.start(singleTableMaxRowsConfigInIncrementalMode("2")) + maxrowsIncrLoadTask.start(connector.taskConfigs(1).get(0)) var expectedData = new Struct(expectedSchema) .put("id", 1) @@ -457,6 +488,8 @@ class HANASourceTaskUpdateTest extends HANASourceTaskTestBase assert(records.size() === 1) assert(maxrowsIncrLoadTask.poll() === null) + task.stop() + connector.stop() } finally { connection.close() } diff --git a/src/test/scala/com/sap/kafka/connect/source/hana/HANASourceConnectorTest.scala b/src/test/scala/com/sap/kafka/connect/source/hana/HANASourceConnectorTest.scala new file mode 100644 index 0000000..aabd475 --- /dev/null +++ b/src/test/scala/com/sap/kafka/connect/source/hana/HANASourceConnectorTest.scala @@ -0,0 +1,182 @@ +package com.sap.kafka.connect.source.hana + +import com.sap.kafka.client.MetaSchema +import com.sap.kafka.connect.MockJdbcClient +import com.sap.kafka.connect.config.hana.HANAParameters +import org.apache.kafka.connect.data.{Field, Schema} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import java.util + +class HANASourceConnectorTest extends AnyFunSuite with BeforeAndAfterAll { + val tmpdir = System.getProperty("java.io.tmpdir") + protected val TEST_CONNECTION_URL= s"jdbc:h2:file:$tmpdir/test2;INIT=CREATE SCHEMA IF NOT EXISTS TEST;DB_CLOSE_DELAY=-1" + private val TOPIC1 = "test-topic1" + private val TOPIC2 = "test-topic2" + private val TABLE_NAME_ONE_SOURCE = "\"TEST\".\"ONE_SOURCE\"" + private val TABLE_CONFIG_SINGLE = tableConfigSingle + private val TABLE_NAME_TWO_SOURCE = "\"TEST\".\"TWO_SOURCE\"" + private val TABLE_CONFIG_MULTIPLE = tableConfigMultiple + protected var jdbcClient: MockJdbcClient = _ + + override def beforeAll(): Unit = { + super.beforeAll() + jdbcClient = new MockJdbcClient(HANAParameters.getConfig(TABLE_CONFIG_SINGLE)) + val connection = jdbcClient.getConnection + try { + connection.setAutoCommit(true) + val stmt = connection.createStatement + + stmt.execute("DROP ALL OBJECTS DELETE FILES") + stmt.execute("CREATE SCHEMA IF NOT EXISTS SYS") + + val fields = Seq(new Field("SCHEMA_NAME", 1, Schema.STRING_SCHEMA), + new Field("TABLE_NAME", 2, Schema.STRING_SCHEMA), + new Field("PARTITION", 3, Schema.INT32_SCHEMA)) + + jdbcClient.createTable(Some("SYS"), "M_CS_PARTITIONS", MetaSchema(null, fields), 3000) + stmt.execute("insert into \"SYS\".\"M_CS_PARTITIONS\" values('TEST', 'ONE_SOURCE', 0)") + stmt.execute("insert into \"SYS\".\"M_CS_PARTITIONS\" values('TEST', 'TWO_SOURCE', 1)") + stmt.execute("insert into \"SYS\".\"M_CS_PARTITIONS\" values('TEST', 'TWO_SOURCE', 2)") + } finally { + connection.close() + } + } + + override def afterAll(): Unit = { + val connection = jdbcClient.getConnection + try { + connection.setAutoCommit(true) + val stmt = connection.createStatement + stmt.execute("drop table \"SYS\".\"M_CS_PARTITIONS\"") + + stmt.execute("DROP ALL OBJECTS DELETE FILES") + } finally { + connection.close() + } + super.afterAll() + } + + test("one task for non-partitioned single table") { + val connector = new HANASourceConnector + try { + connector.start(TABLE_CONFIG_SINGLE) + val taskConfigs = connector.taskConfigs(1) + assert(taskConfigs.size() === 1) + verifySourceTasksConfigs(taskConfigs, + List( + List(Tuple3("\"TEST\".\"ONE_SOURCE\"", 0, "test-topic1")))) + } finally { + connector.stop() + } + } + + test("too many tasks for non-partitioned single table") { + val connector = new HANASourceConnector + try { + connector.start(TABLE_CONFIG_SINGLE) + val taskConfigs = connector.taskConfigs(3) + assert(taskConfigs.size === 1) + verifySourceTasksConfigs(taskConfigs, + List( + List(Tuple3("\"TEST\".\"ONE_SOURCE\"", 0, "test-topic1")))) + } finally { + connector.stop() + } + } + + test("multiple tasks for partitioned multiple tables") { + val connector = new HANASourceConnector + try { + connector.start(TABLE_CONFIG_MULTIPLE) + val taskConfigs = connector.taskConfigs(3) + assert(taskConfigs.size === 3) + verifySourceTasksConfigs(taskConfigs, + List( + List(Tuple3("\"TEST\".\"ONE_SOURCE\"", 0, "test-topic1")), + List(Tuple3("\"TEST\".\"TWO_SOURCE\"", 1, "test-topic2")), + List(Tuple3("\"TEST\".\"TWO_SOURCE\"", 2, "test-topic2")))) + } finally { + connector.stop() + } + } + + test("too many tasks for partitioned multiple tables") { + val connector = new HANASourceConnector + try { + connector.start(TABLE_CONFIG_MULTIPLE) + val taskConfigs = connector.taskConfigs(5) + assert(taskConfigs.size === 3) + verifySourceTasksConfigs(taskConfigs, + List( + List(Tuple3("\"TEST\".\"ONE_SOURCE\"", 0, "test-topic1")), + List(Tuple3("\"TEST\".\"TWO_SOURCE\"", 1, "test-topic2")), + List(Tuple3("\"TEST\".\"TWO_SOURCE\"", 2, "test-topic2")))) + } finally { + connector.stop() + } + } + + test("less tasks for partitioned multiple table") { + val connector = new HANASourceConnector + try { + connector.start(TABLE_CONFIG_MULTIPLE) + val taskConfigs = connector.taskConfigs(2) + assert(taskConfigs.size === 2) + verifySourceTasksConfigs(taskConfigs, + List( + List(Tuple3("\"TEST\".\"ONE_SOURCE\"", 0, "test-topic1"), Tuple3("\"TEST\".\"TWO_SOURCE\"", 1, "test-topic2")), + List(Tuple3("\"TEST\".\"TWO_SOURCE\"", 2, "test-topic2")))) + } finally { + connector.stop() + } + } + + def verifySourceTasksConfigs(taskConfigs: util.List[util.Map[String, String]], expected: List[List[Tuple3[String, Int, String]]]) : Unit = { + val tcit = taskConfigs.iterator + for (tq <- expected) { + if (tcit.hasNext) { + val task = new HANASourceTask + try { + task.start(tcit.next) + assert(tq === task.getTableOrQueryInfos()) + } finally { + task.stop() + } + } else { + fail("Unexpected number of tasks") + } + } + } + + def tableConfigBase(): util.Map[String, String] = { + val props = new util.HashMap[String, String]() + props.put("connection.url", TEST_CONNECTION_URL) + props.put("connection.user", "sa") + props.put("connection.password", "sa") + props.put("mode", "bulk") + props + } + + def tableConfigSingle(): util.Map[String, String] = { + val props = tableConfigBase + props.put("topics", TOPIC1) + props.put(s"$TOPIC1.table.name", TABLE_NAME_ONE_SOURCE) + props.put(s"$TOPIC1.partition.count", "1") + props.put(s"$TOPIC1.poll.interval.ms", "60000") + props + } + + def tableConfigMultiple(): util.Map[String, String] = { + val props = tableConfigBase + props.put("topics", s"$TOPIC1,$TOPIC2") + props.put(s"$TOPIC1.table.name", TABLE_NAME_ONE_SOURCE) + props.put(s"$TOPIC1.partition.count", "1") + props.put(s"$TOPIC1.poll.interval.ms", "60000") + props.put(s"$TOPIC2.table.name", TABLE_NAME_TWO_SOURCE) + props.put(s"$TOPIC2.partition.count", "1") + props.put(s"$TOPIC2.poll.interval.ms", "60000") + props + } +}