diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index 70164d7822836..9e67b6bac8b51 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -19,11 +19,7 @@ from typing import Any, TYPE_CHECKING, List, Optional, Union, Sequence from types import ModuleType -from pyspark.errors import ( - PySparkRuntimeError, - PySparkTypeError, - PySparkValueError, -) +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql import Column, functions as F from pyspark.sql.internal import InternalFunction as SF from pyspark.sql.pandas.utils import require_minimum_pandas_version @@ -38,14 +34,8 @@ class PySparkTopNPlotBase: def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame": - from pyspark.sql import SparkSession - - session = SparkSession.getActiveSession() - if session is None: - raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) - max_rows = int( - session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] + sdf._session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] ) pdf = sdf.limit(max_rows + 1).toPandas() @@ -59,16 +49,11 @@ def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame": class PySparkSampledPlotBase: def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame": - from pyspark.sql import SparkSession, Observation, functions as F - - session = SparkSession.getActiveSession() - if session is None: - raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) + from pyspark.sql import Observation, functions as F max_rows = int( - session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] + sdf._session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] ) - observation = Observation("pyspark plotting") rand_col_name = "__pyspark_plotting_sampled_plot_base_rand__"