Skip to content

Commit

Permalink
[SPARK-49530][PYTHON] Get active session from dataframes
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Get active session from dataframes

### Why are the changes needed?
we can directly get session from dataframes

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48735 from zhengruifeng/py_plot_session.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Nov 1, 2024
1 parent 2cb7232 commit 362d1c7
Showing 1 changed file with 4 additions and 19 deletions.
23 changes: 4 additions & 19 deletions python/pyspark/sql/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@

from typing import Any, TYPE_CHECKING, List, Optional, Union, Sequence
from types import ModuleType
from pyspark.errors import (
PySparkRuntimeError,
PySparkTypeError,
PySparkValueError,
)
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.sql import Column, functions as F
from pyspark.sql.internal import InternalFunction as SF
from pyspark.sql.pandas.utils import require_minimum_pandas_version
Expand All @@ -38,14 +34,8 @@

class PySparkTopNPlotBase:
def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame":
from pyspark.sql import SparkSession

session = SparkSession.getActiveSession()
if session is None:
raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())

max_rows = int(
session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
sdf._session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
)
pdf = sdf.limit(max_rows + 1).toPandas()

Expand All @@ -59,16 +49,11 @@ def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame":

class PySparkSampledPlotBase:
def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":
from pyspark.sql import SparkSession, Observation, functions as F

session = SparkSession.getActiveSession()
if session is None:
raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())
from pyspark.sql import Observation, functions as F

max_rows = int(
session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
sdf._session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
)

observation = Observation("pyspark plotting")

rand_col_name = "__pyspark_plotting_sampled_plot_base_rand__"
Expand Down

0 comments on commit 362d1c7

Please sign in to comment.