Skip to content

Commit

Permalink
[SPARK-51156][CONNECT] Static token authentication support in Spark C…
Browse files Browse the repository at this point in the history
…onnect

### What changes were proposed in this pull request?

Adds static token authentication support to Spark Connect, which is used by default for automatically started servers locally.

### Why are the changes needed?

To add authentication support to Spark Connect so a connect server isn't started that could be accessible to other users inadvertently.

### Does this PR introduce _any_ user-facing change?

The local authentication should be transparent to users, but adds the option for users manually starting connect servers to specify an authentication token.

### How was this patch tested?

Existing UTs

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #50006 from Kimahriman/spark-connect-local-auth.

Lead-authored-by: Adam Binford <adamq43@gmail.com>
Co-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit 7e9547c)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
Kimahriman and HyukjinKwon committed Feb 23, 2025
1 parent 79ffa1d commit cd513f5
Show file tree
Hide file tree
Showing 15 changed files with 197 additions and 33 deletions.
29 changes: 19 additions & 10 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def userId(self) -> Optional[str]:

@property
def token(self) -> Optional[str]:
return self._params.get(ChannelBuilder.PARAM_TOKEN, None)
return self._params.get(
ChannelBuilder.PARAM_TOKEN, os.environ.get("SPARK_CONNECT_AUTHENTICATE_TOKEN")
)

def metadata(self) -> Iterable[Tuple[str, str]]:
"""
Expand Down Expand Up @@ -410,10 +412,11 @@ def _extract_attributes(self) -> None:

@property
def secure(self) -> bool:
return (
self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true"
or self.token is not None
)
return self.use_ssl or self.token is not None

@property
def use_ssl(self) -> bool:
return self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true"

@property
def host(self) -> str:
Expand All @@ -439,14 +442,20 @@ def toChannel(self) -> grpc.Channel:

if not self.secure:
return self._insecure_channel(self.endpoint)
elif not self.use_ssl and self._host == "localhost":
creds = grpc.local_channel_credentials()

if self.token is not None:
creds = grpc.composite_channel_credentials(
creds, grpc.access_token_call_credentials(self.token)
)
return self._secure_channel(self.endpoint, creds)
else:
ssl_creds = grpc.ssl_channel_credentials()
creds = grpc.ssl_channel_credentials()

if self.token is None:
creds = ssl_creds
else:
if self.token is not None:
creds = grpc.composite_channel_credentials(
ssl_creds, grpc.access_token_call_credentials(self.token)
creds, grpc.access_token_call_credentials(self.token)
)

return self._secure_channel(self.endpoint, creds)
Expand Down
11 changes: 10 additions & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)
Expand Down Expand Up @@ -1030,6 +1031,8 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
2. Starts a regular Spark session that automatically starts a Spark Connect server
via ``spark.plugins`` feature.
Returns the authentication token that should be used to connect to this session.
"""
from pyspark import SparkContext, SparkConf

Expand All @@ -1049,6 +1052,13 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
if "spark.api.mode" in overwrite_conf:
del overwrite_conf["spark.api.mode"]

# Check for a user provided authentication token, creating a new one if not,
# and make sure it's set in the environment,
if "SPARK_CONNECT_AUTHENTICATE_TOKEN" not in os.environ:
os.environ["SPARK_CONNECT_AUTHENTICATE_TOKEN"] = opts.get(
"spark.connect.authenticate.token", str(uuid.uuid4())
)

