Skip to content

Commit

Permalink
Add method to asynchronously prepare CQL statements
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-antoniak committed Dec 2, 2024
1 parent 6e2ffd4 commit 814dbf2
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 38 deletions.
116 changes: 81 additions & 35 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2717,7 +2717,7 @@ def execute_async(self, query, parameters=None, trace=False, custom_payload=None
if execute_as:
custom_payload[_proxy_execute_key] = execute_as.encode()

future = self._create_response_future(
future = self._create_execute_response_future(
query, parameters, trace, custom_payload, timeout,
execution_profile, paging_state, host)
future._protocol_handler = self.client_protocol_handler
Expand Down Expand Up @@ -2782,8 +2782,8 @@ def execute_graph_async(self, query, parameters=None, trace=False, execution_pro
custom_payload[_proxy_execute_key] = execute_as.encode()
custom_payload[_request_timeout_key] = int64_pack(int(execution_profile.request_timeout * 1000))

future = self._create_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload,
timeout=_NOT_SET, execution_profile=execution_profile)
future = self._create_execute_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload,
timeout=_NOT_SET, execution_profile=execution_profile)

future.message.query_params = graph_parameters
future._protocol_handler = self.client_protocol_handler
Expand Down Expand Up @@ -2885,9 +2885,9 @@ def _transform_params(self, parameters, graph_options):

def _target_analytics_master(self, future):
future._start_timer()
master_query_future = self._create_response_future("CALL DseClientTool.getAnalyticsGraphServer()",
parameters=None, trace=False,
custom_payload=None, timeout=future.timeout)
master_query_future = self._create_execute_response_future("CALL DseClientTool.getAnalyticsGraphServer()",
parameters=None, trace=False,
custom_payload=None, timeout=future.timeout)
master_query_future.row_factory = tuple_factory
master_query_future.send_request()

Expand All @@ -2910,9 +2910,37 @@ def _on_analytics_master_result(self, response, master_future, query_future):

self.submit(query_future.send_request)

def _create_response_future(self, query, parameters, trace, custom_payload,
timeout, execution_profile=EXEC_PROFILE_DEFAULT,
paging_state=None, host=None):
def prepare_async(self, query, custom_payload=None, keyspace=None):
"""
Prepare the given query and return a :class:`~.PrepareFuture`
object. You may also call :meth:`~.PrepareFuture.result()`
on the :class:`.PrepareFuture` to synchronously block for
prepared statement object at any time.
See :meth:`Session.prepare` for parameter definitions.
Example usage::
>>> future = session.prepare_async("SELECT * FROM mycf")
>>> # do other stuff...
>>> try:
... prepared_statement = future.result()
... except Exception:
... log.exception("Operation failed:")
"""
future = self._create_prepare_response_future(query, keyspace, custom_payload)
future._protocol_handler = self.client_protocol_handler
self._on_request(future)
future.send_request()
return future

def _create_prepare_response_future(self, query, keyspace, custom_payload):
return PrepareFuture(self, query, keyspace, custom_payload, self.default_timeout)

def _create_execute_response_future(self, query, parameters, trace, custom_payload,
timeout, execution_profile=EXEC_PROFILE_DEFAULT,
paging_state=None, host=None):
""" Returns the ResponseFuture before calling send_request() on it """

prepared_statement = None
Expand Down Expand Up @@ -3121,33 +3149,9 @@ def prepare(self, query, custom_payload=None, keyspace=None):
`custom_payload` is a key value map to be passed along with the prepare
message. See :ref:`custom_payload`.
"""
message = PrepareMessage(query=query, keyspace=keyspace)
future = ResponseFuture(self, message, query=None, timeout=self.default_timeout)
try:
future.send_request()
response = future.result().one()
except Exception:
log.exception("Error preparing query:")
raise
return self.prepare_async(query, custom_payload, keyspace).result()

prepared_keyspace = keyspace if keyspace else None
prepared_statement = PreparedStatement.from_message(
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
prepared_statement.custom_payload = future.custom_payload

self.cluster.add_prepared(response.query_id, prepared_statement)

if self.cluster.prepare_on_all_hosts:
host = future._current_host
try:
self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace)
except Exception:
log.exception("Error preparing query on all hosts:")

return prepared_statement

def prepare_on_all_hosts(self, query, excluded_host, keyspace=None):
def prepare_on_all_nodes(self, query, excluded_host, keyspace=None):
"""
Prepare the given query on all hosts, excluding ``excluded_host``.
Intended for internal use only.
Expand Down Expand Up @@ -5105,6 +5109,48 @@ def __str__(self):
__repr__ = __str__


class PrepareFuture(ResponseFuture):
_final_prepare_result = _NOT_SET

def __init__(self, session, query, keyspace, custom_payload, timeout):
super().__init__(session, PrepareMessage(query=query, keyspace=keyspace), None, timeout)
self.query_string = query
self._keyspace = keyspace
self._custom_payload = custom_payload

def _set_final_result(self, response):
session = self.session
cluster = session.cluster
prepared_statement = PreparedStatement.from_message(
response.query_id, response.bind_metadata, response.pk_indexes, cluster.metadata, self.query_string,
self._keyspace, session._protocol_version, response.column_metadata, response.result_metadata_id,
cluster.column_encryption_policy)
prepared_statement.custom_payload = response.custom_payload
cluster.add_prepared(response.query_id, prepared_statement)
self._final_prepare_result = prepared_statement

if cluster.prepare_on_all_hosts:
# trigger asynchronous preparation of query on other C* nodes,
# we are on event loop thread, so do not execute those synchronously
session.submit(
session.prepare_on_all_nodes,
self.query_string, self._current_host, self._keyspace)

super()._set_final_result(response)

def result(self):
self._event.wait()
if self._final_prepare_result is not _NOT_SET:
return self._final_prepare_result
else:
raise self._final_exception

def __str__(self):
result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result
return "<PrepareFuture: query='%s' request_id=%s result=%s exception=%s coordinator_host=%s>" \
% (self.query_string, self._req_id, result, self._final_exception, self.coordinator_host)
__repr__ = __str__

class QueryExhausted(Exception):
"""
Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and
Expand Down
2 changes: 1 addition & 1 deletion cassandra/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None):

def _execute(self, query, parameters, time_spent, max_wait):
timeout = (max_wait - time_spent) if max_wait is not None else None
future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout)
future = self._session._create_execute_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout)
# in case the user switched the row factory, set it to namedtuple for this query
future.row_factory = named_tuple_factory
future.send_request()
Expand Down
78 changes: 78 additions & 0 deletions tests/integration/standard/test_prepared_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cassandra import InvalidRequest, DriverException

from cassandra import ConsistencyLevel, ProtocolVersion
from cassandra.cluster import PrepareFuture
from cassandra.query import PreparedStatement, UNSET_VALUE
from tests.integration import (get_server_versions, greaterthanorequalcass40, greaterthanorequaldse50,
requirecassandra, BasicSharedKeyspaceUnitTestCase)
Expand Down Expand Up @@ -121,6 +122,83 @@ def test_basic(self):
results = self.session.execute(bound)
self.assertEqual(results, [('x', 'y', 'z')])

def test_basic_async(self):
"""
Test basic asynchronous PreparedStatement usage
"""
self.session.execute(
"""
DROP KEYSPACE IF EXISTS preparedtests
"""
)
self.session.execute(
"""
CREATE KEYSPACE preparedtests
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
""")

self.session.set_keyspace("preparedtests")
self.session.execute(
"""
CREATE TABLE cf0 (
a text,
b text,
c text,
PRIMARY KEY (a, b)
)
""")

prepared_future = self.session.prepare_async(
"""
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
""")
self.assertIsInstance(prepared_future, PrepareFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind(('a', 'b', 'c'))
self.session.execute(bound)

prepared_future = self.session.prepare_async(
"""
SELECT * FROM cf0 WHERE a=?
""")
self.assertIsInstance(prepared_future, PrepareFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind(('a'))
results = self.session.execute(bound)
self.assertEqual(results, [('a', 'b', 'c')])

# test with new dict binding
prepared_future = self.session.prepare_async(
"""
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
""")
self.assertIsInstance(prepared_future, PrepareFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind({
'a': 'x',
'b': 'y',
'c': 'z'
})
self.session.execute(bound)

prepared_future = self.session.prepare_async(
"""
SELECT * FROM cf0 WHERE a=?
""")
self.assertIsInstance(prepared_future, PrepareFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind({'a': 'x'})
results = self.session.execute(bound)
self.assertEqual(results, [('x', 'y', 'z')])

def test_missing_primary_key(self):
"""
Ensure an InvalidRequest is thrown
Expand Down
14 changes: 12 additions & 2 deletions tests/integration/standard/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,6 @@ def test_prepared_statement(self):

prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)')
prepared.consistency_level = ConsistencyLevel.ONE

self.assertEqual(str(prepared),
'<PreparedStatement query="INSERT INTO test3rf.test (k, v) VALUES (?, ?)", consistency=ONE>')

Expand Down Expand Up @@ -717,6 +716,17 @@ def test_prepared_statements(self):
self.session.execute_async(batch).result()
self.confirm_results()

def test_prepare_async(self):
prepared = self.session.prepare_async("INSERT INTO test3rf.test (k, v) VALUES (?, ?)").result()

batch = BatchStatement(BatchType.LOGGED)
for i in range(10):
batch.add(prepared, (i, i))

self.session.execute(batch)
self.session.execute_async(batch).result()
self.confirm_results()

def test_bound_statements(self):
prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

Expand Down Expand Up @@ -942,7 +952,7 @@ def test_no_connection_refused_on_timeout(self):
exception_type = type(result).__name__
if exception_type == "NoHostAvailable":
self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message)
if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub"]:
if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessage", "ErrorMessageSub"]:
if type(result).__name__ in ["WriteTimeout", "WriteFailure"]:
received_timeout = True
continue
Expand Down

0 comments on commit 814dbf2

Please sign in to comment.