diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 918540fa756be..360f391de6c1c 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -220,7 +220,9 @@ def userId(self) -> Optional[str]: @property def token(self) -> Optional[str]: - return self._params.get(ChannelBuilder.PARAM_TOKEN, None) + return self._params.get( + ChannelBuilder.PARAM_TOKEN, os.environ.get("SPARK_CONNECT_AUTHENTICATE_TOKEN") + ) def metadata(self) -> Iterable[Tuple[str, str]]: """ @@ -410,10 +412,11 @@ def _extract_attributes(self) -> None: @property def secure(self) -> bool: - return ( - self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true" - or self.token is not None - ) + return self.use_ssl or self.token is not None + + @property + def use_ssl(self) -> bool: + return self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true" @property def host(self) -> str: @@ -439,14 +442,20 @@ def toChannel(self) -> grpc.Channel: if not self.secure: return self._insecure_channel(self.endpoint) + elif not self.use_ssl and self._host == "localhost": + creds = grpc.local_channel_credentials() + + if self.token is not None: + creds = grpc.composite_channel_credentials( + creds, grpc.access_token_call_credentials(self.token) + ) + return self._secure_channel(self.endpoint, creds) else: - ssl_creds = grpc.ssl_channel_credentials() + creds = grpc.ssl_channel_credentials() - if self.token is None: - creds = ssl_creds - else: + if self.token is not None: creds = grpc.composite_channel_credentials( - ssl_creds, grpc.access_token_call_credentials(self.token) + creds, grpc.access_token_call_credentials(self.token) ) return self._secure_channel(self.endpoint, creds) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index c863af3265dc1..4918762d240ec 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import uuid from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -1030,6 +1031,8 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: 2. Starts a regular Spark session that automatically starts a Spark Connect server via ``spark.plugins`` feature. + + Returns the authentication token that should be used to connect to this session. """ from pyspark import SparkContext, SparkConf @@ -1049,6 +1052,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: if "spark.api.mode" in overwrite_conf: del overwrite_conf["spark.api.mode"] + # Check for a user provided authentication token, creating a new one if not, + # and make sure it's set in the environment, + if "SPARK_CONNECT_AUTHENTICATE_TOKEN" not in os.environ: + os.environ["SPARK_CONNECT_AUTHENTICATE_TOKEN"] = opts.get( + "spark.connect.authenticate.token", str(uuid.uuid4()) + ) + # Configurations to be set if unset. default_conf = { "spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin", @@ -1081,7 +1091,6 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: new_opts = {k: opts[k] for k in opts if k in runtime_conf_keys} opts.clear() opts.update(new_opts) - finally: if origin_remote is not None: os.environ["SPARK_REMOTE"] = origin_remote diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index 1ab069a4025c4..1fd59609d450c 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -43,6 +43,7 @@ from pyspark.errors.exceptions.connect import ( AnalysisException, SparkConnectException, + SparkConnectGrpcException, SparkUpgradeException, ) @@ -237,7 +238,13 @@ def test_custom_channel_builder(self): class CustomChannelBuilder(ChannelBuilder): def toChannel(self): - return self._insecure_channel(endpoint) + creds = grpc.local_channel_credentials() + + if self.token is not None: + creds = grpc.composite_channel_credentials( + creds, grpc.access_token_call_credentials(self.token) + ) + return self._secure_channel(endpoint, creds) session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create() session.sql("select 1 + 1") @@ -290,6 +297,15 @@ def test_api_mode(self): self.assertEqual(session.range(1).first()[0], 0) self.assertIsInstance(session, RemoteSparkSession) + def test_authentication(self): + # All servers start with a default token of "deadbeef", so supply in invalid one + session = RemoteSparkSession.builder.remote("sc://localhost/;token=invalid").create() + + with self.assertRaises(SparkConnectGrpcException) as e: + session.range(3).collect() + + self.assertTrue("Invalid authentication token" in str(e.exception)) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectSessionWithOptionsTest(unittest.TestCase): diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py index 47f7d672cc8c2..e1b8d7c76d183 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py @@ -24,7 +24,6 @@ import unittest from typing import cast -from pyspark import SparkConf from pyspark.sql.streaming.state import GroupStateTimeout, GroupState from pyspark.sql.types import ( LongType, @@ -56,7 +55,7 @@ class GroupedApplyInPandasWithStateTestsMixin: @classmethod def conf(cls): - cfg = SparkConf() + cfg = super().conf() cfg.set("spark.sql.shuffle.partitions", "5") return cfg diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 7dea8a2103c3d..423a717e8ab5e 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -155,6 +155,9 @@ def conf(cls): conf._jconf.remove("spark.master") conf.set("spark.connect.execute.reattachable.senderMaxStreamDuration", "1s") conf.set("spark.connect.execute.reattachable.senderMaxStreamSize", "123") + # Set a static token for all tests so the parallelism doesn't overwrite each + # tests' environment variables + conf.set("spark.connect.authenticate.token", "deadbeef") return conf @classmethod diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala index b342d5b415692..a3c0220665324 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala @@ -125,7 +125,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite { assert(builder.host === "localhost") assert(builder.port === 15002) assert(builder.userAgent.contains("_SPARK_CONNECT_SCALA")) - assert(builder.sslEnabled) + assert(!builder.sslEnabled) assert(builder.token.contains("thisismysecret")) assert(builder.userId.isEmpty) assert(builder.userName.isEmpty) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index acee1b2775f17..3d1ba71b9f90a 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -299,8 +299,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { TestPackURI(s"sc://host:123/;session_id=${UUID.randomUUID().toString}", isCorrect = true), TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = true), TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = true), - TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = false), - TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = false), + TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = true), + TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = true), TestPackURI("sc://host:123/;param1=value1;param2=value2", isCorrect = true), TestPackURI( "sc://SPARK-45486", diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala index c4067ea3ac330..0af7c7b6d97a7 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala @@ -762,6 +762,7 @@ object SparkSession extends SparkSessionCompanion with Logging { (remoteString.exists(_.startsWith("local")) || (remoteString.isDefined && isAPIModeConnect)) && maybeConnectStartScript.exists(Files.exists(_))) { + val token = java.util.UUID.randomUUID().toString() val serverId = UUID.randomUUID().toString server = Some { val args = @@ -779,6 +780,7 @@ object SparkSession extends SparkSessionCompanion with Logging { pb.environment().remove(SparkConnectClient.SPARK_REMOTE) pb.environment().put("SPARK_IDENT_STRING", serverId) pb.environment().put("HOSTNAME", "local") + pb.environment().put("SPARK_CONNECT_AUTHENTICATE_TOKEN", token) pb.start() } @@ -800,7 +802,7 @@ object SparkSession extends SparkSessionCompanion with Logging { } } - System.setProperty("spark.remote", "sc://localhost") + System.setProperty("spark.remote", s"sc://localhost/;token=$token") // scalastyle:off runtimeaddshutdownhook Runtime.getRuntime.addShutdownHook(new Thread() { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index dd241c50c9340..57ed454183166 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -468,8 +468,6 @@ object SparkConnectClient { * sc://localhost/;token=aaa;use_ssl=true * }}} * - * Throws exception if the token is set but use_ssl=false. - * * @param inputToken * the user token. * @return @@ -477,11 +475,7 @@ object SparkConnectClient { */ def token(inputToken: String): Builder = { require(inputToken != null && inputToken.nonEmpty) - if (_configuration.isSslEnabled.contains(false)) { - throw new IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG) - } - _configuration = - _configuration.copy(token = Option(inputToken), isSslEnabled = Option(true)) + _configuration = _configuration.copy(token = Option(inputToken)) this } @@ -499,7 +493,6 @@ object SparkConnectClient { * this builder. */ def disableSsl(): Builder = { - require(token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG) _configuration = _configuration.copy(isSslEnabled = Option(false)) this } @@ -737,6 +730,8 @@ object SparkConnectClient { grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE, grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) { + private def isLocal = host.equals("localhost") + def userContext: proto.UserContext = { val builder = proto.UserContext.newBuilder() if (userId != null) { @@ -749,7 +744,7 @@ object SparkConnectClient { } def credentials: ChannelCredentials = { - if (isSslEnabled.contains(true)) { + if (isSslEnabled.contains(true) || (token.isDefined && !isLocal)) { token match { case Some(t) => // With access token added in the http header. @@ -765,10 +760,18 @@ object SparkConnectClient { } def createChannel(): ManagedChannel = { - val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, credentials) + val creds = credentials + val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, creds) + + // Workaround LocalChannelCredentials are added in + // https://github.com/grpc/grpc-java/issues/9900 + var metadataWithOptionalToken = metadata + if (!isSslEnabled.contains(true) && isLocal && token.isDefined) { + metadataWithOptionalToken = metadata + (("Authorization", s"Bearer ${token.get}")) + } - if (metadata.nonEmpty) { - channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata)) + if (metadataWithOptionalToken.nonEmpty) { + channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadataWithOptionalToken)) } interceptors.foreach(channelBuilder.intercept(_)) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index b0c5a2a055b56..9f884b683079c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.config import java.util.concurrent.TimeUnit +import org.apache.spark.SparkEnv import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.connect.common.config.ConnectCommon import org.apache.spark.sql.internal.SQLConf @@ -313,4 +314,21 @@ object Connect { .internal() .booleanConf .createWithDefault(true) + + val CONNECT_AUTHENTICATE_TOKEN = + buildStaticConf("spark.connect.authenticate.token") + .doc("A pre-shared token that will be used to authenticate clients. This secret must be" + + " passed as a bearer token by for clients to connect.") + .version("4.0.0") + .internal() + .stringConf + .createOptional + + val CONNECT_AUTHENTICATE_TOKEN_ENV = "SPARK_CONNECT_AUTHENTICATE_TOKEN" + + def getAuthenticateToken: Option[String] = { + SparkEnv.get.conf.get(CONNECT_AUTHENTICATE_TOKEN).orElse { + Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV)) + } + } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 07c5da9744cc6..cc6b58216ed7a 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, QUERY_ID, RUN_ID_STRING, import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} import org.apache.spark.sql.connect.common.ForeachWriterPacket +import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.streaming.StreamingQuery @@ -131,7 +132,10 @@ object StreamingForeachBatchHelper extends Logging { sessionHolder: SessionHolder): (ForeachBatchFnType, AutoCloseable) = { val port = SparkConnectService.localPort - val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + var connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + Connect.getAuthenticateToken.foreach { token => + connectUrl = s"$connectUrl;token=$token" + } val runner = StreamingPythonRunner( pythonFn, connectUrl, diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala index c342050a212ef..42c090d43f065 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, SimplePythonFunction, SpecialLengths, StreamingPythonRunner} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.FUNCTION_NAME +import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService} import org.apache.spark.sql.streaming.StreamingQueryListener @@ -36,7 +37,10 @@ class PythonStreamingQueryListener(listener: SimplePythonFunction, sessionHolder with Logging { private val port = SparkConnectService.localPort - private val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + private var connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + Connect.getAuthenticateToken.foreach { token => + connectUrl = s"$connectUrl;token=$token" + } // Scoped for testing private[connect] val runner = StreamingPythonRunner( listener, diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala new file mode 100644 index 0000000000000..5d7cc65358eb3 --- /dev/null +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor, Status} + +class PreSharedKeyAuthenticationInterceptor(token: String) extends ServerInterceptor { + + val authorizationMetadataKey = + Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER) + + val expectedValue = s"Bearer $token" + + override def interceptCall[ReqT, RespT]( + call: ServerCall[ReqT, RespT], + metadata: Metadata, + next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + val authHeaderValue = metadata.get(authorizationMetadataKey) + + if (authHeaderValue == null) { + val status = Status.UNAUTHENTICATED.withDescription("No authentication token provided") + call.close(status, new Metadata()) + new ServerCall.Listener[ReqT]() {} + } else if (authHeaderValue != expectedValue) { + val status = Status.UNAUTHENTICATED.withDescription("Invalid authentication token") + call.close(status, new Metadata()) + new ServerCall.Listener[ReqT]() {} + } else { + next.startCall(call, metadata) + } + } +} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index e62c19b66c8e5..8fa64ddcce49e 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -39,7 +39,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.HOST import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent} -import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES} +import org.apache.spark.sql.connect.config.Connect.{getAuthenticateToken, CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES} import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab} import org.apache.spark.sql.connect.utils.ErrorUtils @@ -381,6 +381,10 @@ object SparkConnectService extends Logging { sb.maxInboundMessageSize(SparkEnv.get.conf.get(CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE).toInt) .addService(sparkConnectService) + getAuthenticateToken.foreach { token => + sb.intercept(new PreSharedKeyAuthenticationInterceptor(token)) + } + // Add all registered interceptors to the server builder. SparkConnectInterceptorRegistry.chainInterceptors(sb, configuredInterceptors) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectAuthSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectAuthSuite.scala new file mode 100644 index 0000000000000..30f186ab7c2b1 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectAuthSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.service + +// import io.grpc.StatusRuntimeException + +import org.apache.spark.SparkException +import org.apache.spark.sql.connect.{SparkConnectServerTest, SparkSession} +import org.apache.spark.sql.connect.service.SparkConnectService + +class SparkConnectAuthSuite extends SparkConnectServerTest { + override protected def sparkConf = { + super.sparkConf.set("spark.connect.authenticate.token", "deadbeef") + } + + test("Test local authentication") { + val session = SparkSession + .builder() + .remote(s"sc://localhost:${SparkConnectService.localPort}/;token=deadbeef") + .create() + session.range(5).collect() + + val invalidSession = SparkSession + .builder() + .remote(s"sc://localhost:${SparkConnectService.localPort}/;token=invalid") + .create() + val exception = intercept[SparkException] { + invalidSession.range(5).collect() + } + assert(exception.getMessage.contains("Invalid authentication token")) + } +}