Skip to content

Commit

Permalink
[DOP-22427] Do not keep open JDBC connection on Spark driver
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Feb 5, 2025
1 parent 0bc6eac commit 5a0ffce
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 52 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/next_release/334.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Since now all JDBC connections opened by ``connection.fetch(...)``, ``connection.execute(...)`` or ``connection.check()``
are immediately closed after the statements is executed.

Previously Spark session with ``master=local[1]`` opened 3 connections to target DB - one for ``.check()``,
another for Spark driver interaction with DB to create tables, and last one for Spark executor. Now only 2 connections are opened.
This is important for RDBMS like Postgres or Greenplum where number of connections is strictly limited and limit is usually quite low.
8 changes: 3 additions & 5 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@
import logging
import os
import textwrap
import threading
import warnings
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from typing import TYPE_CHECKING, Any, ClassVar
from urllib.parse import quote, urlencode, urlparse, urlunparse

from etl_entities.instance import Host

from onetl.connection.db_connection.jdbc_connection.options import JDBCReadOptions

try:
from pydantic.v1 import PrivateAttr, SecretStr, validator
from pydantic.v1 import SecretStr, validator
except (ImportError, AttributeError):
from pydantic import validator, SecretStr, PrivateAttr # type: ignore[no-redef, assignment]
from pydantic import validator, SecretStr # type: ignore[no-redef, assignment]

from onetl._util.classproperty import classproperty
from onetl._util.java import try_import_java_class
Expand Down Expand Up @@ -182,7 +181,6 @@ class Greenplum(JDBCMixin, DBConnection): # noqa: WPS338
CONNECTIONS_EXCEPTION_LIMIT: ClassVar[int] = 100

_CHECK_QUERY: ClassVar[str] = "SELECT 1"
_last_connection_and_options: Optional[threading.local] = PrivateAttr(default=None)

@slot
@classmethod
Expand Down
8 changes: 3 additions & 5 deletions onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

import logging
import secrets
import threading
import warnings
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from typing import TYPE_CHECKING, Any, ClassVar

try:
from pydantic.v1 import PrivateAttr, SecretStr, validator
from pydantic.v1 import SecretStr, validator
except (ImportError, AttributeError):
from pydantic import PrivateAttr, SecretStr, validator # type: ignore[no-redef, assignment]
from pydantic import SecretStr, validator # type: ignore[no-redef, assignment]

from onetl._util.java import try_import_java_class
from onetl._util.spark import override_job_description
Expand Down Expand Up @@ -65,7 +64,6 @@ class JDBCConnection(JDBCMixin, DBConnection): # noqa: WPS338

DRIVER: ClassVar[str]
_CHECK_QUERY: ClassVar[str] = "SELECT 1"
_last_connection_and_options: Optional[threading.local] = PrivateAttr(default=None)

JDBCOptions = JDBCMixinOptions
FetchOptions = JDBCFetchOptions
Expand Down
64 changes: 22 additions & 42 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from __future__ import annotations

import logging
import threading
import warnings
from abc import abstractmethod
from contextlib import closing, suppress
from contextlib import closing
from enum import Enum, auto
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, TypeVar
from typing import TYPE_CHECKING, Callable, ClassVar, TypeVar

try:
from pydantic.v1 import Field, PrivateAttr, SecretStr
from pydantic.v1 import Field, SecretStr
except (ImportError, AttributeError):
from pydantic import Field, PrivateAttr, SecretStr # type: ignore[no-redef, assignment]
from pydantic import Field, SecretStr # type: ignore[no-redef, assignment]

from onetl._metrics.command import SparkCommandMetrics
from onetl._util.java import get_java_gateway
Expand Down Expand Up @@ -79,9 +79,6 @@ class JDBCMixin:
DRIVER: ClassVar[str]
_CHECK_QUERY: ClassVar[str] = "SELECT 1"

# cached JDBC connection (Java object), plus corresponding GenericOptions (Python object)
_last_connection_and_options: Optional[threading.local] = PrivateAttr(default=None)

@property
@abstractmethod
def jdbc_url(self) -> str:
Expand All @@ -102,6 +99,9 @@ def close(self):
"""
Close all connections, opened by ``.fetch()``, ``.execute()`` or ``.check()`` methods. |support_hooks|
.. deprecated:: 0.13.0
Connections are now closed immediately. Method is now no-op.
.. note::
Connection can be used again after it was closed.
Expand All @@ -128,7 +128,11 @@ def close(self):
"""

self._close_connections()
warnings.warn(
"Connections are now closed immediately. Method is no-op since 0.13.0",
UserWarning,
stacklevel=2,
)
return self

def __enter__(self):
Expand Down Expand Up @@ -382,26 +386,9 @@ def _options_to_connection_properties(self, options: JDBCFetchOptions | JDBCExec
return jdbc_options.asConnectionProperties()

def _get_jdbc_connection(self, options: JDBCFetchOptions | JDBCExecuteOptions):
if not self._last_connection_and_options:
# connection class can be used in multiple threads.
# each Python thread creates its own thread in JVM
# so we need local variable to create per-thread persistent connection
self._last_connection_and_options = threading.local()

with suppress(Exception): # nothing cached, or JVM failed
last_connection, last_options = self._last_connection_and_options.data
if options == last_options and not last_connection.isClosed():
return last_connection

# only one connection can be opened in one moment of time
last_connection.close()

connection_properties = self._options_to_connection_properties(options)
driver_manager = self.spark._jvm.java.sql.DriverManager # type: ignore
new_connection = driver_manager.getConnection(self.jdbc_url, connection_properties)

self._last_connection_and_options.data = (new_connection, options)
return new_connection
return driver_manager.getConnection(self.jdbc_url, connection_properties)

def _get_spark_dialect_name(self) -> str:
"""
Expand All @@ -411,19 +398,9 @@ def _get_spark_dialect_name(self) -> str:
return dialect.split("$")[0] if "$" in dialect else dialect

def _get_spark_dialect(self):
jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc
jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc # type: ignore
return jdbc_dialects_package.JdbcDialects.get(self.jdbc_url)

def _close_connections(self):
with suppress(Exception):
# connection maybe not opened yet
last_connection, _ = self._last_connection_and_options.data
last_connection.close()

with suppress(Exception):
# connection maybe not opened yet
del self._last_connection_and_options.data

def _get_statement_args(self) -> tuple[int, ...]:
resultset = self.spark._jvm.java.sql.ResultSet # type: ignore

Expand All @@ -442,15 +419,18 @@ def _execute_on_driver(
Almost like ``org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD`` is fetching data:
* https://github.com/apache/spark/blob/v2.3.0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala#L297-L306
Each time new connection is opened to execute the statement, and then closed.
"""

jdbc_connection = self._get_jdbc_connection(options)
jdbc_connection.setReadOnly(read_only) # type: ignore
with closing(jdbc_connection):
jdbc_connection.setReadOnly(read_only) # type: ignore

statement_args = self._get_statement_args()
jdbc_statement = self._build_statement(statement, statement_type, jdbc_connection, statement_args)
statement_args = self._get_statement_args()
jdbc_statement = self._build_statement(statement, statement_type, jdbc_connection, statement_args)

return self._execute_statement(jdbc_statement, statement, options, callback, read_only)
return self._execute_statement(jdbc_statement, statement, options, callback, read_only)

def _execute_statement(
self,
Expand Down

0 comments on commit 5a0ffce

Please sign in to comment.