From ea7c1ada7f85608c4a973e0fea25048355492b23 Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Fri, 1 Mar 2024 21:34:07 -0500 Subject: [PATCH 01/12] feat(stream): Add mqtt connector --- Cargo.lock | 48 ++++++- ci/scripts/gen-integration-test-yaml.py | 3 +- .../mqtt-source/create_source.sql | 11 ++ integration_tests/mqtt-source/data_check | 1 + .../mqtt-source/docker-compose.yml | 45 +++++++ integration_tests/mqtt-source/query.sql | 8 ++ src/connector/Cargo.toml | 4 + src/connector/src/common.rs | 126 ++++++++++++++++++ src/connector/src/error.rs | 2 + src/connector/src/macros.rs | 1 + src/connector/src/source/mod.rs | 2 + .../src/source/mqtt/enumerator/mod.rs | 102 ++++++++++++++ src/connector/src/source/mqtt/mod.rs | 55 ++++++++ .../src/source/mqtt/source/message.rs | 48 +++++++ src/connector/src/source/mqtt/source/mod.rs | 20 +++ .../src/source/mqtt/source/reader.rs | 105 +++++++++++++++ src/connector/src/source/mqtt/split.rs | 50 +++++++ src/connector/with_options_sink.yaml | 33 +++++ src/connector/with_options_source.yaml | 36 +++++ src/frontend/src/handler/create_source.rs | 7 +- src/workspace-hack/Cargo.toml | 3 + 21 files changed, 705 insertions(+), 5 deletions(-) create mode 100644 integration_tests/mqtt-source/create_source.sql create mode 100644 integration_tests/mqtt-source/data_check create mode 100644 integration_tests/mqtt-source/docker-compose.yml create mode 100644 integration_tests/mqtt-source/query.sql create mode 100644 src/connector/src/source/mqtt/enumerator/mod.rs create mode 100644 src/connector/src/source/mqtt/mod.rs create mode 100644 src/connector/src/source/mqtt/source/message.rs create mode 100644 src/connector/src/source/mqtt/source/mod.rs create mode 100644 src/connector/src/source/mqtt/source/reader.rs create mode 100644 src/connector/src/source/mqtt/split.rs diff --git a/Cargo.lock b/Cargo.lock index b407640880163..0414976469cc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -797,7 +797,7 @@ dependencies = [ "rustls", "rustls-native-certs", "rustls-pemfile", - "rustls-webpki", + "rustls-webpki 0.101.7", "serde", "serde_json", "serde_nanos", @@ -4052,6 +4052,7 @@ checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" dependencies = [ "futures-core", "futures-sink", + "nanorand", "pin-project", "spin 0.9.8", ] @@ -6348,6 +6349,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "034a0ad7deebf0c2abcf2435950a6666c3c15ea9d8fad0c0f48efa8a7f843fed" +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "native-tls" version = "0.2.11" @@ -9172,7 +9182,10 @@ dependencies = [ "risingwave_jni_core", "risingwave_pb", "risingwave_rpc_client", + "rumqttc", "rust_decimal", + "rustls-native-certs", + "rustls-pemfile", "rw_futures_util", "serde", "serde_derive", @@ -9189,6 +9202,7 @@ dependencies = [ "time", "tokio-postgres", "tokio-retry", + "tokio-rustls", "tokio-stream", "tokio-util", "tracing", @@ -10319,6 +10333,24 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rumqttc" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2433b134712bc17a6f85a35e06b901e6e8d0bb20b5367e1121e6fedc140c0ac" +dependencies = [ + "bytes", + "flume", + "futures", + "log", + "rustls-native-certs", + "rustls-pemfile", + "rustls-webpki 0.100.3", + "thiserror", + "tokio", + "tokio-rustls", +] + [[package]] name = "rust-embed" version = "8.1.0" @@ -10453,7 +10485,7 @@ checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c" dependencies = [ "log", "ring 0.17.5", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] @@ -10478,6 +10510,16 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-webpki" +version = "0.100.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6a5fc258f1c1276dfe3016516945546e2d5383911efc0fc4f1cdc5df3a4ae3" +dependencies = [ + "ring 0.16.20", + "untrusted 0.7.1", +] + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -14022,6 +14064,7 @@ dependencies = [ "either", "fail", "flate2", + "flume", "frunk_core", "futures", "futures-channel", @@ -14032,6 +14075,7 @@ dependencies = [ "futures-task", "futures-util", "generic-array", + "getrandom", "governor", "hashbrown 0.13.2", "hashbrown 0.14.0", diff --git a/ci/scripts/gen-integration-test-yaml.py b/ci/scripts/gen-integration-test-yaml.py index fab33cbf6944d..90aa83b54ad98 100644 --- a/ci/scripts/gen-integration-test-yaml.py +++ b/ci/scripts/gen-integration-test-yaml.py @@ -34,7 +34,8 @@ 'big-query-sink': ['json'], 'mindsdb': ['json'], 'vector': ['json'], - 'nats': ['json', 'protobuf'], + 'nats': ['json'], + 'mqtt-source': ['json'], 'doris-sink': ['json'], 'starrocks-sink': ['json'], 'deltalake-sink': ['json'], diff --git a/integration_tests/mqtt-source/create_source.sql b/integration_tests/mqtt-source/create_source.sql new file mode 100644 index 0000000000000..6c6344c20cda7 --- /dev/null +++ b/integration_tests/mqtt-source/create_source.sql @@ -0,0 +1,11 @@ + +CREATE TABLE mqtt_source_table +( + id integer, + name varchar, +) +WITH ( + connector='mqtt', + host='mqtt-server', + topic= 'test' +) FORMAT PLAIN ENCODE JSON; diff --git a/integration_tests/mqtt-source/data_check b/integration_tests/mqtt-source/data_check new file mode 100644 index 0000000000000..fcaf3aca97ed0 --- /dev/null +++ b/integration_tests/mqtt-source/data_check @@ -0,0 +1 @@ +mqtt_source_table \ No newline at end of file diff --git a/integration_tests/mqtt-source/docker-compose.yml b/integration_tests/mqtt-source/docker-compose.yml new file mode 100644 index 0000000000000..87969f8ad9044 --- /dev/null +++ b/integration_tests/mqtt-source/docker-compose.yml @@ -0,0 +1,45 @@ +--- +version: "3" +services: + risingwave-standalone: + extends: + file: ../../docker/docker-compose.yml + service: risingwave-standalone + mqtt-server: + image: emqx/emqx:5.2.1 + ports: + - 1883:1883 + etcd-0: + extends: + file: ../../docker/docker-compose.yml + service: etcd-0 + grafana-0: + extends: + file: ../../docker/docker-compose.yml + service: grafana-0 + minio-0: + extends: + file: ../../docker/docker-compose.yml + service: minio-0 + prometheus-0: + extends: + file: ../../docker/docker-compose.yml + service: prometheus-0 + message_queue: + extends: + file: ../../docker/docker-compose.yml + service: message_queue +volumes: + compute-node-0: + external: false + etcd-0: + external: false + grafana-0: + external: false + minio-0: + external: false + prometheus-0: + external: false + message_queue: + external: false +name: risingwave-compose diff --git a/integration_tests/mqtt-source/query.sql b/integration_tests/mqtt-source/query.sql new file mode 100644 index 0000000000000..5a3abc3b555ce --- /dev/null +++ b/integration_tests/mqtt-source/query.sql @@ -0,0 +1,8 @@ +select + * +from + mqtt_source_table +order by + id +LIMIT + 10; \ No newline at end of file diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 77f2c1374dc92..e997e71201242 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -117,7 +117,10 @@ risingwave_common = { workspace = true } risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } +rumqttc = "0.22.0" rust_decimal = "1" +rustls-native-certs = "0.6" +rustls-pemfile = "1" rw_futures_util = { workspace = true } serde = { version = "1", features = ["derive", "rc"] } serde_derive = "1" @@ -141,6 +144,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ ] } tokio-postgres = { version = "0.7", features = ["with-uuid-1"] } tokio-retry = "0.3" +tokio-rustls = "0.24" tokio-stream = "0.1" tokio-util = { version = "0.7", features = ["codec", "io"] } tonic = { workspace = true } diff --git a/src/connector/src/common.rs b/src/connector/src/common.rs index 66afaf55f0cc1..fc1cd6792b668 100644 --- a/src/connector/src/common.rs +++ b/src/connector/src/common.rs @@ -684,3 +684,129 @@ impl NatsCommon { Ok(creds) } } + +#[serde_as] +#[derive(Deserialize, Debug, Clone, WithOptions)] +pub struct MqttCommon { + /// Protocol used for RisingWave to communicate with Kafka brokers. Could be tcp or ssl + #[serde(rename = "protocol")] + pub protocol: Option, + #[serde(rename = "host")] + pub host: String, + #[serde(rename = "port")] + pub port: Option, + #[serde(rename = "topic")] + pub topic: String, + #[serde(rename = "username")] + pub user: Option, + #[serde(rename = "password")] + pub password: Option, + #[serde(rename = "client_prefix")] + pub client_prefix: Option, + #[serde(rename = "tls.ca")] + pub ca: Option, + #[serde(rename = "tls.client_cert")] + pub client_cert: Option, + #[serde(rename = "tls.client_key")] + pub client_key: Option, +} + +impl MqttCommon { + pub(crate) fn build_client( + &self, + id: u32, + ) -> ConnectorResult<(rumqttc::v5::AsyncClient, rumqttc::v5::EventLoop)> { + let ssl = self + .protocol + .as_ref() + .map(|p| p == "ssl") + .unwrap_or_default(); + + let client_id = format!( + "{}_{}{}", + self.client_prefix.as_deref().unwrap_or("risingwave"), + id, + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() + % 100000, + ); + + let port = self.port.unwrap_or(if ssl { 8883 } else { 1883 }) as u16; + + let mut options = rumqttc::v5::MqttOptions::new(client_id, &self.host, port); + if ssl { + let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); + if let Some(ca) = &self.ca { + let certificates = load_certs(ca)?; + for cert in certificates { + root_cert_store.add(&cert).unwrap(); + } + } else { + for cert in + rustls_native_certs::load_native_certs().expect("could not load platform certs") + { + root_cert_store + .add(&tokio_rustls::rustls::Certificate(cert.0)) + .unwrap(); + } + } + + let builder = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store); + + let tls_config = if let (Some(client_cert), Some(client_key)) = + (self.client_cert.as_ref(), self.client_key.as_ref()) + { + let certs = load_certs(client_cert)?; + let key = load_private_key(client_key)?; + + builder.with_client_auth_cert(certs, key)? + } else { + builder.with_no_client_auth() + }; + + options.set_transport(rumqttc::Transport::tls_with_config( + rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)), + )); + } + + if let Some(user) = &self.user { + options.set_credentials(user, self.password.as_deref().unwrap_or_default()); + } + + Ok(rumqttc::v5::AsyncClient::new(options, 10)) + } +} + +fn load_certs(certificates: &str) -> ConnectorResult> { + let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") { + std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())? + } else { + certificates.as_bytes().to_owned() + }; + + let certs = rustls_pemfile::certs(&mut cert_bytes.as_slice())?; + + Ok(certs + .into_iter() + .map(tokio_rustls::rustls::Certificate) + .collect()) +} + +fn load_private_key(certificate: &str) -> ConnectorResult { + let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") { + std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())? + } else { + certificate.as_bytes().to_owned() + }; + + let certs = rustls_pemfile::pkcs8_private_keys(&mut cert_bytes.as_slice())?; + let cert = certs + .into_iter() + .next() + .ok_or_else(|| anyhow!("No private key found"))?; + Ok(tokio_rustls::rustls::PrivateKey(cert)) +} diff --git a/src/connector/src/error.rs b/src/connector/src/error.rs index 3dc10af3d8e7a..1317981f88919 100644 --- a/src/connector/src/error.rs +++ b/src/connector/src/error.rs @@ -58,6 +58,8 @@ def_anyhow_newtype! { redis::RedisError => "Redis error", arrow_schema::ArrowError => "Arrow error", google_cloud_pubsub::client::google_cloud_auth::error::Error => "Google Cloud error", + tokio_rustls::rustls::Error => "TLS error", + rumqttc::v5::ClientError => "MQTT error", } pub type ConnectorResult = std::result::Result; diff --git a/src/connector/src/macros.rs b/src/connector/src/macros.rs index 4b375254c5ad1..b369e6d8a11e3 100644 --- a/src/connector/src/macros.rs +++ b/src/connector/src/macros.rs @@ -32,6 +32,7 @@ macro_rules! for_all_classified_sources { { Nexmark, $crate::source::nexmark::NexmarkProperties, $crate::source::nexmark::NexmarkSplit }, { Datagen, $crate::source::datagen::DatagenProperties, $crate::source::datagen::DatagenSplit }, { GooglePubsub, $crate::source::google_pubsub::PubsubProperties, $crate::source::google_pubsub::PubsubSplit }, + { Mqtt, $crate::source::mqtt::MqttProperties, $crate::source::mqtt::split::MqttSplit }, { Nats, $crate::source::nats::NatsProperties, $crate::source::nats::split::NatsSplit }, { S3, $crate::source::filesystem::S3Properties, $crate::source::filesystem::FsSplit }, { Gcs, $crate::source::filesystem::opendal_source::GcsProperties , $crate::source::filesystem::OpendalFsSplit<$crate::source::filesystem::opendal_source::OpendalGcs> }, diff --git a/src/connector/src/source/mod.rs b/src/connector/src/source/mod.rs index 3656820ed95b0..f965d373d9306 100644 --- a/src/connector/src/source/mod.rs +++ b/src/connector/src/source/mod.rs @@ -21,6 +21,7 @@ pub mod google_pubsub; pub mod kafka; pub mod kinesis; pub mod monitor; +pub mod mqtt; pub mod nats; pub mod nexmark; pub mod pulsar; @@ -29,6 +30,7 @@ pub(crate) use common::*; pub use google_pubsub::GOOGLE_PUBSUB_CONNECTOR; pub use kafka::KAFKA_CONNECTOR; pub use kinesis::KINESIS_CONNECTOR; +pub use mqtt::MQTT_CONNECTOR; pub use nats::NATS_CONNECTOR; mod common; pub mod iceberg; diff --git a/src/connector/src/source/mqtt/enumerator/mod.rs b/src/connector/src/source/mqtt/enumerator/mod.rs new file mode 100644 index 0000000000000..8e012b1c41036 --- /dev/null +++ b/src/connector/src/source/mqtt/enumerator/mod.rs @@ -0,0 +1,102 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashSet; + +use async_trait::async_trait; +use risingwave_common::bail; +use rumqttc::v5::{Event, Incoming}; +use rumqttc::Outgoing; + +use super::source::MqttSplit; +use super::MqttProperties; +use crate::error::ConnectorResult; +use crate::source::{SourceEnumeratorContextRef, SplitEnumerator}; + +pub struct MqttSplitEnumerator { + topic: String, + client: rumqttc::v5::AsyncClient, + eventloop: rumqttc::v5::EventLoop, +} + +#[async_trait] +impl SplitEnumerator for MqttSplitEnumerator { + type Properties = MqttProperties; + type Split = MqttSplit; + + async fn new( + properties: Self::Properties, + context: SourceEnumeratorContextRef, + ) -> ConnectorResult { + let (client, eventloop) = properties.common.build_client(context.info.source_id)?; + + Ok(Self { + topic: properties.common.topic, + client, + eventloop, + }) + } + + async fn list_splits(&mut self) -> ConnectorResult> { + if !self.topic.contains('#') && !self.topic.contains('+') { + self.client + .subscribe(self.topic.clone(), rumqttc::v5::mqttbytes::QoS::AtLeastOnce) + .await?; + + let start = std::time::Instant::now(); + loop { + match self.eventloop.poll().await { + Ok(Event::Outgoing(Outgoing::Subscribe(_))) => { + break; + } + _ => { + if start.elapsed().as_secs() > 5 { + bail!("Failed to subscribe to topic {}", self.topic); + } + } + } + } + self.client.unsubscribe(self.topic.clone()).await?; + + return Ok(vec![MqttSplit::new(self.topic.clone())]); + } + + self.client + .subscribe(self.topic.clone(), rumqttc::v5::mqttbytes::QoS::AtLeastOnce) + .await?; + + let start = std::time::Instant::now(); + let mut topics = HashSet::new(); + loop { + match self.eventloop.poll().await { + Ok(Event::Incoming(Incoming::Publish(p))) => { + topics.insert(String::from_utf8_lossy(&p.topic).to_string()); + } + _ => { + if start.elapsed().as_secs() > 15 { + self.client.unsubscribe(self.topic.clone()).await?; + if topics.is_empty() { + tracing::warn!( + "Failed to find any topics for pattern {}, using a single split", + self.topic + ); + return Ok(vec![MqttSplit::new(self.topic.clone())]); + } + return Ok(topics.into_iter().map(MqttSplit::new).collect()); + } + } + } + } + } +} diff --git a/src/connector/src/source/mqtt/mod.rs b/src/connector/src/source/mqtt/mod.rs new file mode 100644 index 0000000000000..c1085849255b4 --- /dev/null +++ b/src/connector/src/source/mqtt/mod.rs @@ -0,0 +1,55 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod enumerator; +pub mod source; +pub mod split; + +use std::collections::HashMap; + +use serde::Deserialize; +use with_options::WithOptions; + +use crate::common::MqttCommon; +use crate::source::mqtt::enumerator::MqttSplitEnumerator; +use crate::source::mqtt::source::{MqttSplit, MqttSplitReader}; +use crate::source::SourceProperties; + +pub const MQTT_CONNECTOR: &str = "mqtt"; + +#[derive(Clone, Debug, Deserialize, WithOptions)] +pub struct MqttProperties { + #[serde(flatten)] + pub common: MqttCommon, + + // 0 - AtLeastOnce, 1 - AtMostOnce, 2 - ExactlyOnce + pub qos: Option, + + #[serde(flatten)] + pub unknown_fields: HashMap, +} + +impl SourceProperties for MqttProperties { + type Split = MqttSplit; + type SplitEnumerator = MqttSplitEnumerator; + type SplitReader = MqttSplitReader; + + const SOURCE_NAME: &'static str = MQTT_CONNECTOR; +} + +impl crate::source::UnknownFields for MqttProperties { + fn unknown_fields(&self) -> HashMap { + self.unknown_fields.clone() + } +} diff --git a/src/connector/src/source/mqtt/source/message.rs b/src/connector/src/source/mqtt/source/message.rs new file mode 100644 index 0000000000000..16914a3dabcdb --- /dev/null +++ b/src/connector/src/source/mqtt/source/message.rs @@ -0,0 +1,48 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use rumqttc::v5::mqttbytes::v5::Publish; + +use crate::source::base::SourceMessage; +use crate::source::SourceMeta; + +#[derive(Clone, Debug)] +pub struct MqttMessage { + pub topic: String, + pub sequence_number: String, + pub payload: Vec, +} + +impl From for SourceMessage { + fn from(message: MqttMessage) -> Self { + SourceMessage { + key: None, + payload: Some(message.payload), + // For nats jetstream, use sequence id as offset + offset: message.sequence_number, + split_id: message.topic.into(), + meta: SourceMeta::Empty, + } + } +} + +impl MqttMessage { + pub fn new(message: Publish) -> Self { + MqttMessage { + topic: String::from_utf8_lossy(&message.topic).to_string(), + sequence_number: message.pkid.to_string(), + payload: message.payload.to_vec(), + } + } +} diff --git a/src/connector/src/source/mqtt/source/mod.rs b/src/connector/src/source/mqtt/source/mod.rs new file mode 100644 index 0000000000000..2cf5350a0247c --- /dev/null +++ b/src/connector/src/source/mqtt/source/mod.rs @@ -0,0 +1,20 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod message; +mod reader; + +pub use reader::*; + +pub use crate::source::mqtt::split::*; diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs new file mode 100644 index 0000000000000..a2fb1478ba393 --- /dev/null +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -0,0 +1,105 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::Context; +use async_trait::async_trait; +use futures_async_stream::try_stream; +use risingwave_common::bail; +use rumqttc::v5::mqttbytes::v5::Filter; +use rumqttc::v5::mqttbytes::QoS; +use rumqttc::v5::{Event, Incoming}; + +use super::message::MqttMessage; +use super::MqttSplit; +use crate::error::ConnectorResult as Result; +use crate::parser::ParserConfig; +use crate::source::common::{into_chunk_stream, CommonSplitReader}; +use crate::source::mqtt::MqttProperties; +use crate::source::{ + self, BoxChunkSourceStream, Column, SourceContextRef, SourceMessage, SplitReader, +}; + +pub struct MqttSplitReader { + eventloop: rumqttc::v5::EventLoop, + properties: MqttProperties, + parser_config: ParserConfig, + source_ctx: SourceContextRef, +} + +#[async_trait] +impl SplitReader for MqttSplitReader { + type Properties = MqttProperties; + type Split = MqttSplit; + + async fn new( + properties: MqttProperties, + splits: Vec, + parser_config: ParserConfig, + source_ctx: SourceContextRef, + _columns: Option>, + ) -> Result { + let (client, eventloop) = properties + .common + .build_client(source_ctx.actor_id, source_ctx.fragment_id)?; + + let qos = if let Some(qos) = properties.qos { + match qos { + 0 => QoS::AtMostOnce, + 1 => QoS::AtLeastOnce, + 2 => QoS::ExactlyOnce, + _ => bail!("Invalid QoS level: {}", qos), + } + } else { + QoS::AtLeastOnce + }; + + client + .subscribe_many( + splits + .into_iter() + .map(|split| Filter::new(split.topic, qos)), + ) + .await?; + + Ok(Self { + eventloop, + properties, + parser_config, + source_ctx, + }) + } + + fn into_stream(self) -> BoxChunkSourceStream { + let parser_config = self.parser_config.clone(); + let source_context = self.source_ctx.clone(); + into_chunk_stream(self, parser_config, source_context) + } +} + +impl CommonSplitReader for MqttSplitReader { + #[try_stream(ok = Vec, error = crate::error::ConnectorError)] + async fn into_data_stream(self) { + let mut eventloop = self.eventloop; + loop { + match eventloop.poll().await { + Ok(Event::Incoming(Incoming::Publish(p))) => { + let msg = MqttMessage::new(p); + yield vec![SourceMessage::from(msg)]; + } + Ok(_) => (), + Err(e) => Err(e).context("Error getting data from the event loop")?, + } + } + } +} diff --git a/src/connector/src/source/mqtt/split.rs b/src/connector/src/source/mqtt/split.rs new file mode 100644 index 0000000000000..b86bde6097ae9 --- /dev/null +++ b/src/connector/src/source/mqtt/split.rs @@ -0,0 +1,50 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::types::JsonbVal; +use serde::{Deserialize, Serialize}; + +use crate::error::ConnectorResult; +use crate::source::{SplitId, SplitMetaData}; + +/// The states of a NATS split, which will be persisted to checkpoint. +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] +pub struct MqttSplit { + pub(crate) topic: String, +} + +impl SplitMetaData for MqttSplit { + fn id(&self) -> SplitId { + // TODO: should avoid constructing a string every time + self.topic.clone().into() + } + + fn restore_from_json(value: JsonbVal) -> ConnectorResult { + serde_json::from_value(value.take()).map_err(Into::into) + } + + fn encode_to_json(&self) -> JsonbVal { + serde_json::to_value(self.clone()).unwrap().into() + } + + fn update_with_offset(&mut self, _start_sequence: String) -> ConnectorResult<()> { + Ok(()) + } +} + +impl MqttSplit { + pub fn new(topic: String) -> Self { + Self { topic } + } +} diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index 28b492f0c80ea..427efcf23ef85 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -350,6 +350,39 @@ KinesisSinkConfig: field_type: String required: false alias: kinesis.assumerole.external_id +MqttCommon: + fields: + - name: protocol + field_type: String + comments: Protocol used for RisingWave to communicate with Kafka brokers. Could be tcp or ssl + required: false + - name: host + field_type: String + required: true + - name: port + field_type: i32 + required: false + - name: topic + field_type: String + required: true + - name: username + field_type: String + required: false + - name: password + field_type: String + required: false + - name: client_prefix + field_type: String + required: false + - name: tls.ca + field_type: String + required: false + - name: tls.client_cert + field_type: String + required: false + - name: tls.client_key + field_type: String + required: false NatsConfig: fields: - name: server_url diff --git a/src/connector/with_options_source.yaml b/src/connector/with_options_source.yaml index 29055d068294c..fbca14d8699ff 100644 --- a/src/connector/with_options_source.yaml +++ b/src/connector/with_options_source.yaml @@ -255,6 +255,42 @@ KinesisProperties: field_type: String required: false alias: kinesis.assumerole.external_id +MqttProperties: + fields: + - name: protocol + field_type: String + comments: Protocol used for RisingWave to communicate with Kafka brokers. Could be tcp or ssl + required: false + - name: host + field_type: String + required: true + - name: port + field_type: i32 + required: false + - name: topic + field_type: String + required: true + - name: username + field_type: String + required: false + - name: password + field_type: String + required: false + - name: client_prefix + field_type: String + required: false + - name: tls.ca + field_type: String + required: false + - name: tls.client_cert + field_type: String + required: false + - name: tls.client_key + field_type: String + required: false + - name: qos + field_type: i32 + required: false NatsProperties: fields: - name: server_url diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index cba138268f06d..f0e6f075b261e 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -49,8 +49,8 @@ use risingwave_connector::source::nexmark::source::{get_event_data_types_with_na use risingwave_connector::source::test_source::TEST_CONNECTOR; use risingwave_connector::source::{ ConnectorProperties, GCS_CONNECTOR, GOOGLE_PUBSUB_CONNECTOR, KAFKA_CONNECTOR, - KINESIS_CONNECTOR, NATS_CONNECTOR, NEXMARK_CONNECTOR, OPENDAL_S3_CONNECTOR, POSIX_FS_CONNECTOR, - PULSAR_CONNECTOR, S3_CONNECTOR, + KINESIS_CONNECTOR, MQTT_CONNECTOR, NATS_CONNECTOR, NEXMARK_CONNECTOR, OPENDAL_S3_CONNECTOR, + POSIX_FS_CONNECTOR, PULSAR_CONNECTOR, S3_CONNECTOR, }; use risingwave_pb::catalog::{ PbSchemaRegistryNameStrategy, PbSource, StreamSourceInfo, WatermarkDesc, @@ -1011,6 +1011,9 @@ static CONNECTORS_COMPATIBLE_FORMATS: LazyLock hashmap!( Format::Plain => vec![Encode::Json, Encode::Protobuf], ), + MQTT_CONNECTOR => hashmap!( + Format::Plain => vec![Encode::Json, Encode::Bytes], + ), TEST_CONNECTOR => hashmap!( Format::Plain => vec![Encode::Json], ), diff --git a/src/workspace-hack/Cargo.toml b/src/workspace-hack/Cargo.toml index 75b1555da7ae3..b58b8c04cfc7f 100644 --- a/src/workspace-hack/Cargo.toml +++ b/src/workspace-hack/Cargo.toml @@ -48,6 +48,7 @@ digest = { version = "0.10", features = ["mac", "oid", "std"] } either = { version = "1", features = ["serde"] } fail = { version = "0.5", default-features = false, features = ["failpoints"] } flate2 = { version = "1", features = ["zlib"] } +flume = { version = "0.10" } frunk_core = { version = "0.4", default-features = false, features = ["std"] } futures = { version = "0.3" } futures-channel = { version = "0.3", features = ["sink"] } @@ -58,6 +59,7 @@ futures-sink = { version = "0.3" } futures-task = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } +getrandom = { git = "https://github.com/madsim-rs/getrandom.git", rev = "e79a7ae", default-features = false, features = ["js", "rdrand", "std"] } governor = { version = "0.6", default-features = false, features = ["dashmap", "jitter", "std"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["nightly", "raw"] } hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] } @@ -170,6 +172,7 @@ digest = { version = "0.10", features = ["mac", "oid", "std"] } either = { version = "1", features = ["serde"] } frunk_core = { version = "0.4", default-features = false, features = ["std"] } generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } +getrandom = { git = "https://github.com/madsim-rs/getrandom.git", rev = "e79a7ae", default-features = false, features = ["js", "rdrand", "std"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["nightly", "raw"] } indexmap-f595c2ba2a3f28df = { package = "indexmap", version = "2", features = ["serde"] } itertools = { version = "0.11" } From 15dc7d59bc8548bbc960ce212e8c8fc20a8aeca7 Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Sun, 3 Mar 2024 11:13:11 -0500 Subject: [PATCH 02/12] feat: Add mqtt sink --- ci/scripts/gen-integration-test-yaml.py | 2 +- .../mqtt-source/create_source.sql | 11 - integration_tests/mqtt-source/data_check | 1 - integration_tests/mqtt/create_source.sql | 43 +++ integration_tests/mqtt/data_check | 1 + .../{mqtt-source => mqtt}/docker-compose.yml | 0 .../{mqtt-source => mqtt}/query.sql | 0 src/connector/src/common.rs | 1 + src/connector/src/sink/mod.rs | 8 + src/connector/src/sink/mqtt.rs | 245 ++++++++++++++++++ .../src/source/mqtt/enumerator/mod.rs | 133 ++++++---- src/connector/src/source/mqtt/mod.rs | 4 +- .../src/source/mqtt/source/reader.rs | 44 +++- 13 files changed, 420 insertions(+), 73 deletions(-) delete mode 100644 integration_tests/mqtt-source/create_source.sql delete mode 100644 integration_tests/mqtt-source/data_check create mode 100644 integration_tests/mqtt/create_source.sql create mode 100644 integration_tests/mqtt/data_check rename integration_tests/{mqtt-source => mqtt}/docker-compose.yml (100%) rename integration_tests/{mqtt-source => mqtt}/query.sql (100%) create mode 100644 src/connector/src/sink/mqtt.rs diff --git a/ci/scripts/gen-integration-test-yaml.py b/ci/scripts/gen-integration-test-yaml.py index 90aa83b54ad98..93eadbdf401ff 100644 --- a/ci/scripts/gen-integration-test-yaml.py +++ b/ci/scripts/gen-integration-test-yaml.py @@ -35,7 +35,7 @@ 'mindsdb': ['json'], 'vector': ['json'], 'nats': ['json'], - 'mqtt-source': ['json'], + 'mqtt': ['json'], 'doris-sink': ['json'], 'starrocks-sink': ['json'], 'deltalake-sink': ['json'], diff --git a/integration_tests/mqtt-source/create_source.sql b/integration_tests/mqtt-source/create_source.sql deleted file mode 100644 index 6c6344c20cda7..0000000000000 --- a/integration_tests/mqtt-source/create_source.sql +++ /dev/null @@ -1,11 +0,0 @@ - -CREATE TABLE mqtt_source_table -( - id integer, - name varchar, -) -WITH ( - connector='mqtt', - host='mqtt-server', - topic= 'test' -) FORMAT PLAIN ENCODE JSON; diff --git a/integration_tests/mqtt-source/data_check b/integration_tests/mqtt-source/data_check deleted file mode 100644 index fcaf3aca97ed0..0000000000000 --- a/integration_tests/mqtt-source/data_check +++ /dev/null @@ -1 +0,0 @@ -mqtt_source_table \ No newline at end of file diff --git a/integration_tests/mqtt/create_source.sql b/integration_tests/mqtt/create_source.sql new file mode 100644 index 0000000000000..a586ee0966860 --- /dev/null +++ b/integration_tests/mqtt/create_source.sql @@ -0,0 +1,43 @@ +CREATE TABLE + personnel (id integer, name varchar); + +CREATE TABLE mqtt_source_table +( + id integer, + name varchar, +) +WITH ( + connector='mqtt', + host='mqtt-server', + topic= 'test' +) FORMAT PLAIN ENCODE JSON; + + +CREATE SINK mqtt_sink +FROM + personnel +WITH + ( + connector='mqtt', + host='mqtt-server', + topic= 'test', + type = 'append-only', + force_append_only='true', + retain = 'true', + qos = '1' + ); + +INSERT INTO + personnel +VALUES + (1, 'Alice'), + (2, 'Bob'), + (3, 'Tom'), + (4, 'Jerry'), + (5, 'Araminta'), + (6, 'Clover'), + (7, 'Posey'), + (8, 'Waverly'); + + +FLUSH; \ No newline at end of file diff --git a/integration_tests/mqtt/data_check b/integration_tests/mqtt/data_check new file mode 100644 index 0000000000000..8d6dc41bf6691 --- /dev/null +++ b/integration_tests/mqtt/data_check @@ -0,0 +1 @@ +personnel,mqtt_source_table \ No newline at end of file diff --git a/integration_tests/mqtt-source/docker-compose.yml b/integration_tests/mqtt/docker-compose.yml similarity index 100% rename from integration_tests/mqtt-source/docker-compose.yml rename to integration_tests/mqtt/docker-compose.yml diff --git a/integration_tests/mqtt-source/query.sql b/integration_tests/mqtt/query.sql similarity index 100% rename from integration_tests/mqtt-source/query.sql rename to integration_tests/mqtt/query.sql diff --git a/src/connector/src/common.rs b/src/connector/src/common.rs index fc1cd6792b668..68f757da07452 100644 --- a/src/connector/src/common.rs +++ b/src/connector/src/common.rs @@ -736,6 +736,7 @@ impl MqttCommon { let port = self.port.unwrap_or(if ssl { 8883 } else { 1883 }) as u16; let mut options = rumqttc::v5::MqttOptions::new(client_id, &self.host, port); + options.set_keep_alive(std::time::Duration::from_secs(10)); if ssl { let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); if let Some(ca) = &self.ca { diff --git a/src/connector/src/sink/mod.rs b/src/connector/src/sink/mod.rs index 7daf0883ac7e8..6df3027267dfd 100644 --- a/src/connector/src/sink/mod.rs +++ b/src/connector/src/sink/mod.rs @@ -28,6 +28,7 @@ pub mod kafka; pub mod kinesis; pub mod log_store; pub mod mock_coordination_client; +pub mod mqtt; pub mod nats; pub mod pulsar; pub mod redis; @@ -81,6 +82,7 @@ macro_rules! for_all_sinks { { Kinesis, $crate::sink::kinesis::KinesisSink }, { ClickHouse, $crate::sink::clickhouse::ClickHouseSink }, { Iceberg, $crate::sink::iceberg::IcebergSink }, + { Mqtt, $crate::sink::mqtt::MqttSink }, { Nats, $crate::sink::nats::NatsSink }, { Jdbc, $crate::sink::remote::JdbcSink }, { ElasticSearch, $crate::sink::remote::ElasticSearchSink }, @@ -503,6 +505,12 @@ pub enum SinkError { ClickHouse(String), #[error("Redis error: {0}")] Redis(String), + #[error("Mqtt error: {0}")] + Mqtt( + #[source] + #[backtrace] + anyhow::Error, + ), #[error("Nats error: {0}")] Nats( #[source] diff --git a/src/connector/src/sink/mqtt.rs b/src/connector/src/sink/mqtt.rs new file mode 100644 index 0000000000000..031e3c77d3819 --- /dev/null +++ b/src/connector/src/sink/mqtt.rs @@ -0,0 +1,245 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use core::fmt::Debug; +use std::collections::HashMap; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use anyhow::{anyhow, Context as _}; +use risingwave_common::array::StreamChunk; +use risingwave_common::catalog::Schema; +use rumqttc::v5::mqttbytes::QoS; +use rumqttc::v5::ConnectionError; +use serde_derive::Deserialize; +use serde_with::serde_as; +use tokio_retry::strategy::{jitter, ExponentialBackoff}; +use tokio_retry::Retry; +use with_options::WithOptions; + +use super::encoder::{DateHandlingMode, TimeHandlingMode, TimestamptzHandlingMode}; +use super::utils::chunk_to_json; +use super::{DummySinkCommitCoordinator, SinkWriterParam}; +use crate::common::MqttCommon; +use crate::sink::catalog::desc::SinkDesc; +use crate::sink::encoder::{JsonEncoder, TimestampHandlingMode}; +use crate::sink::log_store::DeliveryFutureManagerAddFuture; +use crate::sink::writer::{ + AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter, AsyncTruncateSinkWriterExt, +}; +use crate::sink::{Result, Sink, SinkError, SinkParam, SINK_TYPE_APPEND_ONLY}; +use crate::{deserialize_bool_from_string, deserialize_u32_from_string}; + +pub const MQTT_SINK: &str = "mqtt"; + +#[serde_as] +#[derive(Clone, Debug, Deserialize, WithOptions)] +pub struct MqttConfig { + #[serde(flatten)] + pub common: MqttCommon, + + // 0 - AtLeastOnce, 1 - AtMostOnce, 2 - ExactlyOnce + #[serde(default, deserialize_with = "deserialize_u32_from_string")] + pub qos: u32, + + #[serde(default, deserialize_with = "deserialize_bool_from_string")] + pub retain: bool, + + // accept "append-only" + pub r#type: String, +} + +#[derive(Clone, Debug)] +pub struct MqttSink { + pub config: MqttConfig, + schema: Schema, + is_append_only: bool, +} + +// sink write +pub struct MqttSinkWriter { + pub config: MqttConfig, + client: rumqttc::v5::AsyncClient, + qos: QoS, + retain: bool, + schema: Schema, + json_encoder: JsonEncoder, + stopped: Arc, +} + +/// Basic data types for use with the mqtt interface +impl MqttConfig { + pub fn from_hashmap(values: HashMap) -> Result { + let config = serde_json::from_value::(serde_json::to_value(values).unwrap()) + .map_err(|e| SinkError::Config(anyhow!(e)))?; + if config.r#type != SINK_TYPE_APPEND_ONLY { + Err(SinkError::Config(anyhow!( + "Mqtt sink only support append-only mode" + ))) + } else { + Ok(config) + } + } +} + +impl TryFrom for MqttSink { + type Error = SinkError; + + fn try_from(param: SinkParam) -> std::result::Result { + let schema = param.schema(); + let config = MqttConfig::from_hashmap(param.properties)?; + Ok(Self { + config, + schema, + is_append_only: param.sink_type.is_append_only(), + }) + } +} + +impl Sink for MqttSink { + type Coordinator = DummySinkCommitCoordinator; + type LogSinker = AsyncTruncateLogSinkerOf; + + const SINK_NAME: &'static str = MQTT_SINK; + + fn default_sink_decouple(desc: &SinkDesc) -> bool { + desc.sink_type.is_append_only() + } + + async fn validate(&self) -> Result<()> { + if !self.is_append_only { + return Err(SinkError::Mqtt(anyhow!( + "Nats sink only support append-only mode" + ))); + } + let _client = (self.config.common.build_client(0)) + .context("validate mqtt sink error") + .map_err(SinkError::Mqtt)?; + Ok(()) + } + + async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result { + Ok(MqttSinkWriter::new( + self.config.clone(), + self.schema.clone(), + writer_param.executor_id, + )? + .into_log_sinker(usize::MAX)) + } +} + +impl MqttSinkWriter { + pub fn new(config: MqttConfig, schema: Schema, id: u64) -> Result { + let qos = match config.qos { + 0 => QoS::AtMostOnce, + 1 => QoS::AtLeastOnce, + 2 => QoS::ExactlyOnce, + _ => { + return Err(SinkError::Mqtt(anyhow!( + "Invalid QoS level: {}", + config.qos + ))) + } + }; + + let (client, mut eventloop) = config + .common + .build_client(id as u32) + .map_err(|e| SinkError::Mqtt(anyhow!(e)))?; + + let stopped = Arc::new(AtomicBool::new(false)); + let stopped_clone = stopped.clone(); + + tokio::spawn(async move { + while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) { + match eventloop.poll().await { + Ok(_) => (), + Err(err) => { + if let ConnectionError::Timeout(_) = err { + continue; + } + + if let ConnectionError::MqttState(rumqttc::v5::StateError::Io(err)) = err { + if err.kind() != std::io::ErrorKind::ConnectionAborted { + tracing::error!("[Sink] Failed to poll mqtt eventloop: {}", err); + std::thread::sleep(std::time::Duration::from_secs(1)); + } + } else { + tracing::error!("[Sink] Failed to poll mqtt eventloop: {}", err); + std::thread::sleep(std::time::Duration::from_secs(1)); + } + } + } + } + }); + + Ok::<_, SinkError>(Self { + config: config.clone(), + client, + qos, + retain: config.retain, + schema: schema.clone(), + stopped, + json_encoder: JsonEncoder::new( + schema, + None, + DateHandlingMode::FromCe, + TimestampHandlingMode::Milli, + TimestamptzHandlingMode::UtcWithoutSuffix, + TimeHandlingMode::Milli, + ), + }) + } + + async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> { + Retry::spawn( + ExponentialBackoff::from_millis(100).map(jitter).take(3), + || async { + let data = chunk_to_json(chunk.clone(), &self.json_encoder).unwrap(); + for item in data { + self.client + .publish( + &self.config.common.topic, + self.qos, + self.retain, + item.into_bytes(), + ) + .await + .context("mqtt sink error") + .map_err(SinkError::Mqtt)?; + } + Ok::<_, SinkError>(()) + }, + ) + .await + .context("mqtts sink error") + .map_err(SinkError::Mqtt) + } +} + +impl AsyncTruncateSinkWriter for MqttSinkWriter { + async fn write_chunk<'a>( + &'a mut self, + chunk: StreamChunk, + _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>, + ) -> Result<()> { + self.append_only(chunk).await + } +} + +impl Drop for MqttSinkWriter { + fn drop(&mut self) { + self.stopped + .store(true, std::sync::atomic::Ordering::Relaxed); + } +} diff --git a/src/connector/src/source/mqtt/enumerator/mod.rs b/src/connector/src/source/mqtt/enumerator/mod.rs index 8e012b1c41036..98f3fcec498bb 100644 --- a/src/connector/src/source/mqtt/enumerator/mod.rs +++ b/src/connector/src/source/mqtt/enumerator/mod.rs @@ -13,11 +13,14 @@ // limitations under the License. use std::collections::HashSet; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; use async_trait::async_trait; use risingwave_common::bail; -use rumqttc::v5::{Event, Incoming}; +use rumqttc::v5::{ConnectionError, Event, Incoming}; use rumqttc::Outgoing; +use tokio::sync::RwLock; use super::source::MqttSplit; use super::MqttProperties; @@ -27,7 +30,9 @@ use crate::source::{SourceEnumeratorContextRef, SplitEnumerator}; pub struct MqttSplitEnumerator { topic: String, client: rumqttc::v5::AsyncClient, - eventloop: rumqttc::v5::EventLoop, + topics: Arc>>, + connected: Arc, + stopped: Arc, } #[async_trait] @@ -39,64 +44,100 @@ impl SplitEnumerator for MqttSplitEnumerator { properties: Self::Properties, context: SourceEnumeratorContextRef, ) -> ConnectorResult { - let (client, eventloop) = properties.common.build_client(context.info.source_id)?; + let (client, mut eventloop) = properties.common.build_client(context.info.source_id)?; - Ok(Self { - topic: properties.common.topic, - client, - eventloop, - }) - } + let topic = properties.common.topic.clone(); + let mut topics = HashSet::new(); + if !topic.contains('#') && !topic.contains('+') { + topics.insert(topic.clone()); + } - async fn list_splits(&mut self) -> ConnectorResult> { - if !self.topic.contains('#') && !self.topic.contains('+') { - self.client - .subscribe(self.topic.clone(), rumqttc::v5::mqttbytes::QoS::AtLeastOnce) - .await?; + client + .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtLeastOnce) + .await?; - let start = std::time::Instant::now(); - loop { - match self.eventloop.poll().await { + let cloned_client = client.clone(); + + let topics = Arc::new(RwLock::new(topics)); + + let connected = Arc::new(AtomicBool::new(false)); + let connected_clone = connected.clone(); + + let stopped = Arc::new(AtomicBool::new(false)); + let stopped_clone = stopped.clone(); + + let topics_clone = topics.clone(); + tokio::spawn(async move { + while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) { + match eventloop.poll().await { Ok(Event::Outgoing(Outgoing::Subscribe(_))) => { - break; + connected_clone.store(true, std::sync::atomic::Ordering::Relaxed); } - _ => { - if start.elapsed().as_secs() > 5 { - bail!("Failed to subscribe to topic {}", self.topic); + Ok(Event::Incoming(Incoming::Publish(p))) => { + let topic = String::from_utf8_lossy(&p.topic).to_string(); + let exist = { + let topics = topics_clone.read().await; + topics.contains(&topic) + }; + + if !exist { + let mut topics = topics_clone.write().await; + topics.insert(topic); + } + } + Ok(_) => {} + Err(err) => { + if let ConnectionError::Timeout(_) = err { + continue; } + tracing::error!( + "[Enumerator] Failed to subscribe to topic {}: {}", + topic, + err + ); + connected_clone.store(false, std::sync::atomic::Ordering::Relaxed); + cloned_client + .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtLeastOnce) + .await + .unwrap(); } } } - self.client.unsubscribe(self.topic.clone()).await?; - - return Ok(vec![MqttSplit::new(self.topic.clone())]); - } + }); - self.client - .subscribe(self.topic.clone(), rumqttc::v5::mqttbytes::QoS::AtLeastOnce) - .await?; + Ok(Self { + client, + topics, + topic: properties.common.topic, + connected, + stopped, + }) + } - let start = std::time::Instant::now(); - let mut topics = HashSet::new(); - loop { - match self.eventloop.poll().await { - Ok(Event::Incoming(Incoming::Publish(p))) => { - topics.insert(String::from_utf8_lossy(&p.topic).to_string()); + async fn list_splits(&mut self) -> ConnectorResult> { + if !self.connected.load(std::sync::atomic::Ordering::Relaxed) { + let start = std::time::Instant::now(); + loop { + if self.connected.load(std::sync::atomic::Ordering::Relaxed) { + break; } - _ => { - if start.elapsed().as_secs() > 15 { - self.client.unsubscribe(self.topic.clone()).await?; - if topics.is_empty() { - tracing::warn!( - "Failed to find any topics for pattern {}, using a single split", - self.topic - ); - return Ok(vec![MqttSplit::new(self.topic.clone())]); - } - return Ok(topics.into_iter().map(MqttSplit::new).collect()); - } + + if start.elapsed().as_secs() > 10 { + bail!("Failed to connect to mqtt broker"); } + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; } } + + let topics = self.topics.read().await; + Ok(topics.iter().cloned().map(MqttSplit::new).collect()) + } +} + +impl Drop for MqttSplitEnumerator { + fn drop(&mut self) { + self.stopped + .store(true, std::sync::atomic::Ordering::Relaxed); } } diff --git a/src/connector/src/source/mqtt/mod.rs b/src/connector/src/source/mqtt/mod.rs index c1085849255b4..0aeef04e58ee7 100644 --- a/src/connector/src/source/mqtt/mod.rs +++ b/src/connector/src/source/mqtt/mod.rs @@ -22,6 +22,7 @@ use serde::Deserialize; use with_options::WithOptions; use crate::common::MqttCommon; +use crate::deserialize_u32_from_string; use crate::source::mqtt::enumerator::MqttSplitEnumerator; use crate::source::mqtt::source::{MqttSplit, MqttSplitReader}; use crate::source::SourceProperties; @@ -34,7 +35,8 @@ pub struct MqttProperties { pub common: MqttCommon, // 0 - AtLeastOnce, 1 - AtMostOnce, 2 - ExactlyOnce - pub qos: Option, + #[serde(default, deserialize_with = "deserialize_u32_from_string")] + pub qos: u32, #[serde(flatten)] pub unknown_fields: HashMap, diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs index a2fb1478ba393..c07d54bee2946 100644 --- a/src/connector/src/source/mqtt/source/reader.rs +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::Context; use async_trait::async_trait; use futures_async_stream::try_stream; use risingwave_common::bail; use rumqttc::v5::mqttbytes::v5::Filter; use rumqttc::v5::mqttbytes::QoS; -use rumqttc::v5::{Event, Incoming}; +use rumqttc::v5::{ConnectionError, Event, Incoming}; use super::message::MqttMessage; use super::MqttSplit; @@ -32,6 +31,9 @@ use crate::source::{ pub struct MqttSplitReader { eventloop: rumqttc::v5::EventLoop, + client: rumqttc::v5::AsyncClient, + qos: QoS, + splits: Vec, properties: MqttProperties, parser_config: ParserConfig, source_ctx: SourceContextRef, @@ -53,27 +55,27 @@ impl SplitReader for MqttSplitReader { .common .build_client(source_ctx.actor_id, source_ctx.fragment_id)?; - let qos = if let Some(qos) = properties.qos { - match qos { - 0 => QoS::AtMostOnce, - 1 => QoS::AtLeastOnce, - 2 => QoS::ExactlyOnce, - _ => bail!("Invalid QoS level: {}", qos), - } - } else { - QoS::AtLeastOnce + let qos = match properties.qos { + 0 => QoS::AtMostOnce, + 1 => QoS::AtLeastOnce, + 2 => QoS::ExactlyOnce, + _ => bail!("Invalid QoS level: {}", properties.qos), }; client .subscribe_many( splits - .into_iter() + .iter() + .cloned() .map(|split| Filter::new(split.topic, qos)), ) .await?; Ok(Self { eventloop, + client, + qos, + splits, properties, parser_config, source_ctx, @@ -91,6 +93,9 @@ impl CommonSplitReader for MqttSplitReader { #[try_stream(ok = Vec, error = crate::error::ConnectorError)] async fn into_data_stream(self) { let mut eventloop = self.eventloop; + let client = self.client; + let qos = self.qos; + let splits = self.splits; loop { match eventloop.poll().await { Ok(Event::Incoming(Incoming::Publish(p))) => { @@ -98,7 +103,20 @@ impl CommonSplitReader for MqttSplitReader { yield vec![SourceMessage::from(msg)]; } Ok(_) => (), - Err(e) => Err(e).context("Error getting data from the event loop")?, + Err(e) => { + if let ConnectionError::Timeout(_) = e { + continue; + } + tracing::error!("[Reader] Failed to poll mqtt eventloop: {}", e); + client + .subscribe_many( + splits + .iter() + .cloned() + .map(|split| Filter::new(split.topic, qos)), + ) + .await?; + } } } } From 4f8f84d061edfdfb89d6c588e6ebd4522320c46d Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Sun, 3 Mar 2024 11:30:44 -0500 Subject: [PATCH 03/12] chore: Fix dylint errors --- integration_tests/mqtt/create_source.sql | 8 ++++---- src/connector/src/common.rs | 2 +- src/connector/src/sink/mqtt.rs | 13 +++++++++++-- src/connector/src/source/mqtt/enumerator/mod.rs | 3 ++- src/connector/src/source/mqtt/source/reader.rs | 3 ++- src/connector/with_options_sink.yaml | 13 ++++++++++++- src/connector/with_options_source.yaml | 3 ++- 7 files changed, 34 insertions(+), 11 deletions(-) diff --git a/integration_tests/mqtt/create_source.sql b/integration_tests/mqtt/create_source.sql index a586ee0966860..925082841b3e5 100644 --- a/integration_tests/mqtt/create_source.sql +++ b/integration_tests/mqtt/create_source.sql @@ -9,7 +9,8 @@ CREATE TABLE mqtt_source_table WITH ( connector='mqtt', host='mqtt-server', - topic= 'test' + topic= 'test', + qos = '1' ) FORMAT PLAIN ENCODE JSON; @@ -23,7 +24,7 @@ WITH topic= 'test', type = 'append-only', force_append_only='true', - retain = 'true', + retain = 'false', qos = '1' ); @@ -39,5 +40,4 @@ VALUES (7, 'Posey'), (8, 'Waverly'); - -FLUSH; \ No newline at end of file +FLUSH; diff --git a/src/connector/src/common.rs b/src/connector/src/common.rs index 68f757da07452..6638537028729 100644 --- a/src/connector/src/common.rs +++ b/src/connector/src/common.rs @@ -778,7 +778,7 @@ impl MqttCommon { options.set_credentials(user, self.password.as_deref().unwrap_or_default()); } - Ok(rumqttc::v5::AsyncClient::new(options, 10)) + Ok(rumqttc::v5::AsyncClient::new(options, 100)) } } diff --git a/src/connector/src/sink/mqtt.rs b/src/connector/src/sink/mqtt.rs index 031e3c77d3819..6deb07b5ea2af 100644 --- a/src/connector/src/sink/mqtt.rs +++ b/src/connector/src/sink/mqtt.rs @@ -23,6 +23,7 @@ use rumqttc::v5::mqttbytes::QoS; use rumqttc::v5::ConnectionError; use serde_derive::Deserialize; use serde_with::serde_as; +use thiserror_ext::AsReport; use tokio_retry::strategy::{jitter, ExponentialBackoff}; use tokio_retry::Retry; use with_options::WithOptions; @@ -122,9 +123,11 @@ impl Sink for MqttSink { "Nats sink only support append-only mode" ))); } + let _client = (self.config.common.build_client(0)) .context("validate mqtt sink error") .map_err(SinkError::Mqtt)?; + Ok(()) } @@ -171,11 +174,17 @@ impl MqttSinkWriter { if let ConnectionError::MqttState(rumqttc::v5::StateError::Io(err)) = err { if err.kind() != std::io::ErrorKind::ConnectionAborted { - tracing::error!("[Sink] Failed to poll mqtt eventloop: {}", err); + tracing::error!( + "[Sink] Failed to poll mqtt eventloop: {}", + err.as_report() + ); std::thread::sleep(std::time::Duration::from_secs(1)); } } else { - tracing::error!("[Sink] Failed to poll mqtt eventloop: {}", err); + tracing::error!( + "[Sink] Failed to poll mqtt eventloop: {}", + err.as_report() + ); std::thread::sleep(std::time::Duration::from_secs(1)); } } diff --git a/src/connector/src/source/mqtt/enumerator/mod.rs b/src/connector/src/source/mqtt/enumerator/mod.rs index 98f3fcec498bb..1a88603cedde2 100644 --- a/src/connector/src/source/mqtt/enumerator/mod.rs +++ b/src/connector/src/source/mqtt/enumerator/mod.rs @@ -20,6 +20,7 @@ use async_trait::async_trait; use risingwave_common::bail; use rumqttc::v5::{ConnectionError, Event, Incoming}; use rumqttc::Outgoing; +use thiserror_ext::AsReport; use tokio::sync::RwLock; use super::source::MqttSplit; @@ -93,7 +94,7 @@ impl SplitEnumerator for MqttSplitEnumerator { tracing::error!( "[Enumerator] Failed to subscribe to topic {}: {}", topic, - err + err.as_report(), ); connected_clone.store(false, std::sync::atomic::Ordering::Relaxed); cloned_client diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs index c07d54bee2946..af84373e07251 100644 --- a/src/connector/src/source/mqtt/source/reader.rs +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -18,6 +18,7 @@ use risingwave_common::bail; use rumqttc::v5::mqttbytes::v5::Filter; use rumqttc::v5::mqttbytes::QoS; use rumqttc::v5::{ConnectionError, Event, Incoming}; +use thiserror_ext::AsReport; use super::message::MqttMessage; use super::MqttSplit; @@ -107,7 +108,7 @@ impl CommonSplitReader for MqttSplitReader { if let ConnectionError::Timeout(_) = e { continue; } - tracing::error!("[Reader] Failed to poll mqtt eventloop: {}", e); + tracing::error!("[Reader] Failed to poll mqtt eventloop: {}", e.as_report()); client .subscribe_many( splits diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index 427efcf23ef85..96aba3f1d641b 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -350,7 +350,7 @@ KinesisSinkConfig: field_type: String required: false alias: kinesis.assumerole.external_id -MqttCommon: +MqttConfig: fields: - name: protocol field_type: String @@ -383,6 +383,17 @@ MqttCommon: - name: tls.client_key field_type: String required: false + - name: qos + field_type: u32 + required: false + default: Default::default + - name: retain + field_type: bool + required: false + default: Default::default + - name: r#type + field_type: String + required: true NatsConfig: fields: - name: server_url diff --git a/src/connector/with_options_source.yaml b/src/connector/with_options_source.yaml index fbca14d8699ff..74c8098cc4dcd 100644 --- a/src/connector/with_options_source.yaml +++ b/src/connector/with_options_source.yaml @@ -289,8 +289,9 @@ MqttProperties: field_type: String required: false - name: qos - field_type: i32 + field_type: u32 required: false + default: Default::default NatsProperties: fields: - name: server_url From 7b1b5947ed6e05414192ee9723c1803ed958f654 Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Tue, 5 Mar 2024 09:43:44 -0500 Subject: [PATCH 04/12] feat: Address PR comments --- integration_tests/mqtt/create_source.sql | 7 +- src/connector/src/common.rs | 69 +++++++- src/connector/src/sink/mqtt.rs | 164 ++++++++++-------- src/connector/src/source/mqtt/mod.rs | 14 +- .../src/source/mqtt/source/reader.rs | 19 +- src/connector/src/with_options.rs | 2 + src/connector/with_options_sink.yaml | 26 ++- src/connector/with_options_source.yaml | 25 ++- src/frontend/src/handler/create_sink.rs | 6 + 9 files changed, 233 insertions(+), 99 deletions(-) diff --git a/integration_tests/mqtt/create_source.sql b/integration_tests/mqtt/create_source.sql index 925082841b3e5..04be605700c76 100644 --- a/integration_tests/mqtt/create_source.sql +++ b/integration_tests/mqtt/create_source.sql @@ -10,7 +10,7 @@ WITH ( connector='mqtt', host='mqtt-server', topic= 'test', - qos = '1' + qos = 'at_least_once', ) FORMAT PLAIN ENCODE JSON; @@ -23,9 +23,10 @@ WITH host='mqtt-server', topic= 'test', type = 'append-only', - force_append_only='true', retain = 'false', - qos = '1' + qos = 'at_least_once', + ) FORMAT PLAIN ENCODE JSON ( + force_append_only='true', ); INSERT INTO diff --git a/src/connector/src/common.rs b/src/connector/src/common.rs index 6638537028729..0aaeb3f06bf42 100644 --- a/src/connector/src/common.rs +++ b/src/connector/src/common.rs @@ -28,16 +28,17 @@ use risingwave_common::bail; use serde_derive::Deserialize; use serde_with::json::JsonString; use serde_with::{serde_as, DisplayFromStr}; +use strum_macros::{Display, EnumString}; use tempfile::NamedTempFile; use time::OffsetDateTime; use url::Url; use with_options::WithOptions; use crate::aws_utils::load_file_descriptor_from_s3; -use crate::deserialize_duration_from_string; use crate::error::ConnectorResult; use crate::sink::SinkError; use crate::source::nats::source::NatsOffset; +use crate::{deserialize_bool_from_string, deserialize_duration_from_string}; // The file describes the common abstractions for each connector and can be used in both source and // sink. @@ -685,29 +686,81 @@ impl NatsCommon { } } +#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)] +#[strum(serialize_all = "snake_case")] +#[allow(clippy::enum_variant_names)] +pub enum QualityOfService { + AtLeastOnce, + AtMostOnce, + ExactlyOnce, +} + +#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)] +#[strum(serialize_all = "snake_case")] +pub enum Protocol { + Tls, + Ssl, +} + #[serde_as] #[derive(Deserialize, Debug, Clone, WithOptions)] pub struct MqttCommon { - /// Protocol used for RisingWave to communicate with Kafka brokers. Could be tcp or ssl + /// Protocol used for RisingWave to communicate with the mqtt brokers. Could be tcp or ssl + #[serde_as(as = "Option")] #[serde(rename = "protocol")] - pub protocol: Option, + pub protocol: Option, + + /// Hostname of the mqtt broker #[serde(rename = "host")] pub host: String, + + /// Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl #[serde(rename = "port")] pub port: Option, + + /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# #[serde(rename = "topic")] pub topic: String, + + /// Username for the mqtt broker #[serde(rename = "username")] pub user: Option, + + /// Password for the mqtt broker #[serde(rename = "password")] pub password: Option, #[serde(rename = "client_prefix")] + + /// Prefix for the mqtt client id pub client_prefix: Option, + + /// `clean_start = true` removes all the state from queues & instructs the broker + /// to clean all the client state when client disconnects. + /// + /// When set `false`, broker will hold the client state and performs pending + /// operations on the client when reconnection with same `client_id` + /// happens. Local queue state is also held to retransmit packets after reconnection. + #[serde(rename = "clean_start")] + #[serde(default, deserialize_with = "deserialize_bool_from_string")] + pub clean_start: bool, + + /// The maximum number of inflight messages. Defaults to 100 + #[serde(rename = "inflight_messages")] + #[serde_as(as = "Option")] + pub inflight_messages: Option, #[serde(rename = "tls.ca")] + + /// Path to CA certificate file for verifying the broker's key. pub ca: Option, #[serde(rename = "tls.client_cert")] + + /// Path to client's certificate file (PEM). Required for client authentication. + /// Can be a file path under fs:// or a string with the certificate content. pub client_cert: Option, #[serde(rename = "tls.client_key")] + + /// Path to client's private key file (PEM). Required for client authentication. + /// Can be a file path under fs:// or a string with the private key content. pub client_key: Option, } @@ -719,7 +772,7 @@ impl MqttCommon { let ssl = self .protocol .as_ref() - .map(|p| p == "ssl") + .map(|p| p == &Protocol::Ssl) .unwrap_or_default(); let client_id = format!( @@ -737,6 +790,9 @@ impl MqttCommon { let mut options = rumqttc::v5::MqttOptions::new(client_id, &self.host, port); options.set_keep_alive(std::time::Duration::from_secs(10)); + + options.set_clean_start(self.clean_start); + if ssl { let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); if let Some(ca) = &self.ca { @@ -778,7 +834,10 @@ impl MqttCommon { options.set_credentials(user, self.password.as_deref().unwrap_or_default()); } - Ok(rumqttc::v5::AsyncClient::new(options, 100)) + Ok(rumqttc::v5::AsyncClient::new( + options, + self.inflight_messages.unwrap_or(100), + )) } } diff --git a/src/connector/src/sink/mqtt.rs b/src/connector/src/sink/mqtt.rs index 6deb07b5ea2af..ab001cee3eb78 100644 --- a/src/connector/src/sink/mqtt.rs +++ b/src/connector/src/sink/mqtt.rs @@ -22,24 +22,22 @@ use risingwave_common::catalog::Schema; use rumqttc::v5::mqttbytes::QoS; use rumqttc::v5::ConnectionError; use serde_derive::Deserialize; -use serde_with::serde_as; +use serde_with::{serde_as, DisplayFromStr}; use thiserror_ext::AsReport; -use tokio_retry::strategy::{jitter, ExponentialBackoff}; -use tokio_retry::Retry; use with_options::WithOptions; -use super::encoder::{DateHandlingMode, TimeHandlingMode, TimestamptzHandlingMode}; -use super::utils::chunk_to_json; +use super::catalog::SinkFormatDesc; +use super::formatter::SinkFormatterImpl; +use super::writer::FormattedSink; use super::{DummySinkCommitCoordinator, SinkWriterParam}; -use crate::common::MqttCommon; +use crate::common::{MqttCommon, QualityOfService}; use crate::sink::catalog::desc::SinkDesc; -use crate::sink::encoder::{JsonEncoder, TimestampHandlingMode}; use crate::sink::log_store::DeliveryFutureManagerAddFuture; use crate::sink::writer::{ AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter, AsyncTruncateSinkWriterExt, }; use crate::sink::{Result, Sink, SinkError, SinkParam, SINK_TYPE_APPEND_ONLY}; -use crate::{deserialize_bool_from_string, deserialize_u32_from_string}; +use crate::{deserialize_bool_from_string, dispatch_sink_formatter_impl}; pub const MQTT_SINK: &str = "mqtt"; @@ -49,10 +47,12 @@ pub struct MqttConfig { #[serde(flatten)] pub common: MqttCommon, - // 0 - AtLeastOnce, 1 - AtMostOnce, 2 - ExactlyOnce - #[serde(default, deserialize_with = "deserialize_u32_from_string")] - pub qos: u32, + /// The quality of service to use when publishing messages. Defaults to at_most_once. + /// Could be at_most_once, at_least_once or exactly_once + #[serde_as(as = "Option")] + pub qos: Option, + /// Whether the message should be retained by the broker #[serde(default, deserialize_with = "deserialize_bool_from_string")] pub retain: bool, @@ -64,17 +64,19 @@ pub struct MqttConfig { pub struct MqttSink { pub config: MqttConfig, schema: Schema, + pk_indices: Vec, + format_desc: SinkFormatDesc, + db_name: String, + sink_from_name: String, is_append_only: bool, } // sink write pub struct MqttSinkWriter { pub config: MqttConfig, - client: rumqttc::v5::AsyncClient, - qos: QoS, - retain: bool, + payload_writer: MqttSinkPayloadWriter, schema: Schema, - json_encoder: JsonEncoder, + formatter: SinkFormatterImpl, stopped: Arc, } @@ -102,6 +104,12 @@ impl TryFrom for MqttSink { Ok(Self { config, schema, + pk_indices: param.downstream_pk, + format_desc: param + .format_desc + .ok_or_else(|| SinkError::Config(anyhow!("missing FORMAT ... ENCODE ...")))?, + db_name: param.db_name, + sink_from_name: param.sink_from_name, is_append_only: param.sink_type.is_append_only(), }) } @@ -135,25 +143,46 @@ impl Sink for MqttSink { Ok(MqttSinkWriter::new( self.config.clone(), self.schema.clone(), + self.pk_indices.clone(), + &self.format_desc, + self.db_name.clone(), + self.sink_from_name.clone(), writer_param.executor_id, - )? + ) + .await? .into_log_sinker(usize::MAX)) } } impl MqttSinkWriter { - pub fn new(config: MqttConfig, schema: Schema, id: u64) -> Result { - let qos = match config.qos { - 0 => QoS::AtMostOnce, - 1 => QoS::AtLeastOnce, - 2 => QoS::ExactlyOnce, - _ => { - return Err(SinkError::Mqtt(anyhow!( - "Invalid QoS level: {}", - config.qos - ))) - } - }; + pub async fn new( + config: MqttConfig, + schema: Schema, + pk_indices: Vec, + format_desc: &SinkFormatDesc, + db_name: String, + sink_from_name: String, + id: u64, + ) -> Result { + let formatter = SinkFormatterImpl::new( + format_desc, + schema.clone(), + pk_indices.clone(), + db_name, + sink_from_name, + &config.common.topic, + ) + .await?; + + let qos = config + .qos + .as_ref() + .map(|qos| match qos { + QualityOfService::AtLeastOnce => QoS::AtMostOnce, + QualityOfService::AtMostOnce => QoS::AtLeastOnce, + QualityOfService::ExactlyOnce => QoS::ExactlyOnce, + }) + .unwrap_or(QoS::AtMostOnce); let (client, mut eventloop) = config .common @@ -175,16 +204,13 @@ impl MqttSinkWriter { if let ConnectionError::MqttState(rumqttc::v5::StateError::Io(err)) = err { if err.kind() != std::io::ErrorKind::ConnectionAborted { tracing::error!( - "[Sink] Failed to poll mqtt eventloop: {}", + "Failed to poll mqtt eventloop: {}", err.as_report() ); std::thread::sleep(std::time::Duration::from_secs(1)); } } else { - tracing::error!( - "[Sink] Failed to poll mqtt eventloop: {}", - err.as_report() - ); + tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report()); std::thread::sleep(std::time::Duration::from_secs(1)); } } @@ -192,48 +218,21 @@ impl MqttSinkWriter { } }); - Ok::<_, SinkError>(Self { - config: config.clone(), + let payload_writer = MqttSinkPayloadWriter { client, + config: config.clone(), qos, retain: config.retain, + }; + + Ok::<_, SinkError>(Self { + config: config.clone(), + payload_writer, schema: schema.clone(), stopped, - json_encoder: JsonEncoder::new( - schema, - None, - DateHandlingMode::FromCe, - TimestampHandlingMode::Milli, - TimestamptzHandlingMode::UtcWithoutSuffix, - TimeHandlingMode::Milli, - ), + formatter, }) } - - async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> { - Retry::spawn( - ExponentialBackoff::from_millis(100).map(jitter).take(3), - || async { - let data = chunk_to_json(chunk.clone(), &self.json_encoder).unwrap(); - for item in data { - self.client - .publish( - &self.config.common.topic, - self.qos, - self.retain, - item.into_bytes(), - ) - .await - .context("mqtt sink error") - .map_err(SinkError::Mqtt)?; - } - Ok::<_, SinkError>(()) - }, - ) - .await - .context("mqtts sink error") - .map_err(SinkError::Mqtt) - } } impl AsyncTruncateSinkWriter for MqttSinkWriter { @@ -242,7 +241,9 @@ impl AsyncTruncateSinkWriter for MqttSinkWriter { chunk: StreamChunk, _add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>, ) -> Result<()> { - self.append_only(chunk).await + dispatch_sink_formatter_impl!(&self.formatter, formatter, { + self.payload_writer.write_chunk(chunk, formatter).await + }) } } @@ -252,3 +253,28 @@ impl Drop for MqttSinkWriter { .store(true, std::sync::atomic::Ordering::Relaxed); } } + +struct MqttSinkPayloadWriter { + // connection to mqtt, one per executor + client: rumqttc::v5::AsyncClient, + config: MqttConfig, + qos: QoS, + retain: bool, +} + +impl FormattedSink for MqttSinkPayloadWriter { + type K = Vec; + type V = Vec; + + async fn write_one(&mut self, _k: Option, v: Option) -> Result<()> { + match v { + Some(v) => self + .client + .publish(&self.config.common.topic, self.qos, self.retain, v) + .await + .context("mqtt sink error") + .map_err(SinkError::Mqtt), + None => Ok(()), + } + } +} diff --git a/src/connector/src/source/mqtt/mod.rs b/src/connector/src/source/mqtt/mod.rs index 0aeef04e58ee7..033e8585f57f8 100644 --- a/src/connector/src/source/mqtt/mod.rs +++ b/src/connector/src/source/mqtt/mod.rs @@ -18,25 +18,27 @@ pub mod split; use std::collections::HashMap; -use serde::Deserialize; +use serde_derive::Deserialize; +use serde_with::{serde_as, DisplayFromStr}; use with_options::WithOptions; -use crate::common::MqttCommon; -use crate::deserialize_u32_from_string; +use crate::common::{MqttCommon, QualityOfService}; use crate::source::mqtt::enumerator::MqttSplitEnumerator; use crate::source::mqtt::source::{MqttSplit, MqttSplitReader}; use crate::source::SourceProperties; pub const MQTT_CONNECTOR: &str = "mqtt"; +#[serde_as] #[derive(Clone, Debug, Deserialize, WithOptions)] pub struct MqttProperties { #[serde(flatten)] pub common: MqttCommon, - // 0 - AtLeastOnce, 1 - AtMostOnce, 2 - ExactlyOnce - #[serde(default, deserialize_with = "deserialize_u32_from_string")] - pub qos: u32, + /// The quality of service to use when publishing messages. Defaults to at_most_once. + /// Could be at_most_once, at_least_once or exactly_once + #[serde_as(as = "Option")] + pub qos: Option, #[serde(flatten)] pub unknown_fields: HashMap, diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs index af84373e07251..4182fe4876014 100644 --- a/src/connector/src/source/mqtt/source/reader.rs +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -14,7 +14,6 @@ use async_trait::async_trait; use futures_async_stream::try_stream; -use risingwave_common::bail; use rumqttc::v5::mqttbytes::v5::Filter; use rumqttc::v5::mqttbytes::QoS; use rumqttc::v5::{ConnectionError, Event, Incoming}; @@ -22,6 +21,7 @@ use thiserror_ext::AsReport; use super::message::MqttMessage; use super::MqttSplit; +use crate::common::QualityOfService; use crate::error::ConnectorResult as Result; use crate::parser::ParserConfig; use crate::source::common::{into_chunk_stream, CommonSplitReader}; @@ -56,12 +56,15 @@ impl SplitReader for MqttSplitReader { .common .build_client(source_ctx.actor_id, source_ctx.fragment_id)?; - let qos = match properties.qos { - 0 => QoS::AtMostOnce, - 1 => QoS::AtLeastOnce, - 2 => QoS::ExactlyOnce, - _ => bail!("Invalid QoS level: {}", properties.qos), - }; + let qos = properties + .qos + .as_ref() + .map(|qos| match qos { + QualityOfService::AtLeastOnce => QoS::AtMostOnce, + QualityOfService::AtMostOnce => QoS::AtLeastOnce, + QualityOfService::ExactlyOnce => QoS::ExactlyOnce, + }) + .unwrap_or(QoS::AtMostOnce); client .subscribe_many( @@ -108,7 +111,7 @@ impl CommonSplitReader for MqttSplitReader { if let ConnectionError::Timeout(_) = e { continue; } - tracing::error!("[Reader] Failed to poll mqtt eventloop: {}", e.as_report()); + tracing::error!("Failed to poll mqtt eventloop: {}", e.as_report()); client .subscribe_many( splits diff --git a/src/connector/src/with_options.rs b/src/connector/src/with_options.rs index a113c9026fd0f..c233fb6e3a33b 100644 --- a/src/connector/src/with_options.rs +++ b/src/connector/src/with_options.rs @@ -52,6 +52,8 @@ impl WithOptions for i32 {} impl WithOptions for i64 {} impl WithOptions for f64 {} impl WithOptions for std::time::Duration {} +impl WithOptions for crate::common::QualityOfService {} +impl WithOptions for crate::common::Protocol {} impl WithOptions for crate::sink::kafka::CompressionCodec {} impl WithOptions for nexmark::config::RateShape {} impl WithOptions for nexmark::event::EventType {} diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index 96aba3f1d641b..e65098727c20f 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -353,42 +353,60 @@ KinesisSinkConfig: MqttConfig: fields: - name: protocol - field_type: String - comments: Protocol used for RisingWave to communicate with Kafka brokers. Could be tcp or ssl + field_type: Protocol + comments: Protocol used for RisingWave to communicate with the mqtt brokers. Could be tcp or ssl required: false - name: host field_type: String + comments: Hostname of the mqtt broker required: true - name: port field_type: i32 + comments: Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl required: false - name: topic field_type: String + comments: The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# required: true - name: username field_type: String + comments: Username for the mqtt broker required: false - name: password field_type: String + comments: Password for the mqtt broker required: false - name: client_prefix field_type: String + comments: Prefix for the mqtt client id + required: false + - name: clean_start + field_type: bool + comments: '`clean_start = true` removes all the state from queues & instructs the broker to clean all the client state when client disconnects. When set `false`, broker will hold the client state and performs pending operations on the client when reconnection with same `client_id` happens. Local queue state is also held to retransmit packets after reconnection.' + required: true + - name: inflight_messages + field_type: usize + comments: The maximum number of inflight messages. Defaults to 100 required: false - name: tls.ca field_type: String + comments: Path to CA certificate file for verifying the broker's key. required: false - name: tls.client_cert field_type: String + comments: Path to client's certificate file (PEM). Required for client authentication. Can be a file path under fs:// or a string with the certificate content. required: false - name: tls.client_key field_type: String + comments: Path to client's private key file (PEM). Required for client authentication. Can be a file path under fs:// or a string with the private key content. required: false - name: qos - field_type: u32 + field_type: QualityOfService + comments: The quality of service to use when publishing messages. Defaults to at_most_once. Could be at_most_once, at_least_once or exactly_once required: false - default: Default::default - name: retain field_type: bool + comments: Whether the message should be retained by the broker required: false default: Default::default - name: r#type diff --git a/src/connector/with_options_source.yaml b/src/connector/with_options_source.yaml index 74c8098cc4dcd..29b9f9355fac7 100644 --- a/src/connector/with_options_source.yaml +++ b/src/connector/with_options_source.yaml @@ -258,40 +258,57 @@ KinesisProperties: MqttProperties: fields: - name: protocol - field_type: String - comments: Protocol used for RisingWave to communicate with Kafka brokers. Could be tcp or ssl + field_type: Protocol + comments: Protocol used for RisingWave to communicate with the mqtt brokers. Could be tcp or ssl required: false - name: host field_type: String + comments: Hostname of the mqtt broker required: true - name: port field_type: i32 + comments: Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl required: false - name: topic field_type: String + comments: The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# required: true - name: username field_type: String + comments: Username for the mqtt broker required: false - name: password field_type: String + comments: Password for the mqtt broker required: false - name: client_prefix field_type: String + comments: Prefix for the mqtt client id + required: false + - name: clean_start + field_type: bool + comments: '`clean_start = true` removes all the state from queues & instructs the broker to clean all the client state when client disconnects. When set `false`, broker will hold the client state and performs pending operations on the client when reconnection with same `client_id` happens. Local queue state is also held to retransmit packets after reconnection.' + required: true + - name: inflight_messages + field_type: usize + comments: The maximum number of inflight messages. Defaults to 100 required: false - name: tls.ca field_type: String + comments: Path to CA certificate file for verifying the broker's key. required: false - name: tls.client_cert field_type: String + comments: Path to client's certificate file (PEM). Required for client authentication. Can be a file path under fs:// or a string with the certificate content. required: false - name: tls.client_key field_type: String + comments: Path to client's private key file (PEM). Required for client authentication. Can be a file path under fs:// or a string with the private key content. required: false - name: qos - field_type: u32 + field_type: QualityOfService + comments: The quality of service to use when publishing messages. Defaults to at_most_once. Could be at_most_once, at_least_once or exactly_once required: false - default: Default::default NatsProperties: fields: - name: server_url diff --git a/src/frontend/src/handler/create_sink.rs b/src/frontend/src/handler/create_sink.rs index 9de145dc47801..d743fb2d3f235 100644 --- a/src/frontend/src/handler/create_sink.rs +++ b/src/frontend/src/handler/create_sink.rs @@ -759,6 +759,7 @@ static CONNECTORS_COMPATIBLE_FORMATS: LazyLock vec![Encode::Json], Format::Debezium => vec![Encode::Json], ), + MqttSink::SINK_NAME => hashmap!( + Format::Plain => vec![Encode::Json,Encode::Bytes], + Format::Upsert => vec![Encode::Json,Encode::Bytes], + Format::Debezium => vec![Encode::Json,Encode::Bytes], + ), PulsarSink::SINK_NAME => hashmap!( Format::Plain => vec![Encode::Json], Format::Upsert => vec![Encode::Json], From 2d226d099d9f56517c8872774abd3cfe63be81cf Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Tue, 5 Mar 2024 10:15:39 -0500 Subject: [PATCH 05/12] fix: QoS conversion --- src/connector/src/sink/mqtt.rs | 4 ++-- src/connector/src/source/mqtt/enumerator/mod.rs | 4 ++-- src/connector/src/source/mqtt/source/reader.rs | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/connector/src/sink/mqtt.rs b/src/connector/src/sink/mqtt.rs index ab001cee3eb78..6442789b12617 100644 --- a/src/connector/src/sink/mqtt.rs +++ b/src/connector/src/sink/mqtt.rs @@ -178,8 +178,8 @@ impl MqttSinkWriter { .qos .as_ref() .map(|qos| match qos { - QualityOfService::AtLeastOnce => QoS::AtMostOnce, - QualityOfService::AtMostOnce => QoS::AtLeastOnce, + QualityOfService::AtMostOnce => QoS::AtMostOnce, + QualityOfService::AtLeastOnce => QoS::AtLeastOnce, QualityOfService::ExactlyOnce => QoS::ExactlyOnce, }) .unwrap_or(QoS::AtMostOnce); diff --git a/src/connector/src/source/mqtt/enumerator/mod.rs b/src/connector/src/source/mqtt/enumerator/mod.rs index 1a88603cedde2..5cfd952ab0121 100644 --- a/src/connector/src/source/mqtt/enumerator/mod.rs +++ b/src/connector/src/source/mqtt/enumerator/mod.rs @@ -54,7 +54,7 @@ impl SplitEnumerator for MqttSplitEnumerator { } client - .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtLeastOnce) + .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtMostOnce) .await?; let cloned_client = client.clone(); @@ -98,7 +98,7 @@ impl SplitEnumerator for MqttSplitEnumerator { ); connected_clone.store(false, std::sync::atomic::Ordering::Relaxed); cloned_client - .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtLeastOnce) + .subscribe(topic.clone(), rumqttc::v5::mqttbytes::QoS::AtMostOnce) .await .unwrap(); } diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs index 4182fe4876014..41e9036d11ff2 100644 --- a/src/connector/src/source/mqtt/source/reader.rs +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -60,8 +60,8 @@ impl SplitReader for MqttSplitReader { .qos .as_ref() .map(|qos| match qos { - QualityOfService::AtLeastOnce => QoS::AtMostOnce, - QualityOfService::AtMostOnce => QoS::AtLeastOnce, + QualityOfService::AtMostOnce => QoS::AtMostOnce, + QualityOfService::AtLeastOnce => QoS::AtLeastOnce, QualityOfService::ExactlyOnce => QoS::ExactlyOnce, }) .unwrap_or(QoS::AtMostOnce); From 8fead890b3eab15e47ed5173f87c673bf2fff75c Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Tue, 5 Mar 2024 12:25:20 -0500 Subject: [PATCH 06/12] feat: Address PR comments --- src/connector/src/common.rs | 22 +++++++--------------- src/connector/with_options_sink.yaml | 9 +++++---- src/connector/with_options_source.yaml | 9 +++++---- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/connector/src/common.rs b/src/connector/src/common.rs index 0aaeb3f06bf42..79be90a4d4ec3 100644 --- a/src/connector/src/common.rs +++ b/src/connector/src/common.rs @@ -698,28 +698,24 @@ pub enum QualityOfService { #[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)] #[strum(serialize_all = "snake_case")] pub enum Protocol { - Tls, + Tcp, Ssl, } #[serde_as] #[derive(Deserialize, Debug, Clone, WithOptions)] pub struct MqttCommon { - /// Protocol used for RisingWave to communicate with the mqtt brokers. Could be tcp or ssl + /// Protocol used for RisingWave to communicate with the mqtt brokers. Could be `tcp` or `ssl`, defaults to `tcp` #[serde_as(as = "Option")] - #[serde(rename = "protocol")] pub protocol: Option, /// Hostname of the mqtt broker - #[serde(rename = "host")] pub host: String, /// Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl - #[serde(rename = "port")] pub port: Option, /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# - #[serde(rename = "topic")] pub topic: String, /// Username for the mqtt broker @@ -727,11 +723,10 @@ pub struct MqttCommon { pub user: Option, /// Password for the mqtt broker - #[serde(rename = "password")] pub password: Option, - #[serde(rename = "client_prefix")] - /// Prefix for the mqtt client id + /// Prefix for the mqtt client id. + /// The client id will be generated as `client_prefix`_`id`_`timestamp`. Defaults to risingwave pub client_prefix: Option, /// `clean_start = true` removes all the state from queues & instructs the broker @@ -740,27 +735,24 @@ pub struct MqttCommon { /// When set `false`, broker will hold the client state and performs pending /// operations on the client when reconnection with same `client_id` /// happens. Local queue state is also held to retransmit packets after reconnection. - #[serde(rename = "clean_start")] #[serde(default, deserialize_with = "deserialize_bool_from_string")] pub clean_start: bool, /// The maximum number of inflight messages. Defaults to 100 - #[serde(rename = "inflight_messages")] #[serde_as(as = "Option")] pub inflight_messages: Option, - #[serde(rename = "tls.ca")] /// Path to CA certificate file for verifying the broker's key. + #[serde(rename = "tls.client_key")] pub ca: Option, - #[serde(rename = "tls.client_cert")] - /// Path to client's certificate file (PEM). Required for client authentication. /// Can be a file path under fs:// or a string with the certificate content. + #[serde(rename = "tls.client_cert")] pub client_cert: Option, - #[serde(rename = "tls.client_key")] /// Path to client's private key file (PEM). Required for client authentication. /// Can be a file path under fs:// or a string with the private key content. + #[serde(rename = "tls.client_key")] pub client_key: Option, } diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index e65098727c20f..a119224a8d580 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -354,7 +354,7 @@ MqttConfig: fields: - name: protocol field_type: Protocol - comments: Protocol used for RisingWave to communicate with the mqtt brokers. Could be tcp or ssl + comments: Protocol used for RisingWave to communicate with the mqtt brokers. Could be `tcp` or `ssl`, defaults to `tcp` required: false - name: host field_type: String @@ -378,17 +378,18 @@ MqttConfig: required: false - name: client_prefix field_type: String - comments: Prefix for the mqtt client id + comments: Prefix for the mqtt client id. The client id will be generated as `client_prefix`_`id`_`timestamp`. Defaults to risingwave required: false - name: clean_start field_type: bool comments: '`clean_start = true` removes all the state from queues & instructs the broker to clean all the client state when client disconnects. When set `false`, broker will hold the client state and performs pending operations on the client when reconnection with same `client_id` happens. Local queue state is also held to retransmit packets after reconnection.' - required: true + required: false + default: Default::default - name: inflight_messages field_type: usize comments: The maximum number of inflight messages. Defaults to 100 required: false - - name: tls.ca + - name: tls.client_key field_type: String comments: Path to CA certificate file for verifying the broker's key. required: false diff --git a/src/connector/with_options_source.yaml b/src/connector/with_options_source.yaml index 29b9f9355fac7..2ae218d7e2a39 100644 --- a/src/connector/with_options_source.yaml +++ b/src/connector/with_options_source.yaml @@ -259,7 +259,7 @@ MqttProperties: fields: - name: protocol field_type: Protocol - comments: Protocol used for RisingWave to communicate with the mqtt brokers. Could be tcp or ssl + comments: Protocol used for RisingWave to communicate with the mqtt brokers. Could be `tcp` or `ssl`, defaults to `tcp` required: false - name: host field_type: String @@ -283,17 +283,18 @@ MqttProperties: required: false - name: client_prefix field_type: String - comments: Prefix for the mqtt client id + comments: Prefix for the mqtt client id. The client id will be generated as `client_prefix`_`id`_`timestamp`. Defaults to risingwave required: false - name: clean_start field_type: bool comments: '`clean_start = true` removes all the state from queues & instructs the broker to clean all the client state when client disconnects. When set `false`, broker will hold the client state and performs pending operations on the client when reconnection with same `client_id` happens. Local queue state is also held to retransmit packets after reconnection.' - required: true + required: false + default: Default::default - name: inflight_messages field_type: usize comments: The maximum number of inflight messages. Defaults to 100 required: false - - name: tls.ca + - name: tls.client_key field_type: String comments: Path to CA certificate file for verifying the broker's key. required: false From 4ebd97ea11cb21d85a0fd0c1af9ca1c1d5474445 Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Tue, 5 Mar 2024 15:40:30 -0500 Subject: [PATCH 07/12] feat: Address PR comments --- Cargo.lock | 1 + integration_tests/mqtt/create_source.sql | 4 +- src/connector/Cargo.toml | 2 +- src/connector/src/common.rs | 132 ++++++++++-------- src/connector/src/error.rs | 1 + src/connector/src/sink/mqtt.rs | 48 +++---- .../src/source/mqtt/enumerator/mod.rs | 6 +- .../src/source/mqtt/source/reader.rs | 13 +- src/connector/src/with_options.rs | 1 - src/connector/with_options_sink.yaml | 20 +-- src/connector/with_options_source.yaml | 16 +-- 11 files changed, 109 insertions(+), 135 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0414976469cc1..99627dd7e9401 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10349,6 +10349,7 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls", + "url", ] [[package]] diff --git a/integration_tests/mqtt/create_source.sql b/integration_tests/mqtt/create_source.sql index 04be605700c76..8c63216125280 100644 --- a/integration_tests/mqtt/create_source.sql +++ b/integration_tests/mqtt/create_source.sql @@ -8,7 +8,7 @@ CREATE TABLE mqtt_source_table ) WITH ( connector='mqtt', - host='mqtt-server', + url='tcp://mqtt-server', topic= 'test', qos = 'at_least_once', ) FORMAT PLAIN ENCODE JSON; @@ -20,7 +20,7 @@ FROM WITH ( connector='mqtt', - host='mqtt-server', + url='tcp://mqtt-server', topic= 'test', type = 'append-only', retain = 'false', diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index e997e71201242..e18b5d2bb2c8f 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -117,7 +117,7 @@ risingwave_common = { workspace = true } risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } -rumqttc = "0.22.0" +rumqttc = { version = "0.22.0", features = ["url"] } rust_decimal = "1" rustls-native-certs = "0.6" rustls-pemfile = "1" diff --git a/src/connector/src/common.rs b/src/connector/src/common.rs index 79be90a4d4ec3..5600097f9c7ca 100644 --- a/src/connector/src/common.rs +++ b/src/connector/src/common.rs @@ -695,29 +695,23 @@ pub enum QualityOfService { ExactlyOnce, } -#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)] -#[strum(serialize_all = "snake_case")] -pub enum Protocol { - Tcp, - Ssl, -} - #[serde_as] #[derive(Deserialize, Debug, Clone, WithOptions)] pub struct MqttCommon { - /// Protocol used for RisingWave to communicate with the mqtt brokers. Could be `tcp` or `ssl`, defaults to `tcp` - #[serde_as(as = "Option")] - pub protocol: Option, - - /// Hostname of the mqtt broker - pub host: String, - - /// Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl - pub port: Option, + /// The url of the broker to connect to. e.g. tcp://localhost. + /// Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, + /// `ws://` or `wss://` to denote the protocol for establishing a connection with the broker. + /// `mqtts://`, `ssl://`, `wss://` + pub url: String, /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# pub topic: String, + /// The quality of service to use when publishing messages. Defaults to at_most_once. + /// Could be at_most_once, at_least_once or exactly_once + #[serde_as(as = "Option")] + pub qos: Option, + /// Username for the mqtt broker #[serde(rename = "username")] pub user: Option, @@ -759,64 +753,32 @@ pub struct MqttCommon { impl MqttCommon { pub(crate) fn build_client( &self, + actor_id: u32, id: u32, ) -> ConnectorResult<(rumqttc::v5::AsyncClient, rumqttc::v5::EventLoop)> { - let ssl = self - .protocol - .as_ref() - .map(|p| p == &Protocol::Ssl) - .unwrap_or_default(); - let client_id = format!( - "{}_{}{}", + "{}_{}_{}", self.client_prefix.as_deref().unwrap_or("risingwave"), - id, - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis() - % 100000, + actor_id, + id ); - let port = self.port.unwrap_or(if ssl { 8883 } else { 1883 }) as u16; + let mut url = url::Url::parse(&self.url)?; - let mut options = rumqttc::v5::MqttOptions::new(client_id, &self.host, port); + let ssl = match url.scheme() { + "mqtts" | "ssl" | "wss" => true, + _ => false, + }; + + url.query_pairs_mut().append_pair("client_id", &client_id); + + let mut options = rumqttc::v5::MqttOptions::try_from(url)?; options.set_keep_alive(std::time::Duration::from_secs(10)); options.set_clean_start(self.clean_start); if ssl { - let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); - if let Some(ca) = &self.ca { - let certificates = load_certs(ca)?; - for cert in certificates { - root_cert_store.add(&cert).unwrap(); - } - } else { - for cert in - rustls_native_certs::load_native_certs().expect("could not load platform certs") - { - root_cert_store - .add(&tokio_rustls::rustls::Certificate(cert.0)) - .unwrap(); - } - } - - let builder = tokio_rustls::rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store); - - let tls_config = if let (Some(client_cert), Some(client_key)) = - (self.client_cert.as_ref(), self.client_key.as_ref()) - { - let certs = load_certs(client_cert)?; - let key = load_private_key(client_key)?; - - builder.with_client_auth_cert(certs, key)? - } else { - builder.with_no_client_auth() - }; - + let tls_config = self.get_tls_config()?; options.set_transport(rumqttc::Transport::tls_with_config( rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)), )); @@ -831,6 +793,52 @@ impl MqttCommon { self.inflight_messages.unwrap_or(100), )) } + + pub(crate) fn qos(&self) -> rumqttc::v5::mqttbytes::QoS { + self.qos + .as_ref() + .map(|qos| match qos { + QualityOfService::AtMostOnce => rumqttc::v5::mqttbytes::QoS::AtMostOnce, + QualityOfService::AtLeastOnce => rumqttc::v5::mqttbytes::QoS::AtLeastOnce, + QualityOfService::ExactlyOnce => rumqttc::v5::mqttbytes::QoS::ExactlyOnce, + }) + .unwrap_or(rumqttc::v5::mqttbytes::QoS::AtMostOnce) + } + + fn get_tls_config(&self) -> ConnectorResult { + let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); + if let Some(ca) = &self.ca { + let certificates = load_certs(ca)?; + for cert in certificates { + root_cert_store.add(&cert).unwrap(); + } + } else { + for cert in + rustls_native_certs::load_native_certs().expect("could not load platform certs") + { + root_cert_store + .add(&tokio_rustls::rustls::Certificate(cert.0)) + .unwrap(); + } + } + + let builder = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store); + + let tls_config = if let (Some(client_cert), Some(client_key)) = + (self.client_cert.as_ref(), self.client_key.as_ref()) + { + let certs = load_certs(client_cert)?; + let key = load_private_key(client_key)?; + + builder.with_client_auth_cert(certs, key)? + } else { + builder.with_no_client_auth() + }; + + Ok(tls_config) + } } fn load_certs(certificates: &str) -> ConnectorResult> { diff --git a/src/connector/src/error.rs b/src/connector/src/error.rs index 1317981f88919..603b2c01f5810 100644 --- a/src/connector/src/error.rs +++ b/src/connector/src/error.rs @@ -60,6 +60,7 @@ def_anyhow_newtype! { google_cloud_pubsub::client::google_cloud_auth::error::Error => "Google Cloud error", tokio_rustls::rustls::Error => "TLS error", rumqttc::v5::ClientError => "MQTT error", + rumqttc::v5::OptionError => "MQTT error", } pub type ConnectorResult = std::result::Result; diff --git a/src/connector/src/sink/mqtt.rs b/src/connector/src/sink/mqtt.rs index 6442789b12617..040e62bfe9634 100644 --- a/src/connector/src/sink/mqtt.rs +++ b/src/connector/src/sink/mqtt.rs @@ -22,7 +22,7 @@ use risingwave_common::catalog::Schema; use rumqttc::v5::mqttbytes::QoS; use rumqttc::v5::ConnectionError; use serde_derive::Deserialize; -use serde_with::{serde_as, DisplayFromStr}; +use serde_with::serde_as; use thiserror_ext::AsReport; use with_options::WithOptions; @@ -30,7 +30,7 @@ use super::catalog::SinkFormatDesc; use super::formatter::SinkFormatterImpl; use super::writer::FormattedSink; use super::{DummySinkCommitCoordinator, SinkWriterParam}; -use crate::common::{MqttCommon, QualityOfService}; +use crate::common::MqttCommon; use crate::sink::catalog::desc::SinkDesc; use crate::sink::log_store::DeliveryFutureManagerAddFuture; use crate::sink::writer::{ @@ -47,11 +47,6 @@ pub struct MqttConfig { #[serde(flatten)] pub common: MqttCommon, - /// The quality of service to use when publishing messages. Defaults to at_most_once. - /// Could be at_most_once, at_least_once or exactly_once - #[serde_as(as = "Option")] - pub qos: Option, - /// Whether the message should be retained by the broker #[serde(default, deserialize_with = "deserialize_bool_from_string")] pub retain: bool, @@ -132,7 +127,7 @@ impl Sink for MqttSink { ))); } - let _client = (self.config.common.build_client(0)) + let _client = (self.config.common.build_client(0, 0)) .context("validate mqtt sink error") .map_err(SinkError::Mqtt)?; @@ -174,19 +169,11 @@ impl MqttSinkWriter { ) .await?; - let qos = config - .qos - .as_ref() - .map(|qos| match qos { - QualityOfService::AtMostOnce => QoS::AtMostOnce, - QualityOfService::AtLeastOnce => QoS::AtLeastOnce, - QualityOfService::ExactlyOnce => QoS::ExactlyOnce, - }) - .unwrap_or(QoS::AtMostOnce); + let qos = config.common.qos(); let (client, mut eventloop) = config .common - .build_client(id as u32) + .build_client(0, id as u32) .map_err(|e| SinkError::Mqtt(anyhow!(e)))?; let stopped = Arc::new(AtomicBool::new(false)); @@ -196,24 +183,23 @@ impl MqttSinkWriter { while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) { match eventloop.poll().await { Ok(_) => (), - Err(err) => { - if let ConnectionError::Timeout(_) = err { + Err(err) => match err { + ConnectionError::Timeout(_) => { continue; } - - if let ConnectionError::MqttState(rumqttc::v5::StateError::Io(err)) = err { - if err.kind() != std::io::ErrorKind::ConnectionAborted { - tracing::error!( - "Failed to poll mqtt eventloop: {}", - err.as_report() - ); - std::thread::sleep(std::time::Duration::from_secs(1)); - } - } else { + ConnectionError::MqttState(rumqttc::v5::StateError::Io(err)) + | ConnectionError::Io(err) + if err.kind() == std::io::ErrorKind::ConnectionAborted + || err.kind() == std::io::ErrorKind::ConnectionReset => + { + continue; + } + err => { + println!("Err: {:?}", err); tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report()); std::thread::sleep(std::time::Duration::from_secs(1)); } - } + }, } } }); diff --git a/src/connector/src/source/mqtt/enumerator/mod.rs b/src/connector/src/source/mqtt/enumerator/mod.rs index 5cfd952ab0121..1013f31e07d5e 100644 --- a/src/connector/src/source/mqtt/enumerator/mod.rs +++ b/src/connector/src/source/mqtt/enumerator/mod.rs @@ -45,7 +45,7 @@ impl SplitEnumerator for MqttSplitEnumerator { properties: Self::Properties, context: SourceEnumeratorContextRef, ) -> ConnectorResult { - let (client, mut eventloop) = properties.common.build_client(context.info.source_id)?; + let (client, mut eventloop) = properties.common.build_client(context.info.source_id, 0)?; let topic = properties.common.topic.clone(); let mut topics = HashSet::new(); @@ -92,7 +92,7 @@ impl SplitEnumerator for MqttSplitEnumerator { continue; } tracing::error!( - "[Enumerator] Failed to subscribe to topic {}: {}", + "Failed to subscribe to topic {}: {}", topic, err.as_report(), ); @@ -127,7 +127,7 @@ impl SplitEnumerator for MqttSplitEnumerator { bail!("Failed to connect to mqtt broker"); } - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + tokio::time::sleep(std::time::Duration::from_millis(500)).await; } } diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs index 41e9036d11ff2..75af670c879c9 100644 --- a/src/connector/src/source/mqtt/source/reader.rs +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -21,7 +21,6 @@ use thiserror_ext::AsReport; use super::message::MqttMessage; use super::MqttSplit; -use crate::common::QualityOfService; use crate::error::ConnectorResult as Result; use crate::parser::ParserConfig; use crate::source::common::{into_chunk_stream, CommonSplitReader}; @@ -54,17 +53,9 @@ impl SplitReader for MqttSplitReader { ) -> Result { let (client, eventloop) = properties .common - .build_client(source_ctx.actor_id, source_ctx.fragment_id)?; + .build_client(source_ctx.source_info.actor_id, source_ctx.source_info.fragment_id)?; - let qos = properties - .qos - .as_ref() - .map(|qos| match qos { - QualityOfService::AtMostOnce => QoS::AtMostOnce, - QualityOfService::AtLeastOnce => QoS::AtLeastOnce, - QualityOfService::ExactlyOnce => QoS::ExactlyOnce, - }) - .unwrap_or(QoS::AtMostOnce); + let qos = properties.common.qos(); client .subscribe_many( diff --git a/src/connector/src/with_options.rs b/src/connector/src/with_options.rs index c233fb6e3a33b..1bdda0b484ce4 100644 --- a/src/connector/src/with_options.rs +++ b/src/connector/src/with_options.rs @@ -53,7 +53,6 @@ impl WithOptions for i64 {} impl WithOptions for f64 {} impl WithOptions for std::time::Duration {} impl WithOptions for crate::common::QualityOfService {} -impl WithOptions for crate::common::Protocol {} impl WithOptions for crate::sink::kafka::CompressionCodec {} impl WithOptions for nexmark::config::RateShape {} impl WithOptions for nexmark::event::EventType {} diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index a119224a8d580..78aaa950f1fb2 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -352,22 +352,18 @@ KinesisSinkConfig: alias: kinesis.assumerole.external_id MqttConfig: fields: - - name: protocol - field_type: Protocol - comments: Protocol used for RisingWave to communicate with the mqtt brokers. Could be `tcp` or `ssl`, defaults to `tcp` - required: false - - name: host + - name: url field_type: String - comments: Hostname of the mqtt broker + comments: The url of the broker to connect to. e.g. tcp://localhost. Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, `ws://` or `wss://` to denote the protocol for establishing a connection with the broker. `mqtts://`, `ssl://`, `wss://` required: true - - name: port - field_type: i32 - comments: Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl - required: false - name: topic field_type: String comments: The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# required: true + - name: qos + field_type: QualityOfService + comments: The quality of service to use when publishing messages. Defaults to at_most_once. Could be at_most_once, at_least_once or exactly_once + required: false - name: username field_type: String comments: Username for the mqtt broker @@ -401,10 +397,6 @@ MqttConfig: field_type: String comments: Path to client's private key file (PEM). Required for client authentication. Can be a file path under fs:// or a string with the private key content. required: false - - name: qos - field_type: QualityOfService - comments: The quality of service to use when publishing messages. Defaults to at_most_once. Could be at_most_once, at_least_once or exactly_once - required: false - name: retain field_type: bool comments: Whether the message should be retained by the broker diff --git a/src/connector/with_options_source.yaml b/src/connector/with_options_source.yaml index 2ae218d7e2a39..d3c8639d6748d 100644 --- a/src/connector/with_options_source.yaml +++ b/src/connector/with_options_source.yaml @@ -257,22 +257,18 @@ KinesisProperties: alias: kinesis.assumerole.external_id MqttProperties: fields: - - name: protocol - field_type: Protocol - comments: Protocol used for RisingWave to communicate with the mqtt brokers. Could be `tcp` or `ssl`, defaults to `tcp` - required: false - - name: host + - name: url field_type: String - comments: Hostname of the mqtt broker + comments: The url of the broker to connect to. e.g. tcp://localhost. Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, `ws://` or `wss://` to denote the protocol for establishing a connection with the broker. `mqtts://`, `ssl://`, `wss://` required: true - - name: port - field_type: i32 - comments: Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl - required: false - name: topic field_type: String comments: The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# required: true + - name: qos + field_type: QualityOfService + comments: The quality of service to use when publishing messages. Defaults to at_most_once. Could be at_most_once, at_least_once or exactly_once + required: false - name: username field_type: String comments: Username for the mqtt broker From 81e208ce89ccfbc7d6d9d865764a933d5971fcc9 Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Tue, 5 Mar 2024 16:43:35 -0500 Subject: [PATCH 08/12] chore: Remove println --- src/connector/src/sink/mqtt.rs | 2 -- src/connector/src/source/mqtt/source/reader.rs | 7 ++++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/connector/src/sink/mqtt.rs b/src/connector/src/sink/mqtt.rs index 040e62bfe9634..ab162b4a1a621 100644 --- a/src/connector/src/sink/mqtt.rs +++ b/src/connector/src/sink/mqtt.rs @@ -178,7 +178,6 @@ impl MqttSinkWriter { let stopped = Arc::new(AtomicBool::new(false)); let stopped_clone = stopped.clone(); - tokio::spawn(async move { while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) { match eventloop.poll().await { @@ -195,7 +194,6 @@ impl MqttSinkWriter { continue; } err => { - println!("Err: {:?}", err); tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report()); std::thread::sleep(std::time::Duration::from_secs(1)); } diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs index 75af670c879c9..09e74763e416a 100644 --- a/src/connector/src/source/mqtt/source/reader.rs +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -51,9 +51,10 @@ impl SplitReader for MqttSplitReader { source_ctx: SourceContextRef, _columns: Option>, ) -> Result { - let (client, eventloop) = properties - .common - .build_client(source_ctx.source_info.actor_id, source_ctx.source_info.fragment_id)?; + let (client, eventloop) = properties.common.build_client( + source_ctx.source_info.actor_id, + source_ctx.source_info.fragment_id, + )?; let qos = properties.common.qos(); From 801c6104f0cc277b1afccf51bc467adc902a7e53 Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Tue, 5 Mar 2024 17:43:36 -0500 Subject: [PATCH 09/12] feat: Avoid truncating executor_id --- src/connector/src/common.rs | 15 +++++++-------- src/connector/src/sink/mqtt.rs | 6 ++---- src/connector/src/source/mqtt/source/reader.rs | 2 +- src/connector/with_options_sink.yaml | 4 ++-- src/connector/with_options_source.yaml | 4 ++-- 5 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/connector/src/common.rs b/src/connector/src/common.rs index 5600097f9c7ca..fd0e125b12a03 100644 --- a/src/connector/src/common.rs +++ b/src/connector/src/common.rs @@ -700,8 +700,8 @@ pub enum QualityOfService { pub struct MqttCommon { /// The url of the broker to connect to. e.g. tcp://localhost. /// Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, - /// `ws://` or `wss://` to denote the protocol for establishing a connection with the broker. - /// `mqtts://`, `ssl://`, `wss://` + /// to denote the protocol for establishing a connection with the broker. + /// `mqtts://`, `ssl://` will use the native certificates if no ca is specified pub url: String, /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# @@ -720,7 +720,7 @@ pub struct MqttCommon { pub password: Option, /// Prefix for the mqtt client id. - /// The client id will be generated as `client_prefix`_`id`_`timestamp`. Defaults to risingwave + /// The client id will be generated as `client_prefix`_`actor_id`_`(executor_id|source_id)`. Defaults to risingwave pub client_prefix: Option, /// `clean_start = true` removes all the state from queues & instructs the broker @@ -754,7 +754,7 @@ impl MqttCommon { pub(crate) fn build_client( &self, actor_id: u32, - id: u32, + id: u64, ) -> ConnectorResult<(rumqttc::v5::AsyncClient, rumqttc::v5::EventLoop)> { let client_id = format!( "{}_{}_{}", @@ -765,13 +765,12 @@ impl MqttCommon { let mut url = url::Url::parse(&self.url)?; - let ssl = match url.scheme() { - "mqtts" | "ssl" | "wss" => true, - _ => false, - }; + let ssl = matches!(url.scheme(), "mqtts" | "ssl"); url.query_pairs_mut().append_pair("client_id", &client_id); + tracing::debug!("connecting mqtt using url: {}", url.as_str()); + let mut options = rumqttc::v5::MqttOptions::try_from(url)?; options.set_keep_alive(std::time::Duration::from_secs(10)); diff --git a/src/connector/src/sink/mqtt.rs b/src/connector/src/sink/mqtt.rs index ab162b4a1a621..54c483c7076a0 100644 --- a/src/connector/src/sink/mqtt.rs +++ b/src/connector/src/sink/mqtt.rs @@ -173,7 +173,7 @@ impl MqttSinkWriter { let (client, mut eventloop) = config .common - .build_client(0, id as u32) + .build_client(0, id) .map_err(|e| SinkError::Mqtt(anyhow!(e)))?; let stopped = Arc::new(AtomicBool::new(false)); @@ -183,9 +183,7 @@ impl MqttSinkWriter { match eventloop.poll().await { Ok(_) => (), Err(err) => match err { - ConnectionError::Timeout(_) => { - continue; - } + ConnectionError::Timeout(_) => (), ConnectionError::MqttState(rumqttc::v5::StateError::Io(err)) | ConnectionError::Io(err) if err.kind() == std::io::ErrorKind::ConnectionAborted diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs index 09e74763e416a..74e315b6cd656 100644 --- a/src/connector/src/source/mqtt/source/reader.rs +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -53,7 +53,7 @@ impl SplitReader for MqttSplitReader { ) -> Result { let (client, eventloop) = properties.common.build_client( source_ctx.source_info.actor_id, - source_ctx.source_info.fragment_id, + source_ctx.source_info.fragment_id as u64, )?; let qos = properties.common.qos(); diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index 78aaa950f1fb2..f9d459fddfd9c 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -354,7 +354,7 @@ MqttConfig: fields: - name: url field_type: String - comments: The url of the broker to connect to. e.g. tcp://localhost. Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, `ws://` or `wss://` to denote the protocol for establishing a connection with the broker. `mqtts://`, `ssl://`, `wss://` + comments: The url of the broker to connect to. e.g. tcp://localhost. Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, to denote the protocol for establishing a connection with the broker. `mqtts://`, `ssl://` will use the native certificates if no ca is specified required: true - name: topic field_type: String @@ -374,7 +374,7 @@ MqttConfig: required: false - name: client_prefix field_type: String - comments: Prefix for the mqtt client id. The client id will be generated as `client_prefix`_`id`_`timestamp`. Defaults to risingwave + comments: Prefix for the mqtt client id. The client id will be generated as `client_prefix`_`actor_id`_`(executor_id|source_id)`. Defaults to risingwave required: false - name: clean_start field_type: bool diff --git a/src/connector/with_options_source.yaml b/src/connector/with_options_source.yaml index d3c8639d6748d..a02f3b2168650 100644 --- a/src/connector/with_options_source.yaml +++ b/src/connector/with_options_source.yaml @@ -259,7 +259,7 @@ MqttProperties: fields: - name: url field_type: String - comments: The url of the broker to connect to. e.g. tcp://localhost. Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, `ws://` or `wss://` to denote the protocol for establishing a connection with the broker. `mqtts://`, `ssl://`, `wss://` + comments: The url of the broker to connect to. e.g. tcp://localhost. Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, to denote the protocol for establishing a connection with the broker. `mqtts://`, `ssl://` will use the native certificates if no ca is specified required: true - name: topic field_type: String @@ -279,7 +279,7 @@ MqttProperties: required: false - name: client_prefix field_type: String - comments: Prefix for the mqtt client id. The client id will be generated as `client_prefix`_`id`_`timestamp`. Defaults to risingwave + comments: Prefix for the mqtt client id. The client id will be generated as `client_prefix`_`actor_id`_`(executor_id|source_id)`. Defaults to risingwave required: false - name: clean_start field_type: bool From 4a8dbe8938b0f451e9859c7dd8957126232564d8 Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Wed, 6 Mar 2024 09:33:08 -0500 Subject: [PATCH 10/12] feat: Address PR comments --- .typos.toml | 1 + integration_tests/mqtt/create_sink.sql | 28 ++++ integration_tests/mqtt/create_source.sql | 30 ---- integration_tests/mqtt/docker-compose.yml | 6 +- integration_tests/mqtt/sink_check.py | 14 ++ src/connector/src/common.rs | 165 +--------------------- src/connector/src/lib.rs | 1 + src/connector/src/mqtt_common.rs | 164 +++++++++++++++++++++ src/connector/src/sink/mqtt.rs | 4 +- src/connector/src/source/mqtt/mod.rs | 2 +- src/connector/src/with_options.rs | 2 +- src/frontend/src/handler/create_sink.rs | 6 +- 12 files changed, 227 insertions(+), 196 deletions(-) create mode 100644 integration_tests/mqtt/create_sink.sql create mode 100644 integration_tests/mqtt/sink_check.py create mode 100644 src/connector/src/mqtt_common.rs diff --git a/.typos.toml b/.typos.toml index b19d0a08c541d..052a051bd1e00 100644 --- a/.typos.toml +++ b/.typos.toml @@ -9,6 +9,7 @@ steam = "stream" # You played with Steam games too much. # Some weird short variable names ot = "ot" bui = "bui" +mosquitto = "mosquitto" # This is a MQTT broker. [default.extend-identifiers] diff --git a/integration_tests/mqtt/create_sink.sql b/integration_tests/mqtt/create_sink.sql new file mode 100644 index 0000000000000..69b6886943944 --- /dev/null +++ b/integration_tests/mqtt/create_sink.sql @@ -0,0 +1,28 @@ +CREATE SINK mqtt_sink +FROM + personnel +WITH +( + connector='mqtt', + url='tcp://mqtt-server', + topic= 'test', + type = 'append-only', + retain = 'true', + qos = 'at_least_once', +) FORMAT PLAIN ENCODE JSON ( + force_append_only='true', +); + +INSERT INTO + personnel +VALUES + (1, 'Alice'), + (2, 'Bob'), + (3, 'Tom'), + (4, 'Jerry'), + (5, 'Araminta'), + (6, 'Clover'), + (7, 'Posey'), + (8, 'Waverly'); + +FLUSH; \ No newline at end of file diff --git a/integration_tests/mqtt/create_source.sql b/integration_tests/mqtt/create_source.sql index 8c63216125280..068d7e0a6cb46 100644 --- a/integration_tests/mqtt/create_source.sql +++ b/integration_tests/mqtt/create_source.sql @@ -12,33 +12,3 @@ WITH ( topic= 'test', qos = 'at_least_once', ) FORMAT PLAIN ENCODE JSON; - - -CREATE SINK mqtt_sink -FROM - personnel -WITH - ( - connector='mqtt', - url='tcp://mqtt-server', - topic= 'test', - type = 'append-only', - retain = 'false', - qos = 'at_least_once', - ) FORMAT PLAIN ENCODE JSON ( - force_append_only='true', - ); - -INSERT INTO - personnel -VALUES - (1, 'Alice'), - (2, 'Bob'), - (3, 'Tom'), - (4, 'Jerry'), - (5, 'Araminta'), - (6, 'Clover'), - (7, 'Posey'), - (8, 'Waverly'); - -FLUSH; diff --git a/integration_tests/mqtt/docker-compose.yml b/integration_tests/mqtt/docker-compose.yml index 87969f8ad9044..9db7e7c04f8fc 100644 --- a/integration_tests/mqtt/docker-compose.yml +++ b/integration_tests/mqtt/docker-compose.yml @@ -6,7 +6,11 @@ services: file: ../../docker/docker-compose.yml service: risingwave-standalone mqtt-server: - image: emqx/emqx:5.2.1 + image: eclipse-mosquitto + command: + - sh + - -c + - echo "running command"; printf 'allow_anonymous true\nlistener 1883 0.0.0.0' > /mosquitto/config/mosquitto.conf; echo "starting service..."; cat /mosquitto/config/mosquitto.conf;/docker-entrypoint.sh;/usr/sbin/mosquitto -c /mosquitto/config/mosquitto.conf ports: - 1883:1883 etcd-0: diff --git a/integration_tests/mqtt/sink_check.py b/integration_tests/mqtt/sink_check.py new file mode 100644 index 0000000000000..cb74a12e9fe29 --- /dev/null +++ b/integration_tests/mqtt/sink_check.py @@ -0,0 +1,14 @@ +import sys +import subprocess + + +output = subprocess.Popen(["docker", "compose", "exec", "mqtt-server", "mosquitto_sub", "-h", "localhost", "-t", "test", "-p", "1883", "-C", "1", "-W", "120"], + stdout=subprocess.PIPE) +rows = subprocess.check_output(["wc", "-l"], stdin=output.stdout) +output.stdout.close() +output.wait() +rows = int(rows.decode('utf8').strip()) +print(f"{rows} rows in 'test'") +if rows < 1: + print(f"Data check failed for case 'test'") + sys.exit(1) diff --git a/src/connector/src/common.rs b/src/connector/src/common.rs index fd0e125b12a03..4b198bf6caf1b 100644 --- a/src/connector/src/common.rs +++ b/src/connector/src/common.rs @@ -28,17 +28,16 @@ use risingwave_common::bail; use serde_derive::Deserialize; use serde_with::json::JsonString; use serde_with::{serde_as, DisplayFromStr}; -use strum_macros::{Display, EnumString}; use tempfile::NamedTempFile; use time::OffsetDateTime; use url::Url; use with_options::WithOptions; use crate::aws_utils::load_file_descriptor_from_s3; +use crate::deserialize_duration_from_string; use crate::error::ConnectorResult; use crate::sink::SinkError; use crate::source::nats::source::NatsOffset; -use crate::{deserialize_bool_from_string, deserialize_duration_from_string}; // The file describes the common abstractions for each connector and can be used in both source and // sink. @@ -686,161 +685,9 @@ impl NatsCommon { } } -#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)] -#[strum(serialize_all = "snake_case")] -#[allow(clippy::enum_variant_names)] -pub enum QualityOfService { - AtLeastOnce, - AtMostOnce, - ExactlyOnce, -} - -#[serde_as] -#[derive(Deserialize, Debug, Clone, WithOptions)] -pub struct MqttCommon { - /// The url of the broker to connect to. e.g. tcp://localhost. - /// Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, - /// to denote the protocol for establishing a connection with the broker. - /// `mqtts://`, `ssl://` will use the native certificates if no ca is specified - pub url: String, - - /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# - pub topic: String, - - /// The quality of service to use when publishing messages. Defaults to at_most_once. - /// Could be at_most_once, at_least_once or exactly_once - #[serde_as(as = "Option")] - pub qos: Option, - - /// Username for the mqtt broker - #[serde(rename = "username")] - pub user: Option, - - /// Password for the mqtt broker - pub password: Option, - - /// Prefix for the mqtt client id. - /// The client id will be generated as `client_prefix`_`actor_id`_`(executor_id|source_id)`. Defaults to risingwave - pub client_prefix: Option, - - /// `clean_start = true` removes all the state from queues & instructs the broker - /// to clean all the client state when client disconnects. - /// - /// When set `false`, broker will hold the client state and performs pending - /// operations on the client when reconnection with same `client_id` - /// happens. Local queue state is also held to retransmit packets after reconnection. - #[serde(default, deserialize_with = "deserialize_bool_from_string")] - pub clean_start: bool, - - /// The maximum number of inflight messages. Defaults to 100 - #[serde_as(as = "Option")] - pub inflight_messages: Option, - - /// Path to CA certificate file for verifying the broker's key. - #[serde(rename = "tls.client_key")] - pub ca: Option, - /// Path to client's certificate file (PEM). Required for client authentication. - /// Can be a file path under fs:// or a string with the certificate content. - #[serde(rename = "tls.client_cert")] - pub client_cert: Option, - - /// Path to client's private key file (PEM). Required for client authentication. - /// Can be a file path under fs:// or a string with the private key content. - #[serde(rename = "tls.client_key")] - pub client_key: Option, -} - -impl MqttCommon { - pub(crate) fn build_client( - &self, - actor_id: u32, - id: u64, - ) -> ConnectorResult<(rumqttc::v5::AsyncClient, rumqttc::v5::EventLoop)> { - let client_id = format!( - "{}_{}_{}", - self.client_prefix.as_deref().unwrap_or("risingwave"), - actor_id, - id - ); - - let mut url = url::Url::parse(&self.url)?; - - let ssl = matches!(url.scheme(), "mqtts" | "ssl"); - - url.query_pairs_mut().append_pair("client_id", &client_id); - - tracing::debug!("connecting mqtt using url: {}", url.as_str()); - - let mut options = rumqttc::v5::MqttOptions::try_from(url)?; - options.set_keep_alive(std::time::Duration::from_secs(10)); - - options.set_clean_start(self.clean_start); - - if ssl { - let tls_config = self.get_tls_config()?; - options.set_transport(rumqttc::Transport::tls_with_config( - rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)), - )); - } - - if let Some(user) = &self.user { - options.set_credentials(user, self.password.as_deref().unwrap_or_default()); - } - - Ok(rumqttc::v5::AsyncClient::new( - options, - self.inflight_messages.unwrap_or(100), - )) - } - - pub(crate) fn qos(&self) -> rumqttc::v5::mqttbytes::QoS { - self.qos - .as_ref() - .map(|qos| match qos { - QualityOfService::AtMostOnce => rumqttc::v5::mqttbytes::QoS::AtMostOnce, - QualityOfService::AtLeastOnce => rumqttc::v5::mqttbytes::QoS::AtLeastOnce, - QualityOfService::ExactlyOnce => rumqttc::v5::mqttbytes::QoS::ExactlyOnce, - }) - .unwrap_or(rumqttc::v5::mqttbytes::QoS::AtMostOnce) - } - - fn get_tls_config(&self) -> ConnectorResult { - let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); - if let Some(ca) = &self.ca { - let certificates = load_certs(ca)?; - for cert in certificates { - root_cert_store.add(&cert).unwrap(); - } - } else { - for cert in - rustls_native_certs::load_native_certs().expect("could not load platform certs") - { - root_cert_store - .add(&tokio_rustls::rustls::Certificate(cert.0)) - .unwrap(); - } - } - - let builder = tokio_rustls::rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store); - - let tls_config = if let (Some(client_cert), Some(client_key)) = - (self.client_cert.as_ref(), self.client_key.as_ref()) - { - let certs = load_certs(client_cert)?; - let key = load_private_key(client_key)?; - - builder.with_client_auth_cert(certs, key)? - } else { - builder.with_no_client_auth() - }; - - Ok(tls_config) - } -} - -fn load_certs(certificates: &str) -> ConnectorResult> { +pub(crate) fn load_certs( + certificates: &str, +) -> ConnectorResult> { let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") { std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())? } else { @@ -855,7 +702,9 @@ fn load_certs(certificates: &str) -> ConnectorResult ConnectorResult { +pub(crate) fn load_private_key( + certificate: &str, +) -> ConnectorResult { let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") { std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())? } else { diff --git a/src/connector/src/lib.rs b/src/connector/src/lib.rs index 4a437755c5185..bbbf8fa5c3b69 100644 --- a/src/connector/src/lib.rs +++ b/src/connector/src/lib.rs @@ -52,6 +52,7 @@ pub mod sink; pub mod source; pub mod common; +pub mod mqtt_common; pub use paste::paste; diff --git a/src/connector/src/mqtt_common.rs b/src/connector/src/mqtt_common.rs new file mode 100644 index 0000000000000..d66f534a2e16f --- /dev/null +++ b/src/connector/src/mqtt_common.rs @@ -0,0 +1,164 @@ +use rumqttc::v5::mqttbytes::QoS; +use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions}; +use serde_derive::Deserialize; +use serde_with::{serde_as, DisplayFromStr}; +use strum_macros::{Display, EnumString}; +use with_options::WithOptions; + +use crate::common::{load_certs, load_private_key}; +use crate::deserialize_bool_from_string; +use crate::error::ConnectorResult; + +#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)] +#[strum(serialize_all = "snake_case")] +#[allow(clippy::enum_variant_names)] +pub enum QualityOfService { + AtLeastOnce, + AtMostOnce, + ExactlyOnce, +} + +#[serde_as] +#[derive(Deserialize, Debug, Clone, WithOptions)] +pub struct MqttCommon { + /// The url of the broker to connect to. e.g. tcp://localhost. + /// Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, + /// to denote the protocol for establishing a connection with the broker. + /// `mqtts://`, `ssl://` will use the native certificates if no ca is specified + pub url: String, + + /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# + pub topic: String, + + /// The quality of service to use when publishing messages. Defaults to at_most_once. + /// Could be at_most_once, at_least_once or exactly_once + #[serde_as(as = "Option")] + pub qos: Option, + + /// Username for the mqtt broker + #[serde(rename = "username")] + pub user: Option, + + /// Password for the mqtt broker + pub password: Option, + + /// Prefix for the mqtt client id. + /// The client id will be generated as `client_prefix`_`actor_id`_`(executor_id|source_id)`. Defaults to risingwave + pub client_prefix: Option, + + /// `clean_start = true` removes all the state from queues & instructs the broker + /// to clean all the client state when client disconnects. + /// + /// When set `false`, broker will hold the client state and performs pending + /// operations on the client when reconnection with same `client_id` + /// happens. Local queue state is also held to retransmit packets after reconnection. + #[serde(default, deserialize_with = "deserialize_bool_from_string")] + pub clean_start: bool, + + /// The maximum number of inflight messages. Defaults to 100 + #[serde_as(as = "Option")] + pub inflight_messages: Option, + + /// Path to CA certificate file for verifying the broker's key. + #[serde(rename = "tls.client_key")] + pub ca: Option, + /// Path to client's certificate file (PEM). Required for client authentication. + /// Can be a file path under fs:// or a string with the certificate content. + #[serde(rename = "tls.client_cert")] + pub client_cert: Option, + + /// Path to client's private key file (PEM). Required for client authentication. + /// Can be a file path under fs:// or a string with the private key content. + #[serde(rename = "tls.client_key")] + pub client_key: Option, +} + +impl MqttCommon { + pub(crate) fn build_client( + &self, + actor_id: u32, + id: u64, + ) -> ConnectorResult<(AsyncClient, EventLoop)> { + let client_id = format!( + "{}_{}_{}", + self.client_prefix.as_deref().unwrap_or("risingwave"), + actor_id, + id + ); + + let mut url = url::Url::parse(&self.url)?; + + let ssl = matches!(url.scheme(), "mqtts" | "ssl"); + + url.query_pairs_mut().append_pair("client_id", &client_id); + + tracing::debug!("connecting mqtt using url: {}", url.as_str()); + + let mut options = MqttOptions::try_from(url)?; + options.set_keep_alive(std::time::Duration::from_secs(10)); + + options.set_clean_start(self.clean_start); + + if ssl { + let tls_config = self.get_tls_config()?; + options.set_transport(rumqttc::Transport::tls_with_config( + rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)), + )); + } + + if let Some(user) = &self.user { + options.set_credentials(user, self.password.as_deref().unwrap_or_default()); + } + + Ok(rumqttc::v5::AsyncClient::new( + options, + self.inflight_messages.unwrap_or(100), + )) + } + + pub(crate) fn qos(&self) -> QoS { + self.qos + .as_ref() + .map(|qos| match qos { + QualityOfService::AtMostOnce => QoS::AtMostOnce, + QualityOfService::AtLeastOnce => QoS::AtLeastOnce, + QualityOfService::ExactlyOnce => QoS::ExactlyOnce, + }) + .unwrap_or(QoS::AtMostOnce) + } + + fn get_tls_config(&self) -> ConnectorResult { + let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); + if let Some(ca) = &self.ca { + let certificates = load_certs(ca)?; + for cert in certificates { + root_cert_store.add(&cert).unwrap(); + } + } else { + for cert in + rustls_native_certs::load_native_certs().expect("could not load platform certs") + { + root_cert_store + .add(&tokio_rustls::rustls::Certificate(cert.0)) + .unwrap(); + } + } + + let builder = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store); + + let tls_config = if let (Some(client_cert), Some(client_key)) = + (self.client_cert.as_ref(), self.client_key.as_ref()) + { + let certs = load_certs(client_cert)?; + let key = load_private_key(client_key)?; + + builder.with_client_auth_cert(certs, key)? + } else { + builder.with_no_client_auth() + }; + + Ok(tls_config) + } +} diff --git a/src/connector/src/sink/mqtt.rs b/src/connector/src/sink/mqtt.rs index 54c483c7076a0..1aebdf4f70062 100644 --- a/src/connector/src/sink/mqtt.rs +++ b/src/connector/src/sink/mqtt.rs @@ -30,7 +30,7 @@ use super::catalog::SinkFormatDesc; use super::formatter::SinkFormatterImpl; use super::writer::FormattedSink; use super::{DummySinkCommitCoordinator, SinkWriterParam}; -use crate::common::MqttCommon; +use crate::mqtt_common::MqttCommon; use crate::sink::catalog::desc::SinkDesc; use crate::sink::log_store::DeliveryFutureManagerAddFuture; use crate::sink::writer::{ @@ -123,7 +123,7 @@ impl Sink for MqttSink { async fn validate(&self) -> Result<()> { if !self.is_append_only { return Err(SinkError::Mqtt(anyhow!( - "Nats sink only support append-only mode" + "Mqtt sink only support append-only mode" ))); } diff --git a/src/connector/src/source/mqtt/mod.rs b/src/connector/src/source/mqtt/mod.rs index 033e8585f57f8..aec17f0454f18 100644 --- a/src/connector/src/source/mqtt/mod.rs +++ b/src/connector/src/source/mqtt/mod.rs @@ -22,7 +22,7 @@ use serde_derive::Deserialize; use serde_with::{serde_as, DisplayFromStr}; use with_options::WithOptions; -use crate::common::{MqttCommon, QualityOfService}; +use crate::mqtt_common::{MqttCommon, QualityOfService}; use crate::source::mqtt::enumerator::MqttSplitEnumerator; use crate::source::mqtt::source::{MqttSplit, MqttSplitReader}; use crate::source::SourceProperties; diff --git a/src/connector/src/with_options.rs b/src/connector/src/with_options.rs index 1bdda0b484ce4..a5c810834727a 100644 --- a/src/connector/src/with_options.rs +++ b/src/connector/src/with_options.rs @@ -52,7 +52,7 @@ impl WithOptions for i32 {} impl WithOptions for i64 {} impl WithOptions for f64 {} impl WithOptions for std::time::Duration {} -impl WithOptions for crate::common::QualityOfService {} +impl WithOptions for crate::mqtt_common::QualityOfService {} impl WithOptions for crate::sink::kafka::CompressionCodec {} impl WithOptions for nexmark::config::RateShape {} impl WithOptions for nexmark::event::EventType {} diff --git a/src/frontend/src/handler/create_sink.rs b/src/frontend/src/handler/create_sink.rs index d743fb2d3f235..e720c62b53042 100644 --- a/src/frontend/src/handler/create_sink.rs +++ b/src/frontend/src/handler/create_sink.rs @@ -776,9 +776,9 @@ static CONNECTORS_COMPATIBLE_FORMATS: LazyLock vec![Encode::Json], ), MqttSink::SINK_NAME => hashmap!( - Format::Plain => vec![Encode::Json,Encode::Bytes], - Format::Upsert => vec![Encode::Json,Encode::Bytes], - Format::Debezium => vec![Encode::Json,Encode::Bytes], + Format::Plain => vec![Encode::Json], + Format::Upsert => vec![Encode::Json], + Format::Debezium => vec![Encode::Json], ), PulsarSink::SINK_NAME => hashmap!( Format::Plain => vec![Encode::Json], From 45f2c206b7b4a7553c7a5f903973508264c6df5a Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Wed, 6 Mar 2024 09:41:05 -0500 Subject: [PATCH 11/12] chore: Include license for new mqtt_common file --- src/connector/src/mqtt_common.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/connector/src/mqtt_common.rs b/src/connector/src/mqtt_common.rs index d66f534a2e16f..1c9a83f6787f6 100644 --- a/src/connector/src/mqtt_common.rs +++ b/src/connector/src/mqtt_common.rs @@ -1,3 +1,17 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use rumqttc::v5::mqttbytes::QoS; use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions}; use serde_derive::Deserialize; From 98f478286bea81606b1bd59359abdd65b5d67bb1 Mon Sep 17 00:00:00 2001 From: Gio Gutierrez Date: Wed, 6 Mar 2024 10:13:58 -0500 Subject: [PATCH 12/12] fix: Update with_options to include mqtt_common.rs --- src/connector/src/source/mqtt/source/reader.rs | 11 ++++------- src/connector/src/with_options_test.rs | 5 +++++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs index 74e315b6cd656..50f90c816390c 100644 --- a/src/connector/src/source/mqtt/source/reader.rs +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -25,9 +25,7 @@ use crate::error::ConnectorResult as Result; use crate::parser::ParserConfig; use crate::source::common::{into_chunk_stream, CommonSplitReader}; use crate::source::mqtt::MqttProperties; -use crate::source::{ - self, BoxChunkSourceStream, Column, SourceContextRef, SourceMessage, SplitReader, -}; +use crate::source::{BoxChunkSourceStream, Column, SourceContextRef, SourceMessage, SplitReader}; pub struct MqttSplitReader { eventloop: rumqttc::v5::EventLoop, @@ -51,10 +49,9 @@ impl SplitReader for MqttSplitReader { source_ctx: SourceContextRef, _columns: Option>, ) -> Result { - let (client, eventloop) = properties.common.build_client( - source_ctx.source_info.actor_id, - source_ctx.source_info.fragment_id as u64, - )?; + let (client, eventloop) = properties + .common + .build_client(source_ctx.actor_id, source_ctx.fragment_id as u64)?; let qos = properties.common.qos(); diff --git a/src/connector/src/with_options_test.rs b/src/connector/src/with_options_test.rs index 4ead1685244d8..fd234e880e469 100644 --- a/src/connector/src/with_options_test.rs +++ b/src/connector/src/with_options_test.rs @@ -38,6 +38,10 @@ fn common_mod_path() -> PathBuf { connector_crate_path().join("src").join("common.rs") } +fn mqtt_common_mod_path() -> PathBuf { + connector_crate_path().join("src").join("mqtt_common.rs") +} + pub fn generate_with_options_yaml_source() -> String { generate_with_options_yaml_inner(&source_mod_path()) } @@ -63,6 +67,7 @@ fn generate_with_options_yaml_inner(path: &Path) -> String { for entry in walkdir::WalkDir::new(path) .into_iter() .chain(walkdir::WalkDir::new(common_mod_path())) + .chain(walkdir::WalkDir::new(mqtt_common_mod_path())) { let entry = entry.expect("Failed to read directory entry"); if entry.path().extension() == Some("rs".as_ref()) {