# Configurations to be set if unset.
default_conf = {
"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin",
Expand Down Expand Up @@ -1081,7 +1091,6 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
new_opts = {k: opts[k] for k in opts if k in runtime_conf_keys}
opts.clear()
opts.update(new_opts)

finally:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
Expand Down
18 changes: 17 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pyspark.errors.exceptions.connect import (
AnalysisException,
SparkConnectException,
SparkConnectGrpcException,
SparkUpgradeException,
)

Expand Down Expand Up @@ -237,7 +238,13 @@ def test_custom_channel_builder(self):

class CustomChannelBuilder(ChannelBuilder):
def toChannel(self):
return self._insecure_channel(endpoint)
creds = grpc.local_channel_credentials()

if self.token is not None:
creds = grpc.composite_channel_credentials(
creds, grpc.access_token_call_credentials(self.token)
)
return self._secure_channel(endpoint, creds)

session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create()
session.sql("select 1 + 1")
Expand Down Expand Up @@ -290,6 +297,15 @@ def test_api_mode(self):
self.assertEqual(session.range(1).first()[0], 0)
self.assertIsInstance(session, RemoteSparkSession)

def test_authentication(self):
# All servers start with a default token of "deadbeef", so supply in invalid one
session = RemoteSparkSession.builder.remote("sc://localhost/;token=invalid").create()

with self.assertRaises(SparkConnectGrpcException) as e:
session.range(3).collect()

self.assertTrue("Invalid authentication token" in str(e.exception))


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectSessionWithOptionsTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import unittest
from typing import cast

from pyspark import SparkConf
from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
from pyspark.sql.types import (
LongType,
Expand Down Expand Up @@ -56,7 +55,7 @@
class GroupedApplyInPandasWithStateTestsMixin:
@classmethod
def conf(cls):
cfg = SparkConf()
cfg = super().conf()
cfg.set("spark.sql.shuffle.partitions", "5")
return cfg

Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def conf(cls):
conf._jconf.remove("spark.master")
conf.set("spark.connect.execute.reattachable.senderMaxStreamDuration", "1s")
conf.set("spark.connect.execute.reattachable.senderMaxStreamSize", "123")
# Set a static token for all tests so the parallelism doesn't overwrite each
# tests' environment variables
conf.set("spark.connect.authenticate.token", "deadbeef")
return conf

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite {
assert(builder.host === "localhost")
assert(builder.port === 15002)
assert(builder.userAgent.contains("_SPARK_CONNECT_SCALA"))
assert(builder.sslEnabled)
assert(!builder.sslEnabled)
assert(builder.token.contains("thisismysecret"))
assert(builder.userId.isEmpty)
assert(builder.userName.isEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
TestPackURI(s"sc://host:123/;session_id=${UUID.randomUUID().toString}", isCorrect = true),
TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = true),
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = true),
TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = false),
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = false),
TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = true),
TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=false", isCorrect = true),
TestPackURI("sc://host:123/;param1=value1;param2=value2", isCorrect = true),
TestPackURI(
"sc://SPARK-45486",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectStartScript.exists(Files.exists(_))) {
val token = java.util.UUID.randomUUID().toString()
val serverId = UUID.randomUUID().toString
server = Some {
val args =
Expand All @@ -779,6 +780,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.environment().put("SPARK_IDENT_STRING", serverId)
pb.environment().put("HOSTNAME", "local")
pb.environment().put("SPARK_CONNECT_AUTHENTICATE_TOKEN", token)
pb.start()
}

Expand All @@ -800,7 +802,7 @@ object SparkSession extends SparkSessionCompanion with Logging {
}
}

