Skip to content

Commit

Permalink
Add Snowflake plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
sethjones348 committed Jul 15, 2024
1 parent a6f63b1 commit 133f2c1
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 0 deletions.
5 changes: 5 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@
<artifactId>mongo-spark-connector_${scala.binary.version}</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>net.snowflake</groupId>
<artifactId>spark-snowflake_${scala.binary.version}</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.elasticsearch</groupId>
<artifactId>elasticsearch-hadoop</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2021 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.spline.harvester.plugin.embedded

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.SaveIntoDataSourceCommand
import za.co.absa.spline.commons.reflect.extractors.SafeTypeMatchingExtractor
import za.co.absa.spline.harvester.builder.SourceIdentifier
import za.co.absa.spline.harvester.plugin.Plugin.{Precedence, WriteNodeInfo}
import za.co.absa.spline.harvester.plugin.embedded.SnowflakePlugin._
import za.co.absa.spline.harvester.plugin.{Plugin, RelationProviderProcessing}

import javax.annotation.Priority
import scala.language.reflectiveCalls


@Priority(Precedence.Normal)
class SnowflakePlugin(spark: SparkSession)
extends Plugin
with RelationProviderProcessing {

import za.co.absa.spline.commons.ExtractorImplicits._

override def relationProviderProcessor: PartialFunction[(AnyRef, SaveIntoDataSourceCommand), WriteNodeInfo] = {
case (rp, cmd) if rp == "net.snowflake.spark.snowflake.DefaultSource" || SnowflakeSourceExtractor.matches(rp) =>
val url: String = cmd.options("sfUrl")
val warehouse: String = cmd.options("sfWarehouse")
val database: String = cmd.options("sfDatabase")
val schema: String = cmd.options("sfSchema")
val table: String = cmd.options("dbtable")

WriteNodeInfo(asSourceId(url, warehouse, database, schema, table), cmd.mode, cmd.query, cmd.options)
}
}

object SnowflakePlugin {

private object SnowflakeSourceExtractor extends SafeTypeMatchingExtractor(classOf[net.snowflake.spark.snowflake.DefaultSource])

private def asSourceId(url: String, warehouse: String, database: String, schema: String, table: String) =
SourceIdentifier(Some("snowflake"), s"snowflake://$url.$warehouse.$database.$schema.$table")

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package za.co.absa.spline.harvester.plugin.embedded

import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.execution.datasources.SaveIntoDataSourceCommand
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar
import za.co.absa.spline.harvester.plugin.Plugin.WriteNodeInfo
import za.co.absa.spline.harvester.builder.SourceIdentifier
import org.mockito.Mockito._

class SnowflakePluginSpec extends AnyFlatSpec with Matchers with MockitoSugar {

"SnowflakePlugin" should "process Snowflake relation providers" in {
// Setup
val spark = mock[SparkSession]
val plugin = new SnowflakePlugin(spark)

val options = Map(
"sfUrl" -> "test-url",
"sfWarehouse" -> "test-warehouse",
"sfDatabase" -> "test-database",
"sfSchema" -> "test-schema",
"dbtable" -> "test-table"
)

val cmd = mock[SaveIntoDataSourceCommand]
when(cmd.options) thenReturn(options)
when(cmd.mode) thenReturn(SaveMode.Overwrite)
when(cmd.query) thenReturn(null)

// Mocking the relation provider to be Snowflake
val snowflakeRP = "net.snowflake.spark.snowflake.DefaultSource"

// Execute
val result = plugin.relationProviderProcessor((snowflakeRP, cmd))

// Verify
val expectedSourceId = SourceIdentifier(Some("snowflake"), "snowflake://test-url.test-warehouse.test-database.test-schema.test-table")
result shouldEqual WriteNodeInfo(expectedSourceId, SaveMode.Overwrite, null, options)
}

it should "not process non-Snowflake relation providers" in {
// Setup
val spark = mock[SparkSession]
val plugin = new SnowflakePlugin(spark)

val cmd = mock[SaveIntoDataSourceCommand]

// Mocking the relation provider to be non-Snowflake
val nonSnowflakeRP = "some.other.datasource"

// Execute & Verify
assertThrows[MatchError] {
plugin.relationProviderProcessor((nonSnowflakeRP, cmd))
}
}
}
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,11 @@
<artifactId>mongo-spark-connector_${scala.binary.version}</artifactId>
<version>2.4.1</version>
</dependency>
<dependency>
<groupId>net.snowflake</groupId>
<artifactId>spark-snowflake_${scala.binary.version}</artifactId>
<version>2.16.0-spark_3.3</version>
</dependency>
<dependency>
<groupId>org.elasticsearch</groupId>
<artifactId>elasticsearch-hadoop</artifactId>
Expand Down

0 comments on commit 133f2c1

Please sign in to comment.