From a47b75a78f62ef0e2a0d1361c119508ae8716cf2 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 28 Nov 2024 17:10:15 +0100 Subject: [PATCH 1/5] Add method to asynchronously prepare CQL statements --- cassandra/cluster.py | 116 ++++++++++++------ cassandra/query.py | 2 +- .../standard/test_prepared_statements.py | 78 ++++++++++++ tests/integration/standard/test_query.py | 14 ++- 4 files changed, 172 insertions(+), 38 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index d5f80290a9..ea965fce2e 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2717,7 +2717,7 @@ def execute_async(self, query, parameters=None, trace=False, custom_payload=None if execute_as: custom_payload[_proxy_execute_key] = execute_as.encode() - future = self._create_response_future( + future = self._create_execute_response_future( query, parameters, trace, custom_payload, timeout, execution_profile, paging_state, host) future._protocol_handler = self.client_protocol_handler @@ -2782,8 +2782,8 @@ def execute_graph_async(self, query, parameters=None, trace=False, execution_pro custom_payload[_proxy_execute_key] = execute_as.encode() custom_payload[_request_timeout_key] = int64_pack(int(execution_profile.request_timeout * 1000)) - future = self._create_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload, - timeout=_NOT_SET, execution_profile=execution_profile) + future = self._create_execute_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload, + timeout=_NOT_SET, execution_profile=execution_profile) future.message.query_params = graph_parameters future._protocol_handler = self.client_protocol_handler @@ -2885,9 +2885,9 @@ def _transform_params(self, parameters, graph_options): def _target_analytics_master(self, future): future._start_timer() - master_query_future = self._create_response_future("CALL DseClientTool.getAnalyticsGraphServer()", - parameters=None, trace=False, - custom_payload=None, timeout=future.timeout) + master_query_future = self._create_execute_response_future("CALL DseClientTool.getAnalyticsGraphServer()", + parameters=None, trace=False, + custom_payload=None, timeout=future.timeout) master_query_future.row_factory = tuple_factory master_query_future.send_request() @@ -2910,9 +2910,37 @@ def _on_analytics_master_result(self, response, master_future, query_future): self.submit(query_future.send_request) - def _create_response_future(self, query, parameters, trace, custom_payload, - timeout, execution_profile=EXEC_PROFILE_DEFAULT, - paging_state=None, host=None): + def prepare_async(self, query, custom_payload=None, keyspace=None): + """ + Prepare the given query and return a :class:`~.PrepareFuture` + object. You may also call :meth:`~.PrepareFuture.result()` + on the :class:`.PrepareFuture` to synchronously block for + prepared statement object at any time. + + See :meth:`Session.prepare` for parameter definitions. + + Example usage:: + + >>> future = session.prepare_async("SELECT * FROM mycf") + >>> # do other stuff... + + >>> try: + ... prepared_statement = future.result() + ... except Exception: + ... log.exception("Operation failed:") + """ + future = self._create_prepare_response_future(query, keyspace, custom_payload) + future._protocol_handler = self.client_protocol_handler + self._on_request(future) + future.send_request() + return future + + def _create_prepare_response_future(self, query, keyspace, custom_payload): + return PrepareFuture(self, query, keyspace, custom_payload, self.default_timeout) + + def _create_execute_response_future(self, query, parameters, trace, custom_payload, + timeout, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None): """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None @@ -3121,33 +3149,9 @@ def prepare(self, query, custom_payload=None, keyspace=None): `custom_payload` is a key value map to be passed along with the prepare message. See :ref:`custom_payload`. """ - message = PrepareMessage(query=query, keyspace=keyspace) - future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) - try: - future.send_request() - response = future.result().one() - except Exception: - log.exception("Error preparing query:") - raise + return self.prepare_async(query, custom_payload, keyspace).result() - prepared_keyspace = keyspace if keyspace else None - prepared_statement = PreparedStatement.from_message( - response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace, - self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy) - prepared_statement.custom_payload = future.custom_payload - - self.cluster.add_prepared(response.query_id, prepared_statement) - - if self.cluster.prepare_on_all_hosts: - host = future._current_host - try: - self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace) - except Exception: - log.exception("Error preparing query on all hosts:") - - return prepared_statement - - def prepare_on_all_hosts(self, query, excluded_host, keyspace=None): + def prepare_on_all_nodes(self, query, excluded_host, keyspace=None): """ Prepare the given query on all hosts, excluding ``excluded_host``. Intended for internal use only. @@ -5105,6 +5109,48 @@ def __str__(self): __repr__ = __str__ +class PrepareFuture(ResponseFuture): + _final_prepare_result = _NOT_SET + + def __init__(self, session, query, keyspace, custom_payload, timeout): + super().__init__(session, PrepareMessage(query=query, keyspace=keyspace), None, timeout) + self.query_string = query + self._keyspace = keyspace + self._custom_payload = custom_payload + + def _set_final_result(self, response): + session = self.session + cluster = session.cluster + prepared_statement = PreparedStatement.from_message( + response.query_id, response.bind_metadata, response.pk_indexes, cluster.metadata, self.query_string, + self._keyspace, session._protocol_version, response.column_metadata, response.result_metadata_id, + cluster.column_encryption_policy) + prepared_statement.custom_payload = response.custom_payload + cluster.add_prepared(response.query_id, prepared_statement) + self._final_prepare_result = prepared_statement + + if cluster.prepare_on_all_hosts: + # trigger asynchronous preparation of query on other C* nodes, + # we are on event loop thread, so do not execute those synchronously + session.submit( + session.prepare_on_all_nodes, + self.query_string, self._current_host, self._keyspace) + + super()._set_final_result(response) + + def result(self): + self._event.wait() + if self._final_prepare_result is not _NOT_SET: + return self._final_prepare_result + else: + raise self._final_exception + + def __str__(self): + result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result + return "" \ + % (self.query_string, self._req_id, result, self._final_exception, self.coordinator_host) + __repr__ = __str__ + class QueryExhausted(Exception): """ Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and diff --git a/cassandra/query.py b/cassandra/query.py index e29c2a3113..be123c3135 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -1035,7 +1035,7 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): def _execute(self, query, parameters, time_spent, max_wait): timeout = (max_wait - time_spent) if max_wait is not None else None - future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) + future = self._session._create_execute_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) # in case the user switched the row factory, set it to namedtuple for this query future.row_factory = named_tuple_factory future.send_request() diff --git a/tests/integration/standard/test_prepared_statements.py b/tests/integration/standard/test_prepared_statements.py index a643b19c07..615d8cf0f8 100644 --- a/tests/integration/standard/test_prepared_statements.py +++ b/tests/integration/standard/test_prepared_statements.py @@ -22,6 +22,7 @@ from cassandra import InvalidRequest, DriverException from cassandra import ConsistencyLevel, ProtocolVersion +from cassandra.cluster import PrepareFuture from cassandra.query import PreparedStatement, UNSET_VALUE from tests.integration import (get_server_versions, greaterthanorequalcass40, greaterthanorequaldse50, requirecassandra, BasicSharedKeyspaceUnitTestCase) @@ -121,6 +122,83 @@ def test_basic(self): results = self.session.execute(bound) self.assertEqual(results, [('x', 'y', 'z')]) + def test_basic_async(self): + """ + Test basic asynchronous PreparedStatement usage + """ + self.session.execute( + """ + DROP KEYSPACE IF EXISTS preparedtests + """ + ) + self.session.execute( + """ + CREATE KEYSPACE preparedtests + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + """) + + self.session.set_keyspace("preparedtests") + self.session.execute( + """ + CREATE TABLE cf0 ( + a text, + b text, + c text, + PRIMARY KEY (a, b) + ) + """) + + prepared_future = self.session.prepare_async( + """ + INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) + """) + self.assertIsInstance(prepared_future, PrepareFuture) + prepared = prepared_future.result() + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind(('a', 'b', 'c')) + self.session.execute(bound) + + prepared_future = self.session.prepare_async( + """ + SELECT * FROM cf0 WHERE a=? + """) + self.assertIsInstance(prepared_future, PrepareFuture) + prepared = prepared_future.result() + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind(('a')) + results = self.session.execute(bound) + self.assertEqual(results, [('a', 'b', 'c')]) + + # test with new dict binding + prepared_future = self.session.prepare_async( + """ + INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) + """) + self.assertIsInstance(prepared_future, PrepareFuture) + prepared = prepared_future.result() + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind({ + 'a': 'x', + 'b': 'y', + 'c': 'z' + }) + self.session.execute(bound) + + prepared_future = self.session.prepare_async( + """ + SELECT * FROM cf0 WHERE a=? + """) + self.assertIsInstance(prepared_future, PrepareFuture) + prepared = prepared_future.result() + self.assertIsInstance(prepared, PreparedStatement) + + bound = prepared.bind({'a': 'x'}) + results = self.session.execute(bound) + self.assertEqual(results, [('x', 'y', 'z')]) + def test_missing_primary_key(self): """ Ensure an InvalidRequest is thrown diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 89486802b4..4ddbd0ab0a 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -647,7 +647,6 @@ def test_prepared_statement(self): prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)') prepared.consistency_level = ConsistencyLevel.ONE - self.assertEqual(str(prepared), '') @@ -717,6 +716,17 @@ def test_prepared_statements(self): self.session.execute_async(batch).result() self.confirm_results() + def test_prepare_async(self): + prepared = self.session.prepare_async("INSERT INTO test3rf.test (k, v) VALUES (?, ?)").result() + + batch = BatchStatement(BatchType.LOGGED) + for i in range(10): + batch.add(prepared, (i, i)) + + self.session.execute(batch) + self.session.execute_async(batch).result() + self.confirm_results() + def test_bound_statements(self): prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)") @@ -942,7 +952,7 @@ def test_no_connection_refused_on_timeout(self): exception_type = type(result).__name__ if exception_type == "NoHostAvailable": self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message) - if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub"]: + if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub", "ErrorMessage"]: if type(result).__name__ in ["WriteTimeout", "WriteFailure"]: received_timeout = True continue From 0e2903e5371f924012a65a24053ca36eb410c07e Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 3 Dec 2024 15:14:00 +0100 Subject: [PATCH 2/5] Preserve synchronous prepare logic when preparing statement on all nodes --- cassandra/cluster.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index ea965fce2e..b25e85db7f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2910,7 +2910,7 @@ def _on_analytics_master_result(self, response, master_future, query_future): self.submit(query_future.send_request) - def prepare_async(self, query, custom_payload=None, keyspace=None): + def prepare_async(self, query, custom_payload=None, keyspace=None, prepare_on_all_hosts=None): """ Prepare the given query and return a :class:`~.PrepareFuture` object. You may also call :meth:`~.PrepareFuture.result()` @@ -2929,14 +2929,16 @@ def prepare_async(self, query, custom_payload=None, keyspace=None): ... except Exception: ... log.exception("Operation failed:") """ - future = self._create_prepare_response_future(query, keyspace, custom_payload) + if prepare_on_all_hosts is None: + prepare_on_all_hosts = self.cluster.prepare_on_all_hosts + future = self._create_prepare_response_future(query, keyspace, custom_payload, prepare_on_all_hosts) future._protocol_handler = self.client_protocol_handler self._on_request(future) future.send_request() return future - def _create_prepare_response_future(self, query, keyspace, custom_payload): - return PrepareFuture(self, query, keyspace, custom_payload, self.default_timeout) + def _create_prepare_response_future(self, query, keyspace, custom_payload, prepare_on_all_hosts): + return PrepareFuture(self, query, keyspace, custom_payload, self.default_timeout, prepare_on_all_hosts) def _create_execute_response_future(self, query, parameters, trace, custom_payload, timeout, execution_profile=EXEC_PROFILE_DEFAULT, @@ -3149,7 +3151,17 @@ def prepare(self, query, custom_payload=None, keyspace=None): `custom_payload` is a key value map to be passed along with the prepare message. See :ref:`custom_payload`. """ - return self.prepare_async(query, custom_payload, keyspace).result() + future = self.prepare_async(query, custom_payload, keyspace, prepare_on_all_hosts=False) + response = future.result() + if self.cluster.prepare_on_all_hosts: + # prepare on all hosts in a synchronous way, not asynchronously + # as internally in prepare_async() (PrepareFuture) + host = future._current_host + try: + self.prepare_on_all_nodes(response.query_string, host, response.keyspace) + except Exception: + log.exception("Error preparing query on all hosts:") + return response def prepare_on_all_nodes(self, query, excluded_host, keyspace=None): """ @@ -5112,9 +5124,10 @@ def __str__(self): class PrepareFuture(ResponseFuture): _final_prepare_result = _NOT_SET - def __init__(self, session, query, keyspace, custom_payload, timeout): + def __init__(self, session, query, keyspace, custom_payload, timeout, prepare_on_all_hosts): super().__init__(session, PrepareMessage(query=query, keyspace=keyspace), None, timeout) self.query_string = query + self._prepare_on_all_hosts = prepare_on_all_hosts self._keyspace = keyspace self._custom_payload = custom_payload @@ -5129,7 +5142,7 @@ def _set_final_result(self, response): cluster.add_prepared(response.query_id, prepared_statement) self._final_prepare_result = prepared_statement - if cluster.prepare_on_all_hosts: + if self._prepare_on_all_hosts: # trigger asynchronous preparation of query on other C* nodes, # we are on event loop thread, so do not execute those synchronously session.submit( From f3b58c79f7b13e4bdddd1dc921456838bfbbc1e6 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 4 Dec 2024 11:24:59 +0100 Subject: [PATCH 3/5] Document difference in prepare_on_all_hosts handling --- cassandra/cluster.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index b25e85db7f..6808a0c790 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2928,6 +2928,10 @@ def prepare_async(self, query, custom_payload=None, keyspace=None, prepare_on_al ... prepared_statement = future.result() ... except Exception: ... log.exception("Operation failed:") + + When :meth:`~.Cluster.prepare_on_all_hosts` is enabled, method + attempts to prepare given query on all hosts, but does not wait + for their response. """ if prepare_on_all_hosts is None: prepare_on_all_hosts = self.cluster.prepare_on_all_hosts @@ -3148,6 +3152,11 @@ def prepare(self, query, custom_payload=None, keyspace=None): **Important**: PreparedStatements should be prepared only once. Preparing the same query more than once will likely affect performance. + When :meth:`~.Cluster.prepare_on_all_hosts` is enabled, method + attempts to prepare given query on all hosts and waits for each node to respond. + Preparing CQL query on other nodes may fail, but error is not propagated + to the caller. + `custom_payload` is a key value map to be passed along with the prepare message. See :ref:`custom_payload`. """ From 3453cb19f820202d65bc4ed4a82f1aea24a31b11 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 6 Dec 2024 10:47:13 +0100 Subject: [PATCH 4/5] Test prepare_on_all_hosts with prepare_async function --- tests/integration/standard/test_query.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 4ddbd0ab0a..c6b4d8ebc7 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -528,6 +528,20 @@ def test_prepare_on_all_hosts(self): session.execute(select_statement, (1, ), host=host) self.assertEqual(2, self.mock_handler.get_message_count('debug', "Re-preparing")) + def test_prepare_async_on_all_hosts(self): + """ + Test to validate prepare_on_all_hosts flag is honored during prepare_async execution. + """ + clus = TestCluster(prepare_on_all_hosts=True) + self.addCleanup(clus.shutdown) + + session = clus.connect(wait_for_all_pools=True) + select_statement = session.prepare_async("SELECT k FROM test3rf.test WHERE k = ?").result() + time.sleep(1) # we have no way to know when prepared statements are asynchronously completed + for host in clus.metadata.all_hosts(): + session.execute(select_statement, (1, ), host=host) + self.assertEqual(0, self.mock_handler.get_message_count('debug', "Re-preparing")) + def test_prepare_batch_statement(self): """ Test to validate a prepared statement used inside a batch statement is correctly handled From 5ac7d2012255e9edb63544e65640d2aac7b84a81 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 11 Dec 2024 19:59:38 +0100 Subject: [PATCH 5/5] Refactor prepare_async implementation --- cassandra/cluster.py | 131 ++++++++---------- cassandra/query.py | 2 +- .../standard/test_prepared_statements.py | 10 +- tests/integration/standard/test_query.py | 5 +- 4 files changed, 62 insertions(+), 86 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 6808a0c790..5a0c355200 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2717,7 +2717,7 @@ def execute_async(self, query, parameters=None, trace=False, custom_payload=None if execute_as: custom_payload[_proxy_execute_key] = execute_as.encode() - future = self._create_execute_response_future( + future = self._create_response_future( query, parameters, trace, custom_payload, timeout, execution_profile, paging_state, host) future._protocol_handler = self.client_protocol_handler @@ -2782,8 +2782,8 @@ def execute_graph_async(self, query, parameters=None, trace=False, execution_pro custom_payload[_proxy_execute_key] = execute_as.encode() custom_payload[_request_timeout_key] = int64_pack(int(execution_profile.request_timeout * 1000)) - future = self._create_execute_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload, - timeout=_NOT_SET, execution_profile=execution_profile) + future = self._create_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload, + timeout=_NOT_SET, execution_profile=execution_profile) future.message.query_params = graph_parameters future._protocol_handler = self.client_protocol_handler @@ -2885,9 +2885,9 @@ def _transform_params(self, parameters, graph_options): def _target_analytics_master(self, future): future._start_timer() - master_query_future = self._create_execute_response_future("CALL DseClientTool.getAnalyticsGraphServer()", - parameters=None, trace=False, - custom_payload=None, timeout=future.timeout) + master_query_future = self._create_response_future("CALL DseClientTool.getAnalyticsGraphServer()", + parameters=None, trace=False, + custom_payload=None, timeout=future.timeout) master_query_future.row_factory = tuple_factory master_query_future.send_request() @@ -2910,11 +2910,11 @@ def _on_analytics_master_result(self, response, master_future, query_future): self.submit(query_future.send_request) - def prepare_async(self, query, custom_payload=None, keyspace=None, prepare_on_all_hosts=None): + def prepare_async(self, query, custom_payload=None, keyspace=None): """ - Prepare the given query and return a :class:`~.PrepareFuture` - object. You may also call :meth:`~.PrepareFuture.result()` - on the :class:`.PrepareFuture` to synchronously block for + Prepare the given query and return a :class:`~.ResponseFuture` + object. You may also call :meth:`~.ResponseFuture.result()` + on the :class:`.ResponseFuture` to synchronously block for prepared statement object at any time. See :meth:`Session.prepare` for parameter definitions. @@ -2928,25 +2928,43 @@ def prepare_async(self, query, custom_payload=None, keyspace=None, prepare_on_al ... prepared_statement = future.result() ... except Exception: ... log.exception("Operation failed:") - - When :meth:`~.Cluster.prepare_on_all_hosts` is enabled, method - attempts to prepare given query on all hosts, but does not wait - for their response. """ - if prepare_on_all_hosts is None: - prepare_on_all_hosts = self.cluster.prepare_on_all_hosts - future = self._create_prepare_response_future(query, keyspace, custom_payload, prepare_on_all_hosts) + future = self._create_prepare_response_future(query, keyspace, custom_payload) future._protocol_handler = self.client_protocol_handler self._on_request(future) future.send_request() return future - def _create_prepare_response_future(self, query, keyspace, custom_payload, prepare_on_all_hosts): - return PrepareFuture(self, query, keyspace, custom_payload, self.default_timeout, prepare_on_all_hosts) + def _create_prepare_response_future(self, query, keyspace, custom_payload): + message = PrepareMessage(query=query, keyspace=keyspace) + future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) + + def _prepare_result_processor(future, response): + prepared_keyspace = keyspace if keyspace else None + prepared_statement = PreparedStatement.from_message( + response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, + prepared_keyspace, + self._protocol_version, response.column_metadata, response.result_metadata_id, + self.cluster.column_encryption_policy) + prepared_statement.custom_payload = custom_payload + self.cluster.add_prepared(response.query_id, prepared_statement) + if self.cluster.prepare_on_all_hosts: + # prepare statement on all hosts + host = future._current_host + try: + self.prepare_on_all_nodes(future.message.query, host, future.message.keyspace) + except Exception: + log.exception("Error preparing query on all hosts:") + + return prepared_statement + + future._set_result_processor(_prepare_result_processor) + return future + - def _create_execute_response_future(self, query, parameters, trace, custom_payload, - timeout, execution_profile=EXEC_PROFILE_DEFAULT, - paging_state=None, host=None): + def _create_response_future(self, query, parameters, trace, custom_payload, + timeout, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None): """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None @@ -3160,17 +3178,8 @@ def prepare(self, query, custom_payload=None, keyspace=None): `custom_payload` is a key value map to be passed along with the prepare message. See :ref:`custom_payload`. """ - future = self.prepare_async(query, custom_payload, keyspace, prepare_on_all_hosts=False) - response = future.result() - if self.cluster.prepare_on_all_hosts: - # prepare on all hosts in a synchronous way, not asynchronously - # as internally in prepare_async() (PrepareFuture) - host = future._current_host - try: - self.prepare_on_all_nodes(response.query_string, host, response.keyspace) - except Exception: - log.exception("Error preparing query on all hosts:") - return response + future = self.prepare_async(query, custom_payload, keyspace) + return future.result() def prepare_on_all_nodes(self, query, excluded_host, keyspace=None): """ @@ -4345,6 +4354,7 @@ class ResponseFuture(object): _col_types = None _final_exception = None _query_traces = None + _result_processor = None _callbacks = None _errbacks = None _current_host = None @@ -4976,10 +4986,20 @@ def result(self): """ self._event.wait() if self._final_result is not _NOT_SET: - return ResultSet(self, self._final_result) + if self._result_processor is not None: + return self._result_processor(self, self._final_result) + else: + return ResultSet(self, self._final_result) else: raise self._final_exception + def _set_result_processor(self, result_processor): + """ + Sets internal result processor which allows to control object + returned by :meth:`ResponseFuture.result()` method. + """ + self._result_processor = result_processor + def get_query_trace_ids(self): """ Returns the trace session ids for this future, if tracing was enabled (does not fetch trace data). @@ -5130,49 +5150,6 @@ def __str__(self): __repr__ = __str__ -class PrepareFuture(ResponseFuture): - _final_prepare_result = _NOT_SET - - def __init__(self, session, query, keyspace, custom_payload, timeout, prepare_on_all_hosts): - super().__init__(session, PrepareMessage(query=query, keyspace=keyspace), None, timeout) - self.query_string = query - self._prepare_on_all_hosts = prepare_on_all_hosts - self._keyspace = keyspace - self._custom_payload = custom_payload - - def _set_final_result(self, response): - session = self.session - cluster = session.cluster - prepared_statement = PreparedStatement.from_message( - response.query_id, response.bind_metadata, response.pk_indexes, cluster.metadata, self.query_string, - self._keyspace, session._protocol_version, response.column_metadata, response.result_metadata_id, - cluster.column_encryption_policy) - prepared_statement.custom_payload = response.custom_payload - cluster.add_prepared(response.query_id, prepared_statement) - self._final_prepare_result = prepared_statement - - if self._prepare_on_all_hosts: - # trigger asynchronous preparation of query on other C* nodes, - # we are on event loop thread, so do not execute those synchronously - session.submit( - session.prepare_on_all_nodes, - self.query_string, self._current_host, self._keyspace) - - super()._set_final_result(response) - - def result(self): - self._event.wait() - if self._final_prepare_result is not _NOT_SET: - return self._final_prepare_result - else: - raise self._final_exception - - def __str__(self): - result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result - return "" \ - % (self.query_string, self._req_id, result, self._final_exception, self.coordinator_host) - __repr__ = __str__ - class QueryExhausted(Exception): """ Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and diff --git a/cassandra/query.py b/cassandra/query.py index be123c3135..e29c2a3113 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -1035,7 +1035,7 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): def _execute(self, query, parameters, time_spent, max_wait): timeout = (max_wait - time_spent) if max_wait is not None else None - future = self._session._create_execute_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) + future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) # in case the user switched the row factory, set it to namedtuple for this query future.row_factory = named_tuple_factory future.send_request() diff --git a/tests/integration/standard/test_prepared_statements.py b/tests/integration/standard/test_prepared_statements.py index 615d8cf0f8..76abc49edc 100644 --- a/tests/integration/standard/test_prepared_statements.py +++ b/tests/integration/standard/test_prepared_statements.py @@ -22,7 +22,7 @@ from cassandra import InvalidRequest, DriverException from cassandra import ConsistencyLevel, ProtocolVersion -from cassandra.cluster import PrepareFuture +from cassandra.cluster import ResponseFuture from cassandra.query import PreparedStatement, UNSET_VALUE from tests.integration import (get_server_versions, greaterthanorequalcass40, greaterthanorequaldse50, requirecassandra, BasicSharedKeyspaceUnitTestCase) @@ -152,7 +152,7 @@ def test_basic_async(self): """ INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) """) - self.assertIsInstance(prepared_future, PrepareFuture) + self.assertIsInstance(prepared_future, ResponseFuture) prepared = prepared_future.result() self.assertIsInstance(prepared, PreparedStatement) @@ -163,7 +163,7 @@ def test_basic_async(self): """ SELECT * FROM cf0 WHERE a=? """) - self.assertIsInstance(prepared_future, PrepareFuture) + self.assertIsInstance(prepared_future, ResponseFuture) prepared = prepared_future.result() self.assertIsInstance(prepared, PreparedStatement) @@ -176,7 +176,7 @@ def test_basic_async(self): """ INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) """) - self.assertIsInstance(prepared_future, PrepareFuture) + self.assertIsInstance(prepared_future, ResponseFuture) prepared = prepared_future.result() self.assertIsInstance(prepared, PreparedStatement) @@ -191,7 +191,7 @@ def test_basic_async(self): """ SELECT * FROM cf0 WHERE a=? """) - self.assertIsInstance(prepared_future, PrepareFuture) + self.assertIsInstance(prepared_future, ResponseFuture) prepared = prepared_future.result() self.assertIsInstance(prepared, PreparedStatement) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index c6b4d8ebc7..44c605a48d 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -536,10 +536,9 @@ def test_prepare_async_on_all_hosts(self): self.addCleanup(clus.shutdown) session = clus.connect(wait_for_all_pools=True) - select_statement = session.prepare_async("SELECT k FROM test3rf.test WHERE k = ?").result() - time.sleep(1) # we have no way to know when prepared statements are asynchronously completed + select_statement = session.prepare_async("SELECT k FROM test3rf.test WHERE k = ? AND v = ? ALLOW FILTERING").result() for host in clus.metadata.all_hosts(): - session.execute(select_statement, (1, ), host=host) + session.execute(select_statement, (1, 1), host=host) self.assertEqual(0, self.mock_handler.get_message_count('debug', "Re-preparing")) def test_prepare_batch_statement(self):