System.setProperty("spark.remote", "sc://localhost")
System.setProperty("spark.remote", s"sc://localhost/;token=$token")

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,20 +468,14 @@ object SparkConnectClient {
* sc://localhost/;token=aaa;use_ssl=true
* }}}
*
* Throws exception if the token is set but use_ssl=false.
*
* @param inputToken
* the user token.
* @return
* this builder.
*/
def token(inputToken: String): Builder = {
require(inputToken != null && inputToken.nonEmpty)
if (_configuration.isSslEnabled.contains(false)) {
throw new IllegalArgumentException(AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
}
_configuration =
_configuration.copy(token = Option(inputToken), isSslEnabled = Option(true))
_configuration = _configuration.copy(token = Option(inputToken))
this
}

Expand All @@ -499,7 +493,6 @@ object SparkConnectClient {
* this builder.
*/
def disableSsl(): Builder = {
require(token.isEmpty, AUTH_TOKEN_ON_INSECURE_CONN_ERROR_MSG)
_configuration = _configuration.copy(isSslEnabled = Option(false))
this
}
Expand Down Expand Up @@ -737,6 +730,8 @@ object SparkConnectClient {
grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE,
grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) {

private def isLocal = host.equals("localhost")

def userContext: proto.UserContext = {
val builder = proto.UserContext.newBuilder()
if (userId != null) {
Expand All @@ -749,7 +744,7 @@ object SparkConnectClient {
}

def credentials: ChannelCredentials = {
if (isSslEnabled.contains(true)) {
if (isSslEnabled.contains(true) || (token.isDefined && !isLocal)) {
token match {
case Some(t) =>
// With access token added in the http header.
Expand All @@ -765,10 +760,18 @@ object SparkConnectClient {
}

def createChannel(): ManagedChannel = {
val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, credentials)
val creds = credentials
val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, creds)

// Workaround LocalChannelCredentials are added in
// https://github.com/grpc/grpc-java/issues/9900
var metadataWithOptionalToken = metadata
if (!isSslEnabled.contains(true) && isLocal && token.isDefined) {
metadataWithOptionalToken = metadata + (("Authorization", s"Bearer ${token.get}"))
}

if (metadata.nonEmpty) {
channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
if (metadataWithOptionalToken.nonEmpty) {
channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadataWithOptionalToken))
}

interceptors.foreach(channelBuilder.intercept(_))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.config

import java.util.concurrent.TimeUnit

import org.apache.spark.SparkEnv
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -313,4 +314,21 @@ object Connect {
.internal()
.booleanConf
.createWithDefault(true)

val CONNECT_AUTHENTICATE_TOKEN =
buildStaticConf("spark.connect.authenticate.token")
.doc("A pre-shared token that will be used to authenticate clients. This secret must be" +
" passed as a bearer token by for clients to connect.")
.version("4.0.0")
.internal()
.stringConf
.createOptional

val CONNECT_AUTHENTICATE_TOKEN_ENV = "SPARK_CONNECT_AUTHENTICATE_TOKEN"

def getAuthenticateToken: Option[String] = {
SparkEnv.get.conf.get(CONNECT_AUTHENTICATE_TOKEN).orElse {
Option(System.getenv.get(CONNECT_AUTHENTICATE_TOKEN_ENV))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, QUERY_ID, RUN_ID_STRING,
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders}
import org.apache.spark.sql.connect.common.ForeachWriterPacket
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.streaming.StreamingQuery
Expand Down Expand Up @@ -131,7 +132,10 @@ object StreamingForeachBatchHelper extends Logging {
sessionHolder: SessionHolder): (ForeachBatchFnType, AutoCloseable) = {

val port = SparkConnectService.localPort
val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
var connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
Connect.getAuthenticateToken.foreach { token =>
connectUrl = s"$connectUrl;token=$token"
}
val runner = StreamingPythonRunner(
pythonFn,
connectUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.FUNCTION_NAME
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService}
import org.apache.spark.sql.streaming.StreamingQueryListener

Expand All @@ -36,7 +37,10 @@ class PythonStreamingQueryListener(listener: SimplePythonFunction, sessionHolder
with Logging {

private val port = SparkConnectService.localPort
private val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
private var connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
Connect.getAuthenticateToken.foreach { token =>
connectUrl = s"$connectUrl;token=$token"
}
// Scoped for testing
private[connect] val runner = StreamingPythonRunner(
listener,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor, Status}

class PreSharedKeyAuthenticationInterceptor(token: String) extends ServerInterceptor {

val authorizationMetadataKey =
Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)

val expectedValue = s"Bearer $token"

override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
metadata: Metadata,
next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
val authHeaderValue = metadata.get(authorizationMetadataKey)

if (authHeaderValue == null) {
val status = Status.UNAUTHENTICATED.withDescription("No authentication token provided")
call.close(status, new Metadata())
new ServerCall.Listener[ReqT]() {}
} else if (authHeaderValue != expectedValue) {
val status = Status.UNAUTHENTICATED.withDescription("Invalid authentication token")
call.close(status, new Metadata())
new ServerCall.Listener[ReqT]() {}
} else {
next.startCall(call, metadata)
}
}
}
Loading

0 comments on commit cd513f5

Please sign in to comment.