diff --git a/Cargo.lock b/Cargo.lock index 8595ee0ba61f6..cafb9036b9f24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10358,6 +10358,7 @@ name = "risingwave_batch" version = "2.3.0-alpha" dependencies = [ "anyhow", + "assert_matches", "async-recursion", "async-trait", "either", diff --git a/e2e_test/webhook/check_1.slt.part b/e2e_test/webhook/check_1.slt.part index c207ba53afb00..b5b2888ff5ee4 100644 --- a/e2e_test/webhook/check_1.slt.part +++ b/e2e_test/webhook/check_1.slt.part @@ -19,6 +19,6 @@ select data ->> 'source', data->> 'auth_algo' from segment_hmac_sha1; segment hmac_sha1 query TT -select data ->> 'source', data->> 'auth_algo' from hubspot_sha256_v2; +select data ->> 'source', data->> 'auth_algo' from test_primary_key; ---- -hubspot sha256_v2 \ No newline at end of file +github hmac_sha1 diff --git a/e2e_test/webhook/check_2.slt.part b/e2e_test/webhook/check_2.slt.part index 0b6305701d882..7fbd8516ce560 100644 --- a/e2e_test/webhook/check_2.slt.part +++ b/e2e_test/webhook/check_2.slt.part @@ -26,4 +26,9 @@ query TT select data ->> 'source', data->> 'auth_algo' from hubspot_sha256_v2; ---- hubspot sha256_v2 -hubspot sha256_v2 \ No newline at end of file +hubspot sha256_v2 + +query TT +select data ->> 'source', data->> 'auth_algo' from test_primary_key; +---- +github hmac_sha1 diff --git a/e2e_test/webhook/check_3.slt.part b/e2e_test/webhook/check_3.slt.part index 0ad97c19f0880..4c75ef84f49c4 100644 --- a/e2e_test/webhook/check_3.slt.part +++ b/e2e_test/webhook/check_3.slt.part @@ -31,4 +31,9 @@ select data ->> 'source', data->> 'auth_algo' from hubspot_sha256_v2; ---- hubspot sha256_v2 hubspot sha256_v2 -hubspot sha256_v2 \ No newline at end of file +hubspot sha256_v2 + +query TT +select data ->> 'source', data->> 'auth_algo' from test_primary_key; +---- +github hmac_sha1 diff --git a/e2e_test/webhook/create_table.slt.part b/e2e_test/webhook/create_table.slt.part index a7b0eb19a48f5..fe806f5938f08 100644 --- a/e2e_test/webhook/create_table.slt.part +++ b/e2e_test/webhook/create_table.slt.part @@ -53,3 +53,16 @@ create table hubspot_sha256_v2 ( , 'UTF8') ), 'hex') ); + +statement ok +create table test_primary_key ( + data JSONB PRIMARY KEY +) WITH ( + connector = 'webhook', +) VALIDATE SECRET test_secret AS secure_compare( + headers->>'x-hub-signature', + 'sha1=' || encode(hmac(test_secret, data, 'sha1'), 'hex') +); + +statement error Adding/dropping a column of a table with webhook has not been implemented. +ALTER TABLE github_hmac_sha1 ADD COLUMN new_col int; diff --git a/e2e_test/webhook/drop_table.slt.part b/e2e_test/webhook/drop_table.slt.part index 32a2a40800b87..fd9cced253e09 100644 --- a/e2e_test/webhook/drop_table.slt.part +++ b/e2e_test/webhook/drop_table.slt.part @@ -1,3 +1,6 @@ +statement ok +DROP TABLE test_primary_key; + statement ok DROP TABLE hubspot_sha256_v2; diff --git a/e2e_test/webhook/sender.py b/e2e_test/webhook/sender.py index 446f674348444..43ce35bf9e1e1 100644 --- a/e2e_test/webhook/sender.py +++ b/e2e_test/webhook/sender.py @@ -62,6 +62,22 @@ def send_github_hmac_sha1(secret): send_webhook(url, headers, payload_json) +def send_test_primary_key(secret): + payload = message + payload['source'] = "github" + payload['auth_algo'] = "hmac_sha1" + url = SERVER_URL + "test_primary_key" + + payload_json = json.dumps(payload) + signature = generate_signature_hmac(secret, payload_json, 'sha1', "sha1=") + # Webhook message headers + headers = { + "Content-Type": "application/json", + "X-Hub-Signature": signature # Custom signature header + } + send_webhook(url, headers, payload_json) + + def send_github_hmac_sha256(secret): payload = message payload['source'] = "github" @@ -143,3 +159,6 @@ def send_hubspot_sha256_v2(secret): send_segment_hmac_sha1(secret) # hubspot send_hubspot_sha256_v2(secret) + + # ensure the single column can still work as normal + send_test_primary_key(secret) diff --git a/proto/catalog.proto b/proto/catalog.proto index 341b1362b3d6c..f5d47d3fc9da2 100644 --- a/proto/catalog.proto +++ b/proto/catalog.proto @@ -100,6 +100,8 @@ message StreamSourceInfo { message WebhookSourceInfo { secret.SecretRef secret_ref = 1; expr.ExprNode signature_expr = 2; + // Return until the data is persisted in the storage layer or not. Default is true. + bool wait_for_persistence = 3; } message Source { diff --git a/proto/task_service.proto b/proto/task_service.proto index cb14ee809d943..a9176202e966f 100644 --- a/proto/task_service.proto +++ b/proto/task_service.proto @@ -60,6 +60,35 @@ message GetDataResponse { data.DataChunk record_batch = 2; } +message FastInsertRequest { + // Id of the table to perform inserting. + uint32 table_id = 1; + // Version of the table. + uint64 table_version_id = 2; + repeated uint32 column_indices = 3; + data.DataChunk data_chunk = 4; + + // An optional field and will be `None` for tables without user-defined pk. + // The `BatchInsertExecutor` should add a column with NULL value which will + // be filled in streaming. + optional uint32 row_id_index = 5; + + // Use this number to assign the insert req to different worker nodes and dml channels. + uint32 request_id = 6; + bool wait_for_persistence = 7; + // TODO(kexiang): add support for default columns. plan_common.ExprContext expr_context is needed for it. +} + +message FastInsertResponse { + enum Status { + UNSPECIFIED = 0; + SUCCEEDED = 1; + DML_FAILED = 2; + } + Status status = 1; + string error_message = 2; +} + message ExecuteRequest { batch_plan.TaskId task_id = 1; batch_plan.PlanFragment plan = 2; @@ -73,6 +102,8 @@ service TaskService { // Cancel an already-died (self execution-failure, previous aborted, completed) task will still succeed. rpc CancelTask(CancelTaskRequest) returns (CancelTaskResponse); rpc Execute(ExecuteRequest) returns (stream GetDataResponse); + // A lightweight version insert, only for non-pgwire insert, such as inserts from webhooks and websockets. + rpc FastInsert(FastInsertRequest) returns (FastInsertResponse); } message GetDataRequest { diff --git a/src/batch/Cargo.toml b/src/batch/Cargo.toml index 56491055cfacd..b83adaa2104fe 100644 --- a/src/batch/Cargo.toml +++ b/src/batch/Cargo.toml @@ -9,6 +9,7 @@ repository = { workspace = true } [dependencies] anyhow = "1" +assert_matches = "1" async-recursion = "1" async-trait = "0.1" either = "1" diff --git a/src/batch/executors/src/executor/insert.rs b/src/batch/executors/src/executor/insert.rs index 3ff8cede9dc0f..9b0b0839e9964 100644 --- a/src/batch/executors/src/executor/insert.rs +++ b/src/batch/executors/src/executor/insert.rs @@ -387,7 +387,7 @@ mod tests { assert_eq!(*chunk.columns()[2], array); }); - assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_)); + assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..)); let epoch = u64::MAX; let full_range = (Bound::Unbounded, Bound::Unbounded); let store_content = store diff --git a/src/batch/src/executor/fast_insert.rs b/src/batch/src/executor/fast_insert.rs new file mode 100644 index 0000000000000..898fe370a6eba --- /dev/null +++ b/src/batch/src/executor/fast_insert.rs @@ -0,0 +1,252 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::iter::repeat; +use std::sync::Arc; + +use itertools::Itertools; +use risingwave_common::array::{DataChunk, Op, SerialArray, StreamChunk}; +use risingwave_common::catalog::{Field, Schema, TableId, TableVersionId}; +use risingwave_common::transaction::transaction_id::TxnId; +use risingwave_common::types::DataType; +use risingwave_common::util::epoch::{Epoch, INVALID_EPOCH}; +use risingwave_dml::dml_manager::DmlManagerRef; +use risingwave_pb::task_service::FastInsertRequest; + +use crate::error::Result; + +/// A fast insert executor spacially designed for non-pgwire inserts such as websockets and webhooks. +pub struct FastInsertExecutor { + /// Target table id. + table_id: TableId, + table_version_id: TableVersionId, + dml_manager: DmlManagerRef, + column_indices: Vec, + + row_id_index: Option, + txn_id: TxnId, + request_id: u32, +} + +impl FastInsertExecutor { + pub fn build( + dml_manager: DmlManagerRef, + insert_req: FastInsertRequest, + ) -> Result<(FastInsertExecutor, DataChunk)> { + let table_id = TableId::new(insert_req.table_id); + let column_indices = insert_req + .column_indices + .iter() + .map(|&i| i as usize) + .collect(); + let mut schema = Schema::new(vec![Field::unnamed(DataType::Jsonb)]); + schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column + let data_chunk_pb = insert_req + .data_chunk + .expect("no data_chunk found in fast insert node"); + + Ok(( + FastInsertExecutor::new( + table_id, + insert_req.table_version_id, + dml_manager, + column_indices, + insert_req.row_id_index.as_ref().map(|index| *index as _), + insert_req.request_id, + ), + DataChunk::from_protobuf(&data_chunk_pb)?, + )) + } + + #[allow(clippy::too_many_arguments)] + pub fn new( + table_id: TableId, + table_version_id: TableVersionId, + dml_manager: DmlManagerRef, + column_indices: Vec, + row_id_index: Option, + request_id: u32, + ) -> Self { + let txn_id = dml_manager.gen_txn_id(); + Self { + table_id, + table_version_id, + dml_manager, + column_indices, + row_id_index, + txn_id, + request_id, + } + } +} + +impl FastInsertExecutor { + pub async fn do_execute( + self, + data_chunk_to_insert: DataChunk, + returning_epoch: bool, + ) -> Result { + let table_dml_handle = self + .dml_manager + .table_dml_handle(self.table_id, self.table_version_id)?; + // instead of session id, we use request id here to select a write handle. + let mut write_handle = table_dml_handle.write_handle(self.request_id, self.txn_id)?; + + write_handle.begin()?; + + // Transform the data chunk to a stream chunk, then write to the source. + // Return the returning chunk. + let write_txn_data = |chunk: DataChunk| async { + let cap = chunk.capacity(); + let (mut columns, vis) = chunk.into_parts(); + + let mut ordered_columns = self + .column_indices + .iter() + .enumerate() + .map(|(i, idx)| (*idx, columns[i].clone())) + .collect_vec(); + + ordered_columns.sort_unstable_by_key(|(idx, _)| *idx); + columns = ordered_columns + .into_iter() + .map(|(_, column)| column) + .collect_vec(); + + // If the user does not specify the primary key, then we need to add a column as the + // primary key. + if let Some(row_id_index) = self.row_id_index { + let row_id_col = SerialArray::from_iter(repeat(None).take(cap)); + columns.insert(row_id_index, Arc::new(row_id_col.into())) + } + + let stream_chunk = StreamChunk::with_visibility(vec![Op::Insert; cap], columns, vis); + + #[cfg(debug_assertions)] + table_dml_handle.check_chunk_schema(&stream_chunk); + + write_handle.write_chunk(stream_chunk).await?; + + Result::Ok(()) + }; + write_txn_data(data_chunk_to_insert).await?; + if returning_epoch { + write_handle.end_returning_epoch().await.map_err(Into::into) + } else { + write_handle.end().await?; + // the returned epoch is invalid and should not be used. + Ok(Epoch(INVALID_EPOCH)) + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::ops::Bound; + + use assert_matches::assert_matches; + use futures::StreamExt; + use risingwave_common::array::{Array, JsonbArrayBuilder}; + use risingwave_common::catalog::{ColumnDesc, ColumnId, INITIAL_TABLE_VERSION_ID}; + use risingwave_common::transaction::transaction_message::TxnMsg; + use risingwave_common::types::JsonbVal; + use risingwave_dml::dml_manager::DmlManager; + use risingwave_storage::memory::MemoryStateStore; + use risingwave_storage::store::{ReadOptions, StateStoreReadExt}; + use serde_json::json; + + use super::*; + use crate::risingwave_common::array::ArrayBuilder; + use crate::risingwave_common::types::Scalar; + use crate::*; + + #[tokio::test] + async fn test_fast_insert() -> Result<()> { + let epoch = Epoch::now(); + let dml_manager = Arc::new(DmlManager::for_test()); + let store = MemoryStateStore::new(); + // Schema of the table + let mut schema = Schema::new(vec![Field::unnamed(DataType::Jsonb)]); + schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column + + let row_id_index = Some(1); + + let mut builder = JsonbArrayBuilder::with_type(1, DataType::Jsonb); + + let mut header_map = HashMap::new(); + header_map.insert("data".to_owned(), "value1".to_owned()); + + let json_value = json!(header_map); + let jsonb_val = JsonbVal::from(json_value); + builder.append(Some(jsonb_val.as_scalar_ref())); + + // Use builder to obtain a single (List) column DataChunk + let data_chunk = DataChunk::new(vec![builder.finish().into_ref()], 1); + + // Create the table. + let table_id = TableId::new(0); + + // Create reader + let column_descs = schema + .fields + .iter() + .enumerate() + .map(|(i, field)| ColumnDesc::unnamed(ColumnId::new(i as _), field.data_type.clone())) + .collect_vec(); + // We must create a variable to hold this `Arc` here, or it will be dropped + // due to the `Weak` reference in `DmlManager`. + let reader = dml_manager + .register_reader(table_id, INITIAL_TABLE_VERSION_ID, &column_descs) + .unwrap(); + let mut reader = reader.stream_reader().into_stream(); + + // Insert + let insert_executor = Box::new(FastInsertExecutor::new( + table_id, + INITIAL_TABLE_VERSION_ID, + dml_manager, + vec![0], // Ignoring insertion order + row_id_index, + 0, + )); + let handle = tokio::spawn(async move { + let epoch_received = insert_executor.do_execute(data_chunk, true).await.unwrap(); + assert_eq!(epoch, epoch_received); + }); + + // Read + assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_)); + + assert_matches!(reader.next().await.unwrap()?, TxnMsg::Data(_, chunk) => { + assert_eq!(chunk.columns().len(),2); + let array = chunk.columns()[0].as_jsonb().iter().collect::>(); + assert_eq!(JsonbVal::from(array[0].unwrap()), jsonb_val); + }); + + assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_, Some(epoch_notifier)) => { + epoch_notifier.send(epoch).unwrap(); + }); + let epoch = u64::MAX; + let full_range = (Bound::Unbounded, Bound::Unbounded); + let store_content = store + .scan(full_range, epoch, None, ReadOptions::default()) + .await?; + assert!(store_content.is_empty()); + + handle.await.unwrap(); + + Ok(()) + } +} diff --git a/src/batch/src/executor/mod.rs b/src/batch/src/executor/mod.rs index e66c6c6c08bd3..2c08961185319 100644 --- a/src/batch/src/executor/mod.rs +++ b/src/batch/src/executor/mod.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod fast_insert; mod managed; pub mod test_utils; @@ -20,6 +21,7 @@ use std::sync::Arc; use anyhow::Context; use async_recursion::async_recursion; +pub use fast_insert::*; use futures::future::BoxFuture; use futures::stream::BoxStream; pub use managed::*; diff --git a/src/batch/src/rpc/service/task_service.rs b/src/batch/src/rpc/service/task_service.rs index 7c601cf1cb031..4a08e56ea4fe1 100644 --- a/src/batch/src/rpc/service/task_service.rs +++ b/src/batch/src/rpc/service/task_service.rs @@ -18,13 +18,16 @@ use risingwave_common::util::tracing::TracingContext; use risingwave_pb::batch_plan::TaskOutputId; use risingwave_pb::task_service::task_service_server::TaskService; use risingwave_pb::task_service::{ - CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, GetDataResponse, - TaskInfoResponse, + fast_insert_response, CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, + FastInsertRequest, FastInsertResponse, GetDataResponse, TaskInfoResponse, }; +use risingwave_storage::dispatch_state_store; use thiserror_ext::AsReport; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status}; +use crate::error::BatchError; +use crate::executor::FastInsertExecutor; use crate::rpc::service::exchange::GrpcExchangeWriter; use crate::task::{ BatchEnvironment, BatchManager, BatchTaskExecution, ComputeNodeContext, StateReporter, @@ -118,6 +121,31 @@ impl TaskService for BatchServiceImpl { let mgr = self.mgr.clone(); BatchServiceImpl::get_execute_stream(env, mgr, req).await } + + #[cfg_attr(coverage, coverage(off))] + async fn fast_insert( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + let res = self.do_fast_insert(req).await; + match res { + Ok(_) => Ok(Response::new(FastInsertResponse { + status: fast_insert_response::Status::Succeeded.into(), + error_message: "".to_owned(), + })), + Err(e) => match e { + BatchError::Dml(e) => Ok(Response::new(FastInsertResponse { + status: fast_insert_response::Status::DmlFailed.into(), + error_message: format!("{}", e.as_report()), + })), + _ => { + error!(error = %e.as_report(), "failed to fast insert"); + Err(e.into()) + } + }, + } + } } impl BatchServiceImpl { @@ -185,4 +213,33 @@ impl BatchServiceImpl { }); Ok(Response::new(ReceiverStream::new(rx))) } + + async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> { + let table_id = insert_req.table_id; + let wait_for_persistence = insert_req.wait_for_persistence; + let (executor, data_chunk) = + FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?; + let epoch = executor + .do_execute(data_chunk, wait_for_persistence) + .await?; + if wait_for_persistence { + dispatch_state_store!(self.env.state_store(), store, { + use risingwave_common::catalog::TableId; + use risingwave_hummock_sdk::HummockReadEpoch; + use risingwave_storage::store::TryWaitEpochOptions; + use risingwave_storage::StateStore; + + store + .try_wait_epoch( + HummockReadEpoch::Committed(epoch.0), + TryWaitEpochOptions { + table_id: TableId::new(table_id), + }, + ) + .await + .map_err(BatchError::from)?; + }); + } + Ok(()) + } } diff --git a/src/common/src/transaction/transaction_message.rs b/src/common/src/transaction/transaction_message.rs index 540a25f56ffc1..d4f92fddfb52a 100644 --- a/src/common/src/transaction/transaction_message.rs +++ b/src/common/src/transaction/transaction_message.rs @@ -13,16 +13,18 @@ // limitations under the License. use enum_as_inner::EnumAsInner; +use tokio::sync::oneshot; use crate::array::StreamChunk; use crate::transaction::transaction_id::TxnId; use crate::transaction::transaction_message::TxnMsg::{Begin, Data, End, Rollback}; +use crate::util::epoch::Epoch; #[derive(Debug, EnumAsInner)] pub enum TxnMsg { Begin(TxnId), Data(TxnId, StreamChunk), - End(TxnId), + End(TxnId, Option>), Rollback(TxnId), } @@ -31,14 +33,14 @@ impl TxnMsg { match self { Begin(txn_id) => *txn_id, Data(txn_id, _) => *txn_id, - End(txn_id) => *txn_id, + End(txn_id, _) => *txn_id, Rollback(txn_id) => *txn_id, } } pub fn as_stream_chunk(&self) -> Option<&StreamChunk> { match self { - Begin(_) | End(_) | Rollback(_) => None, + Begin(_) | End(..) | Rollback(_) => None, Data(_, chunk) => Some(chunk), } } diff --git a/src/dml/src/table.rs b/src/dml/src/table.rs index 0b9c9d398296c..d8b8548b95e79 100644 --- a/src/dml/src/table.rs +++ b/src/dml/src/table.rs @@ -20,6 +20,7 @@ use risingwave_common::array::StreamChunk; use risingwave_common::catalog::ColumnDesc; use risingwave_common::transaction::transaction_id::TxnId; use risingwave_common::transaction::transaction_message::TxnMsg; +use risingwave_common::util::epoch::Epoch; use tokio::sync::oneshot; use crate::error::{DmlError, Result}; @@ -189,11 +190,27 @@ impl WriteHandle { assert_eq!(self.txn_state, TxnState::Begin); self.txn_state = TxnState::Committed; // Await the notifier. - let notifier = self.write_txn_control_msg(TxnMsg::End(self.txn_id))?; + let notifier = self.write_txn_control_msg(TxnMsg::End(self.txn_id, None))?; notifier.await.map_err(|_| DmlError::ReaderClosed)?; Ok(()) } + pub async fn end_returning_epoch(mut self) -> Result { + assert_eq!(self.txn_state, TxnState::Begin); + self.txn_state = TxnState::Committed; + // Await the notifier. + let (epoch_notifier_tx, epoch_notifier_rx) = oneshot::channel(); + let notifier = self.write_txn_control_msg_returning_epoch(TxnMsg::End( + self.txn_id, + Some(epoch_notifier_tx), + ))?; + notifier.await.map_err(|_| DmlError::ReaderClosed)?; + let epoch = epoch_notifier_rx + .await + .map_err(|_| DmlError::ReaderClosed)?; + Ok(epoch) + } + pub fn rollback(mut self) -> Result> { self.rollback_inner() } @@ -234,6 +251,21 @@ impl WriteHandle { Err(_) => Err(DmlError::ReaderClosed), } } + + fn write_txn_control_msg_returning_epoch( + &self, + txn_msg: TxnMsg, + ) -> Result> { + assert_eq!(self.txn_id, txn_msg.txn_id()); + let (notifier_tx, notifier_rx) = oneshot::channel(); + match self.tx.send_immediate(txn_msg, notifier_tx) { + Ok(_) => Ok(notifier_rx), + + // It's possible that the source executor is scaled in or migrated, so the channel + // is closed. To guarantee the transactional atomicity, bail out. + Err(_) => Err(DmlError::ReaderClosed), + } + } } /// [`TableStreamReader`] reads changes from a certain table continuously. @@ -252,7 +284,7 @@ impl TableStreamReader { while let Some((txn_msg, notifier)) = self.rx.recv().await { // Notify about that we've taken the chunk. match txn_msg { - TxnMsg::Begin(_) | TxnMsg::End(_) | TxnMsg::Rollback(_) => { + TxnMsg::Begin(_) | TxnMsg::End(..) | TxnMsg::Rollback(_) => { _ = notifier.send(0); } TxnMsg::Data(_, chunk) => { @@ -268,7 +300,7 @@ impl TableStreamReader { while let Some((txn_msg, notifier)) = self.rx.recv().await { // Notify about that we've taken the chunk. match &txn_msg { - TxnMsg::Begin(_) | TxnMsg::End(_) | TxnMsg::Rollback(_) => { + TxnMsg::Begin(_) | TxnMsg::End(..) | TxnMsg::Rollback(_) => { _ = notifier.send(0); yield txn_msg; } @@ -343,7 +375,7 @@ mod tests { write_handle.end().await.unwrap(); }); - assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(_)); + assert_matches!(reader.next().await.unwrap()?, TxnMsg::End(..)); Ok(()) } diff --git a/src/dml/src/txn_channel.rs b/src/dml/src/txn_channel.rs index 35b8abf73e170..30eb457c3eeb1 100644 --- a/src/dml/src/txn_channel.rs +++ b/src/dml/src/txn_channel.rs @@ -99,7 +99,7 @@ impl Sender { .forget(); Some(PermitValue(card as _)) } - TxnMsg::Begin(_) | TxnMsg::Rollback(_) | TxnMsg::End(_) => None, + TxnMsg::Begin(_) | TxnMsg::Rollback(_) | TxnMsg::End(..) => None, }; self.tx diff --git a/src/frontend/src/handler/alter_table_column.rs b/src/frontend/src/handler/alter_table_column.rs index ceaac06e77803..1f408f987b395 100644 --- a/src/frontend/src/handler/alter_table_column.rs +++ b/src/frontend/src/handler/alter_table_column.rs @@ -258,6 +258,12 @@ pub async fn handle_alter_table_column( ))); } + if original_catalog.webhook_info.is_some() { + return Err(RwError::from(ErrorCode::BindError( + "Adding/dropping a column of a table with webhook has not been implemented.".to_owned(), + ))); + } + // Retrieve the original table definition and parse it to AST. let mut definition = original_catalog.create_sql_ast_purified()?; let Statement::CreateTable { columns, .. } = &mut definition else { diff --git a/src/frontend/src/handler/create_table.rs b/src/frontend/src/handler/create_table.rs index 4c08583deb39b..e1d7f79b2855e 100644 --- a/src/frontend/src/handler/create_table.rs +++ b/src/frontend/src/handler/create_table.rs @@ -2021,6 +2021,7 @@ fn bind_webhook_info( let WebhookSourceInfo { secret_ref, signature_expr, + wait_for_persistence, } = webhook_info; // validate secret_ref @@ -2058,6 +2059,7 @@ fn bind_webhook_info( let pb_webhook_info = PbWebhookSourceInfo { secret_ref: Some(pb_secret_ref), signature_expr: Some(expr.to_expr_proto()), + wait_for_persistence, }; Ok(pb_webhook_info) diff --git a/src/frontend/src/scheduler/fast_insert.rs b/src/frontend/src/scheduler/fast_insert.rs new file mode 100644 index 0000000000000..d2e0080b75ab0 --- /dev/null +++ b/src/frontend/src/scheduler/fast_insert.rs @@ -0,0 +1,81 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::anyhow; +use itertools::Itertools; +use risingwave_batch::error::BatchError; +use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector; +use risingwave_common::hash::WorkerSlotMapping; +use risingwave_pb::common::WorkerNode; +use risingwave_rpc_client::ComputeClient; + +use crate::catalog::TableId; +use crate::scheduler::{SchedulerError, SchedulerResult}; +use crate::session::FrontendEnv; + +pub async fn choose_fast_insert_client( + table_id: &TableId, + frontend_env: &FrontendEnv, + request_id: u32, +) -> SchedulerResult { + let worker = choose_worker(table_id, frontend_env, request_id)?; + let client = frontend_env.client_pool().get(&worker).await?; + Ok(client) +} + +fn get_table_dml_vnode_mapping( + table_id: &TableId, + frontend_env: &FrontendEnv, + worker_node_manager: &WorkerNodeSelector, +) -> SchedulerResult { + let guard = frontend_env.catalog_reader().read_guard(); + + let table = guard + .get_any_table_by_id(table_id) + .map_err(|e| SchedulerError::Internal(anyhow!(e)))?; + + let fragment_id = match table.dml_fragment_id.as_ref() { + Some(dml_fragment_id) => dml_fragment_id, + // Backward compatibility for those table without `dml_fragment_id`. + None => &table.fragment_id, + }; + + worker_node_manager + .manager + .get_streaming_fragment_mapping(fragment_id) + .map_err(|e| e.into()) +} + +fn choose_worker( + table_id: &TableId, + frontend_env: &FrontendEnv, + request_id: u32, +) -> SchedulerResult { + let worker_node_manager = + WorkerNodeSelector::new(frontend_env.worker_node_manager_ref(), false); + + // dml should use streaming vnode mapping + let vnode_mapping = get_table_dml_vnode_mapping(table_id, frontend_env, &worker_node_manager)?; + let worker_node = { + let worker_ids = vnode_mapping.iter_unique().collect_vec(); + let candidates = worker_node_manager + .manager + .get_workers_by_worker_slot_ids(&worker_ids)?; + if candidates.is_empty() { + return Err(BatchError::EmptyWorkerNodes.into()); + } + candidates[request_id as usize % candidates.len()].clone() + }; + Ok(worker_node) +} diff --git a/src/frontend/src/scheduler/mod.rs b/src/frontend/src/scheduler/mod.rs index d81510e8d82e9..f3364e78183f9 100644 --- a/src/frontend/src/scheduler/mod.rs +++ b/src/frontend/src/scheduler/mod.rs @@ -32,6 +32,8 @@ mod snapshot; pub use snapshot::*; mod local; pub use local::*; +mod fast_insert; +pub use fast_insert::*; use crate::scheduler::task_context::FrontendBatchTaskContext; diff --git a/src/frontend/src/webhook/mod.rs b/src/frontend/src/webhook/mod.rs index 17bd803cd6bbd..3b12e7e4fa0ec 100644 --- a/src/frontend/src/webhook/mod.rs +++ b/src/frontend/src/webhook/mod.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::net::{IpAddr, SocketAddr}; +use std::net::SocketAddr; +use std::sync::atomic::AtomicU32; use std::sync::Arc; use anyhow::{anyhow, Context}; @@ -21,19 +22,20 @@ use axum::extract::{Extension, Path}; use axum::http::{HeaderMap, Method, StatusCode}; use axum::routing::post; use axum::Router; -use pgwire::net::Address; -use pgwire::pg_server::SessionManager; +use risingwave_common::array::{Array, ArrayBuilder, DataChunk}; use risingwave_common::secret::LocalSecretManager; -use risingwave_sqlparser::ast::{Expr, ObjectName}; +use risingwave_common::types::{DataType, JsonbVal, Scalar}; +use risingwave_pb::catalog::WebhookSourceInfo; +use risingwave_pb::task_service::{FastInsertRequest, FastInsertResponse}; use tokio::net::TcpListener; use tower::ServiceBuilder; use tower_http::add_extension::AddExtensionLayer; use tower_http::compression::CompressionLayer; use tower_http::cors::{self, CorsLayer}; -use crate::handler::handle; use crate::webhook::utils::{err, Result}; mod utils; +use risingwave_rpc_client::ComputeClient; pub type Service = Arc; @@ -41,71 +43,50 @@ pub type Service = Arc; const USER: &str = "root"; #[derive(Clone)] +pub struct FastInsertContext { + pub webhook_source_info: WebhookSourceInfo, + pub fast_insert_request: FastInsertRequest, + pub compute_client: ComputeClient, +} + pub struct WebhookService { webhook_addr: SocketAddr, + counter: AtomicU32, } pub(super) mod handlers { - use std::net::Ipv4Addr; - + use jsonbb::Value; + use risingwave_common::array::JsonbArrayBuilder; + use risingwave_common::session_config::SearchPath; use risingwave_pb::catalog::WebhookSourceInfo; - use risingwave_sqlparser::ast::{Query, SetExpr, Statement, Value, Values}; + use risingwave_pb::task_service::fast_insert_response; use utils::{header_map_to_json, verify_signature}; use super::*; use crate::catalog::root_catalog::SchemaPath; + use crate::scheduler::choose_fast_insert_client; use crate::session::SESSION_MANAGER; pub async fn handle_post_request( - Extension(_srv): Extension, + Extension(srv): Extension, headers: HeaderMap, Path((database, schema, table)): Path<(String, String, String)>, body: Bytes, ) -> Result<()> { - let session_mgr = SESSION_MANAGER - .get() - .expect("session manager has been initialized"); - - // Can be any address, we use the port of meta to indicate that it's a internal request. - let dummy_addr = Address::Tcp(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 5691)); - - // TODO(kexiang): optimize this - // get a session object for the corresponding database - let session = session_mgr - .connect(database.as_str(), USER, Arc::new(dummy_addr)) - .map_err(|e| { - err( - anyhow!(e).context(format!( - "Failed to create session for database `{}` with user `{}`", - database, USER - )), - StatusCode::UNAUTHORIZED, - ) - })?; + let request_id = srv + .counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let FastInsertContext { + webhook_source_info, + mut fast_insert_request, + compute_client, + } = acquire_table_info(request_id, &database, &schema, &table).await?; let WebhookSourceInfo { - secret_ref, signature_expr, - } = { - let search_path = session.config().search_path(); - let schema_path = SchemaPath::new(Some(schema.as_str()), &search_path, USER); - - let reader = session.env().catalog_reader().read_guard(); - let (table_catalog, _schema) = reader - .get_any_table_by_name(database.as_str(), schema_path, &table) - .map_err(|e| err(e, StatusCode::NOT_FOUND))?; - - table_catalog - .webhook_info - .as_ref() - .ok_or_else(|| { - err( - anyhow!("Table `{}` is not with webhook source", table), - StatusCode::FORBIDDEN, - ) - })? - .clone() - }; + secret_ref, + wait_for_persistence: _, + } = webhook_source_info; let secret_string = LocalSecretManager::global() .fill_secret(secret_ref.unwrap()) @@ -114,6 +95,7 @@ pub(super) mod handlers { // Once limitation here is that the key is no longer case-insensitive, users must user the lowercase key when defining the webhook source table. let headers_jsonb = header_map_to_json(&headers); + // verify the signature let is_valid = verify_signature( headers_jsonb, secret_string.as_str(), @@ -129,45 +111,114 @@ pub(super) mod handlers { )); } - let payload = String::from_utf8(body.to_vec()).map_err(|e| { + // Use builder to obtain a single column & single row DataChunk + let mut builder = JsonbArrayBuilder::with_type(1, DataType::Jsonb); + let json_value = Value::from_text(&body).map_err(|e| { err( anyhow!(e).context("Failed to parse body"), StatusCode::UNPROCESSABLE_ENTITY, ) })?; + let jsonb_val = JsonbVal::from(json_value); + builder.append(Some(jsonb_val.as_scalar_ref())); + let data_chunk = DataChunk::new(vec![builder.finish().into_ref()], 1); + + // fill the data_chunk + fast_insert_request.data_chunk = Some(data_chunk.to_protobuf()); + // execute on the compute node + let res = execute(fast_insert_request, compute_client).await?; + + if res.status == fast_insert_response::Status::Succeeded as i32 { + Ok(()) + } else { + Err(err( + anyhow!("Failed to fast insert: {}", res.error_message), + StatusCode::INTERNAL_SERVER_ERROR, + )) + } + } - let insert_stmt = Statement::Insert { - table_name: ObjectName::from(vec![table.as_str().into()]), - columns: vec![], - source: Box::new(Query { - with: None, - body: SetExpr::Values(Values(vec![vec![Expr::Value(Value::SingleQuotedString( - payload, - ))]])), - order_by: vec![], - limit: None, - offset: None, - fetch: None, - }), - returning: vec![], + async fn acquire_table_info( + request_id: u32, + database: &String, + schema: &String, + table: &String, + ) -> Result { + let session_mgr = SESSION_MANAGER + .get() + .expect("session manager has been initialized"); + + let frontend_env = session_mgr.env(); + + let search_path = SearchPath::default(); + let schema_path = SchemaPath::new(Some(schema.as_str()), &search_path, USER); + + let (webhook_source_info, table_id, version_id, row_id_index) = { + let reader = frontend_env.catalog_reader().read_guard(); + let (table_catalog, _schema) = reader + .get_any_table_by_name(database.as_str(), schema_path, table) + .map_err(|e| err(e, StatusCode::NOT_FOUND))?; + + let webhook_source_info = table_catalog + .webhook_info + .as_ref() + .ok_or_else(|| { + err( + anyhow!("Table `{}` is not with webhook source", table), + StatusCode::FORBIDDEN, + ) + })? + .clone(); + ( + webhook_source_info, + table_catalog.id(), + table_catalog.version_id().expect("table must be versioned"), + table_catalog.row_id_index.map(|idx| idx as u32), + ) }; - let _rsp = handle(session, insert_stmt, Arc::from(""), vec![]) + let fast_insert_request = FastInsertRequest { + table_id: table_id.table_id, + table_version_id: version_id, + column_indices: vec![0], + // leave the data_chunk empty for now + data_chunk: None, + row_id_index, + request_id, + wait_for_persistence: webhook_source_info.wait_for_persistence, + }; + + let compute_client = choose_fast_insert_client(&table_id, frontend_env, request_id) .await - .map_err(|e| { - err( - anyhow!(e).context("Failed to insert into target table"), - StatusCode::INTERNAL_SERVER_ERROR, - ) - })?; + .unwrap(); - Ok(()) + Ok(FastInsertContext { + webhook_source_info, + fast_insert_request, + compute_client, + }) + } + + async fn execute( + request: FastInsertRequest, + client: ComputeClient, + ) -> Result { + let response = client.fast_insert(request).await.map_err(|e| { + err( + anyhow!(e).context("Failed to execute on compute node"), + StatusCode::INTERNAL_SERVER_ERROR, + ) + })?; + Ok(response) } } impl WebhookService { pub fn new(webhook_addr: SocketAddr) -> Self { - Self { webhook_addr } + Self { + webhook_addr, + counter: AtomicU32::new(0), + } } pub async fn serve(self) -> anyhow::Result<()> { diff --git a/src/rpc_client/src/compute_client.rs b/src/rpc_client/src/compute_client.rs index 38f9ae83270b3..e8eaf1dac2ec1 100644 --- a/src/rpc_client/src/compute_client.rs +++ b/src/rpc_client/src/compute_client.rs @@ -40,8 +40,8 @@ use risingwave_pb::task_service::exchange_service_client::ExchangeServiceClient; use risingwave_pb::task_service::task_service_client::TaskServiceClient; use risingwave_pb::task_service::{ permits, CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, - GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, - TaskInfoResponse, + FastInsertRequest, FastInsertResponse, GetDataRequest, GetDataResponse, GetStreamRequest, + GetStreamResponse, PbPermits, TaskInfoResponse, }; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -210,6 +210,16 @@ impl ComputeClient { .into_inner()) } + pub async fn fast_insert(&self, req: FastInsertRequest) -> Result { + Ok(self + .task_client + .to_owned() + .fast_insert(req) + .await + .map_err(RpcError::from_compute_status)? + .into_inner()) + } + pub async fn stack_trace(&self) -> Result { Ok(self .monitor_client diff --git a/src/sqlparser/src/ast/ddl.rs b/src/sqlparser/src/ast/ddl.rs index 5efbc2003cfb1..eaecb276f1218 100644 --- a/src/sqlparser/src/ast/ddl.rs +++ b/src/sqlparser/src/ast/ddl.rs @@ -874,4 +874,5 @@ impl fmt::Display for ReferentialAction { pub struct WebhookSourceInfo { pub secret_ref: SecretRefValue, pub signature_expr: Expr, + pub wait_for_persistence: bool, } diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 843dd566212bf..e3d52c1c73521 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -40,6 +40,8 @@ use crate::{impl_parse_to, parser_v2}; pub(crate) const UPSTREAM_SOURCE_KEY: &str = "connector"; pub(crate) const WEBHOOK_CONNECTOR: &str = "webhook"; +const WEBHOOK_WAIT_FOR_PERSISTENCE: &str = "webhook.wait_for_persistence"; + #[derive(Debug, Clone, PartialEq)] pub enum ParserError { TokenizerError(String), @@ -2603,6 +2605,12 @@ impl Parser<'_> { parser_err!("VALIDATE is only supported for tables created with webhook source"); } + let wait_for_persistence = with_options + .iter() + .find(|&opt| opt.name.real_value() == WEBHOOK_WAIT_FOR_PERSISTENCE) + .map(|opt| opt.value.to_string().eq_ignore_ascii_case("true")) + .unwrap_or(true); + self.expect_keyword(Keyword::SECRET)?; let secret_ref = self.parse_secret_ref()?; if secret_ref.ref_as == SecretRefAsType::File { @@ -2615,6 +2623,7 @@ impl Parser<'_> { Some(WebhookSourceInfo { secret_ref, signature_expr, + wait_for_persistence, }) } else { None diff --git a/src/stream/src/executor/dml.rs b/src/stream/src/executor/dml.rs index 8dddf41aa59f5..ab31c7dd3a2eb 100644 --- a/src/stream/src/executor/dml.rs +++ b/src/stream/src/executor/dml.rs @@ -130,6 +130,8 @@ impl DmlExecutor { stream.pause_stream(); } + let mut epoch = barrier.get_curr_epoch(); + yield Message::Barrier(barrier); // Active transactions: txn_id -> TxnBuffer with transaction chunks. @@ -150,6 +152,7 @@ impl DmlExecutor { Either::Left(msg) => { // Stream messages. if let Message::Barrier(barrier) = &msg { + epoch = barrier.get_curr_epoch(); // We should handle barrier messages here to pause or resume the data from // DML. if let Some(mutation) = barrier.mutation.as_deref() { @@ -205,7 +208,10 @@ impl DmlExecutor { panic!("Transaction id collision txn_id = {}.", txn_id) }); } - TxnMsg::End(txn_id) => { + TxnMsg::End(txn_id, epoch_notifier) => { + if let Some(sender) = epoch_notifier { + let _ = sender.send(epoch); + } let mut txn_buffer = active_txn_map.remove(&txn_id) .unwrap_or_else(|| panic!("Receive an unexpected transaction end message. Active transaction map doesn't contain this transaction txn_id = {}.", txn_id)); @@ -312,13 +318,12 @@ async fn apply_dml_rate_limit( ) { #[for_await] for txn_msg in stream { - let txn_msg = txn_msg?; - match txn_msg { + match txn_msg? { TxnMsg::Begin(txn_id) => { yield TxnMsg::Begin(txn_id); } - TxnMsg::End(txn_id) => { - yield TxnMsg::End(txn_id); + TxnMsg::End(txn_id, epoch_notifier) => { + yield TxnMsg::End(txn_id, epoch_notifier); } TxnMsg::Rollback(txn_id) => { yield TxnMsg::Rollback(txn_id); @@ -330,7 +335,6 @@ async fn apply_dml_rate_limit( yield TxnMsg::Data(txn_id, chunk); continue; } - let rate_limit = loop { match rate_limiter.rate_limit() { RateLimit::Pause => rate_limiter.wait(0).await,