Skip to content

Commit

Permalink
[SPARK-51275][PYTHON][ML][CONNECT] Session propagation in python read…
Browse files Browse the repository at this point in the history
…write

### What changes were proposed in this pull request?
Session propagation in python readwrite

### Why are the changes needed?
to avoid session recreation

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
existing test should cover

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #50035 from zhengruifeng/py_ml_sc_session.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Feb 21, 2025
1 parent 46e12a4 commit 3d76e0b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
8 changes: 5 additions & 3 deletions python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,11 @@ def saveImpl(
"""
stageUids = [stage.uid for stage in stages]
jsonParams = {"stageUids": stageUids, "language": "Python"}
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else SparkSession.active()
DefaultParamsWriter.saveMetadata(instance, path, spark, paramMap=jsonParams)
stagesDir = os.path.join(path, "stages")
for index, stage in enumerate(stages):
cast(MLWritable, stage).write().save(
cast(MLWritable, stage).write().session(spark).save(
PipelineSharedReadWrite.getStagePath(stage.uid, index, len(stages), stagesDir)
)

Expand All @@ -443,12 +444,13 @@ def load(
"""
stagesDir = os.path.join(path, "stages")
stageUids = metadata["paramMap"]["stageUids"]
spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else SparkSession.active()
stages = []
for index, stageUid in enumerate(stageUids):
stagePath = PipelineSharedReadWrite.getStagePath(
stageUid, index, len(stageUids), stagesDir
)
stage: "PipelineStage" = DefaultParamsReader.loadParamsInstance(stagePath, sc)
stage: "PipelineStage" = DefaultParamsReader.loadParamsInstance(stagePath, spark)
stages.append(stage)
return (metadata["uid"], stages)

Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def sparkSession(self) -> SparkSession:
Returns the user-specified Spark Session or the default.
"""
if self._sparkSession is None:
self._sparkSession = SparkSession._getActiveSessionOrCreate()
self._sparkSession = SparkSession.active()
assert self._sparkSession is not None
return self._sparkSession

Expand Down Expand Up @@ -809,10 +809,10 @@ def saveMetadata(
If given, this is saved in the "paramMap" field.
"""
metadataPath = os.path.join(path, "metadata")
spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else SparkSession.active()
metadataJson = DefaultParamsWriter._get_metadata_to_save(
instance, sc, extraMetadata, paramMap
instance, spark, extraMetadata, paramMap
)
spark = sc if isinstance(sc, SparkSession) else SparkSession._getActiveSessionOrCreate()
spark.createDataFrame([(metadataJson,)], schema=["value"]).coalesce(1).write.text(
metadataPath
)
Expand Down Expand Up @@ -932,7 +932,7 @@ def loadMetadata(
If non empty, this is checked against the loaded metadata.
"""
metadataPath = os.path.join(path, "metadata")
spark = sc if isinstance(sc, SparkSession) else SparkSession._getActiveSessionOrCreate()
spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else SparkSession.active()
metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[index]
loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName)
return loadedVals
Expand Down

0 comments on commit 3d76e0b

Please sign in to comment.