Skip to content

Commit

Permalink
[SPARK-51267][CONNECT] Match local Spark Connect server logic between…
Browse files Browse the repository at this point in the history
… Python and Scala

This PR proposes to match local Spark Connect server logic between Python and Scala. This PR includes:

1. Synchronize the local server and terminates it on `SparkSession.stop()`  in Scala
2. Remove the internal `SPARK_LOCAL_CONNECT` environment variable and `spark.local.connect` configurations, and handle them in `SparkSubmitCommandBuilder.buildSparkSubmitArgs`, and do not send `spark.remote` and `spark.api.mode` when locally running Spark Connect server.

To have the consistent behaviours between Python and Scala Spark Connect.

No.

Manually:

```
./bin/spark-shell --master "local" --conf spark.api.mode=connect
```

```
./bin/spark-shell --remote "local[*]"
```

```
./bin/spark-shell --master "local" --conf spark.api.mode=classic
```

```
git clone https://github.com/HyukjinKwon/spark-connect-example
cd spark-connect-example
build/sbt package
cd ..
git clone https://github.com/apache/spark.git
cd spark
build/sbt package
sbin/start-connect-server.sh
bin/spark-submit --name "testApp" --remote "sc://localhost" --class com.hyukjinkwon.SparkConnectExample ../spark-connect-example/target/scala-2.13/spark-connect-example_2.13-0.0.1.jar
```

```
./bin/pyspark --master "local" --conf spark.api.mode=connect
```

```
./bin/pyspark --remote "local"
```

```
./bin/pyspark --conf spark.api.mode=classic
```

```
./bin/pyspark --conf spark.api.mode=connect
```

There is also an existing unittest with Yarn.

No.

Closes #50017 from HyukjinKwon/fix-connect-repl.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit 46e12a4)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Feb 21, 2025
1 parent 7128f1c commit 913bf0e
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 52 deletions.
13 changes: 4 additions & 9 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,12 @@ private[spark] class SparkSubmit extends Logging {
val childArgs = new ArrayBuffer[String]()
val childClasspath = new ArrayBuffer[String]()
val sparkConf = args.toSparkConf()
if (sparkConf.contains("spark.local.connect")) sparkConf.remove("spark.remote")
var childMainClass = ""

// Set the cluster manager
val clusterManager: Int = args.maybeMaster match {
case Some(v) =>
assert(args.maybeRemote.isEmpty || sparkConf.contains("spark.local.connect"))
assert(args.maybeRemote.isEmpty)
v match {
case "yarn" => YARN
case m if m.startsWith("spark") => STANDALONE
Expand Down Expand Up @@ -643,14 +642,11 @@ private[spark] class SparkSubmit extends Logging {
// All cluster managers
OptionAssigner(
// If remote is not set, sets the master,
// In local remote mode, starts the default master to to start the server.
if (args.maybeRemote.isEmpty || sparkConf.contains("spark.local.connect")) args.master
if (args.maybeRemote.isEmpty) args.master
else args.maybeMaster.orNull,
ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.master"),
OptionAssigner(
// In local remote mode, do not set remote.
if (sparkConf.contains("spark.local.connect")) null
else args.maybeRemote.orNull, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.remote"),
args.maybeRemote.orNull, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.remote"),
OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
confKey = SUBMIT_DEPLOY_MODE.key),
OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.app.name"),
Expand Down Expand Up @@ -767,8 +763,7 @@ private[spark] class SparkSubmit extends Logging {
// In case of shells, spark.ui.showConsoleProgress can be true by default or by user. Except,
// when Spark Connect is in local mode, because Spark Connect support its own progress
// reporting.
if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS) &&
!sparkConf.contains("spark.local.connect")) {
if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) {
sparkConf.set(UI_SHOW_CONSOLE_PROGRESS, true)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
if (args.length == 0) {
printUsageAndExit(-1)
}
if (!sparkProperties.contains("spark.local.connect") &&
maybeRemote.isDefined && (maybeMaster.isDefined || deployMode != null)) {
if (maybeRemote.isDefined && (maybeMaster.isDefined || deployMode != null)) {
error("Remote cannot be specified with master and/or deploy mode.")
}
if (primaryResource == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,8 @@ List<String> buildClassPath(String appClassPath) throws IOException {
addToClassPath(cp, f.toString());
}
}
// If we're in 'spark.local.connect', it should create a Spark Classic Spark Context
// that launches Spark Connect server.
if (isRemote && System.getenv("SPARK_LOCAL_CONNECT") == null) {

if (isRemote) {
for (File f: new File(jarsDir).listFiles()) {
// Exclude Spark Classic SQL and Spark Connect server jars
// if we're in Spark Connect Shell. Also exclude Spark SQL API and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ public class SparkLauncher extends AbstractLauncher<SparkLauncher> {

/** The Spark remote. */
public static final String SPARK_REMOTE = "spark.remote";
public static final String SPARK_LOCAL_REMOTE = "spark.local.connect";

/** The Spark API mode. */
public static final String SPARK_API_MODE = "spark.api.mode";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ public List<String> buildCommand(Map<String, String> env)
}

List<String> buildSparkSubmitArgs() {
return buildSparkSubmitArgs(true);
}

List<String> buildSparkSubmitArgs(boolean includeRemote) {
List<String> args = new ArrayList<>();
OptionParser parser = new OptionParser(false);
final boolean isSpecialCommand;
Expand All @@ -210,7 +214,7 @@ List<String> buildSparkSubmitArgs() {
args.add(master);
}

if (remote != null) {
if (includeRemote && remote != null) {
args.add(parser.REMOTE);
args.add(remote);
}
Expand All @@ -226,8 +230,12 @@ List<String> buildSparkSubmitArgs() {
}

for (Map.Entry<String, String> e : conf.entrySet()) {
args.add(parser.CONF);
args.add(String.format("%s=%s", e.getKey(), e.getValue()));
if (includeRemote ||
(!e.getKey().equalsIgnoreCase("spark.api.mode") &&
!e.getKey().equalsIgnoreCase("spark.remote"))) {
args.add(parser.CONF);
args.add(String.format("%s=%s", e.getKey(), e.getValue()));
}
}

if (propertiesFile != null) {
Expand Down Expand Up @@ -368,7 +376,8 @@ private List<String> buildPySparkShellCommand(Map<String, String> env) throws IO
// When launching the pyspark shell, the spark-submit arguments should be stored in the
// PYSPARK_SUBMIT_ARGS env variable.
appResource = PYSPARK_SHELL_RESOURCE;
constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS");
// Do not pass remote configurations to Spark Connect server via Py4J.
constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS", false);

// Will pick up the binary executable in the following order
// 1. conf spark.pyspark.driver.python
Expand All @@ -391,8 +400,7 @@ private List<String> buildPySparkShellCommand(Map<String, String> env) throws IO
String masterStr = firstNonEmpty(master, conf.getOrDefault(SparkLauncher.SPARK_MASTER, null));
String deployStr = firstNonEmpty(
deployMode, conf.getOrDefault(SparkLauncher.DEPLOY_MODE, null));
if (!conf.containsKey(SparkLauncher.SPARK_LOCAL_REMOTE) &&
remoteStr != null && (masterStr != null || deployStr != null)) {
if (remoteStr != null && (masterStr != null || deployStr != null)) {
throw new IllegalStateException("Remote cannot be specified with master and/or deploy mode.");
}

Expand Down Expand Up @@ -423,7 +431,7 @@ private List<String> buildSparkRCommand(Map<String, String> env) throws IOExcept
// When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS
// env variable.
appResource = SPARKR_SHELL_RESOURCE;
constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS");
constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS", true);

// Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up.
String sparkHome = System.getenv("SPARK_HOME");
Expand All @@ -438,12 +446,13 @@ private List<String> buildSparkRCommand(Map<String, String> env) throws IOExcept

private void constructEnvVarArgs(
Map<String, String> env,
String submitArgsEnvVariable) throws IOException {
String submitArgsEnvVariable,
boolean includeRemote) throws IOException {
mergeEnvPathList(env, getLibPathEnvName(),
getEffectiveConfig().get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH));

StringBuilder submitArgs = new StringBuilder();
for (String arg : buildSparkSubmitArgs()) {
for (String arg : buildSparkSubmitArgs(includeRemote)) {
if (submitArgs.length() > 0) {
submitArgs.append(" ");
}
Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ object MimaExcludes {
ProblemFilters.exclude[Problem]("org.sparkproject.spark_protobuf.protobuf.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.protobuf.utils.SchemaConverters.*"),

// SPARK-51267: Match local Spark Connect server logic between Python and Scala
ProblemFilters.exclude[MissingFieldProblem]("org.apache.spark.launcher.SparkLauncher.SPARK_LOCAL_REMOTE"),

(problem: Problem) => problem match {
case MissingClassProblem(cls) => !cls.fullName.startsWith("org.sparkproject.jpmml") &&
!cls.fullName.startsWith("org.sparkproject.dmg.pmml")
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,7 @@ def default_port() -> int:
# This is only used in the test/development mode.
session = PySparkSession._instantiatedSession

# 'spark.local.connect' is set when we use the local mode in Spark Connect.
if session is not None and session.conf.get("spark.local.connect", "0") == "1":
if session is not None:
jvm = PySparkSession._instantiatedSession._jvm # type: ignore[union-attr]
return getattr(
getattr(
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,8 +1044,10 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
# Configurations to be overwritten
overwrite_conf = opts
overwrite_conf["spark.master"] = master
overwrite_conf["spark.local.connect"] = "1"
os.environ["SPARK_LOCAL_CONNECT"] = "1"
if "spark.remote" in overwrite_conf:
del overwrite_conf["spark.remote"]
if "spark.api.mode" in overwrite_conf:
del overwrite_conf["spark.api.mode"]

# Configurations to be set if unset.
default_conf = {
Expand Down Expand Up @@ -1083,7 +1085,6 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
finally:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
del os.environ["SPARK_LOCAL_CONNECT"]
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,17 @@ class SparkSession private[sql] (
}
allocator.close()
SparkSession.onSessionClose(this)
SparkSession.server.synchronized {
if (SparkSession.server.isDefined) {
// When local mode is in use, follow the regular Spark session's
// behavior by terminating the Spark Connect server,
// meaning that you can stop local mode, and restart the Spark Connect
// client with a different remote address.
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
.start()
SparkSession.server = None
}
}
}

/** @inheritdoc */
Expand Down Expand Up @@ -679,6 +690,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private val maybeConnectStartScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
private val maybeConnectStopScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap
Expand All @@ -695,34 +710,37 @@ object SparkSession extends SparkSessionCompanion with Logging {
* Create a new Spark Connect server to connect locally.
*/
private[sql] def withLocalConnectServer[T](f: => T): T = {
synchronized {
lazy val isAPIModeConnect =
Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY))
.getOrElse("classic")
.toLowerCase(Locale.ROOT) == "connect"
val remoteString = sparkOptions
.get("spark.remote")
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))
.orElse {
if (isAPIModeConnect) {
sparkOptions.get("spark.master").orElse(sys.env.get("MASTER"))
} else {
None
}
lazy val isAPIModeConnect =
Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY))
.getOrElse("classic")
.toLowerCase(Locale.ROOT) == "connect"
val remoteString = sparkOptions
.get("spark.remote")
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))
.orElse {
if (isAPIModeConnect) {
sparkOptions.get("spark.master").orElse(sys.env.get("MASTER"))
} else {
None
}
}

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

server.synchronized {
if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectScript.exists(Files.exists(_))) {
maybeConnectStartScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
Seq(
maybeConnectStartScript.get.toString,
"--master",
remoteString.get) ++ (sparkOptions ++ Map(
"spark.sql.artifact.isolation.enabled" -> "true",
"spark.sql.artifact.isolation.alwaysApplyClassloader" -> "true"))
.filter(p => !p._1.startsWith("spark.remote"))
.filter(p => !p._1.startsWith("spark.api.mode"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
Expand All @@ -737,14 +755,17 @@ object SparkSession extends SparkSessionCompanion with Logging {

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
override def run(): Unit = server.synchronized {
if (server.isDefined) {
new ProcessBuilder(maybeConnectStopScript.get.toString)
.start()
}
}
})
// scalastyle:on runtimeaddshutdownhook
}
}

f
}

Expand Down

0 comments on commit 913bf0e

Please sign in to comment.