diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index b84e293..bf2e210 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -31,4 +31,12 @@ jobs: run: cargo test --test integration env: ENVIRONMENT: ${{ inputs.environment }} - TEST_PROJECT_ID: ${{ secrets.TEST_PROJECT_ID }} \ No newline at end of file + TEST_PROJECT_ID: ${{ secrets.TEST_PROJECT_ID }} + CAST_PROJECT_SECRET: ${{ secrets.CAST_PROJECT_SECRET }} + + # Run validate swift + - name: Run validate swift + uses: ./.github/workflows/validate_swift.yml + with: + notify-endpoint: ${{ inputs.environment == 'PROD' && 'cast.walletconnect.com' || 'staging.cast.walletconnect.com' }} + relay-endpoint: ${{ inputs.environment == 'PROD' && 'relay.walletconnect.com' || 'staging.relay.walletconnect.com' }} diff --git a/Cargo.lock b/Cargo.lock index d818773..06cea53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -482,6 +482,7 @@ dependencies = [ "bitflags", "bytes", "futures-util", + "headers", "http", "http-body", "hyper", @@ -766,10 +767,12 @@ dependencies = [ "bs58", "build-info", "build-info-build", + "cerberus", "chacha20poly1305", "chrono", "dashmap", "data-encoding", + "deadpool-redis", "derive_more", "dotenv", "ed25519-dalek 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -796,6 +799,7 @@ dependencies = [ "relay_rpc", "reqwest", "ring", + "rmp-serde", "serde", "serde_bson", "serde_json", @@ -827,6 +831,19 @@ dependencies = [ "jobserver", ] +[[package]] +name = "cerberus" +version = "0.2.0" +source = "git+https://github.com/WalletConnect/cerberus.git#2c46f9575011258bc6eb9c3866720ef7acdfdca9" +dependencies = [ + "async-trait", + "once_cell", + "regex", + "reqwest", + "serde", + "thiserror", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -899,6 +916,20 @@ name = "collections" version = "0.1.0" source = "git+https://github.com/WalletConnect/utils-rs.git?tag=v0.1.0#69d982a1269dbc6649e153e78f486e97d6a60815" +[[package]] +name = "combine" +version = "4.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "const-oid" version = "0.9.2" @@ -1159,6 +1190,38 @@ version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23d8666cb01533c39dde32bcbab8e227b4ed6679b2c925eba05feabea39508fb" +[[package]] +name = "deadpool" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "421fe0f90f2ab22016f32a9881be5134fdd71c65298917084b0c7477cbc3856e" +dependencies = [ + "async-trait", + "deadpool-runtime", + "num_cpus", + "retain_mut", + "tokio", +] + +[[package]] +name = "deadpool-redis" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f1760f60ffc6653b4afd924c5792098d8c00d9a3deb6b3d989eac17949dc422" +dependencies = [ + "deadpool", + "redis", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaa37046cc0f6c3cc6090fbdbf73ef0b8ef4cfcc37f6befc0020f63e8cf121e1" +dependencies = [ + "tokio", +] + [[package]] name = "der" version = "0.7.6" @@ -1637,6 +1700,31 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +[[package]] +name = "headers" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584" +dependencies = [ + "base64 0.13.1", + "bitflags", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" +dependencies = [ + "http", +] + [[package]] name = "heck" version = "0.4.1" @@ -2851,6 +2939,25 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "redis" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ea8c51b5dc1d8e5fd3350ec8167f464ec0995e79f2e90a075b63371500d557f" +dependencies = [ + "async-trait", + "bytes", + "combine", + "futures-util", + "itoa", + "percent-encoding", + "pin-project-lite", + "ryu", + "tokio", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -2994,6 +3101,12 @@ dependencies = [ "quick-error", ] +[[package]] +name = "retain_mut" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4389f1d5789befaf6029ebd9f7dac4af7f7e3d61b69d4f30e2ac02b57e7712b0" + [[package]] name = "rfc6979" version = "0.4.0" @@ -3019,6 +3132,28 @@ dependencies = [ "winapi", ] +[[package]] +name = "rmp" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44519172358fd6d58656c86ab8e7fbc9e1490c3e8f14d35ed78ca0dd07403c9f" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5b13be192e0220b8afb7222aa5813cb62cc269ebb5cac346ca6487681d2913e" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rustc_version" version = "0.2.3" diff --git a/Cargo.toml b/Cargo.toml index b119d2e..6a83ce5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,9 +9,10 @@ build = "build.rs" [dependencies] wc = { git = "https://github.com/WalletConnect/utils-rs.git", tag = "v0.1.0", features = ["full"] } +cerberus = { git = "https://github.com/WalletConnect/cerberus.git"} tokio = { version = "1", features = ["full"] } -axum = { version = "0.6", features = ["json"] } +axum = { version = "0.6", features = ["json", "headers"] } tower = "0.4" tower-http = { version = "0.3", features = ["trace", "cors"] } hyper = "0.14" @@ -90,6 +91,8 @@ pnet_datalink = "0.33.0" ipnet = "2.8.0" once_cell = "1.18.0" lazy_static = "1.4.0" +rmp-serde = "1.1.1" +deadpool-redis = "0.12.0" [dev-dependencies] test-context = "0.1" diff --git a/src/config.rs b/src/config.rs index d2e68b5..99db78a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,5 @@ use { - crate::networking, + crate::{networking, storage::redis::Addr as RedisAddr}, serde::Deserialize, std::{net::IpAddr, str::FromStr}, }; @@ -17,6 +17,14 @@ pub struct Configuration { pub project_id: String, pub relay_url: String, pub cast_url: String, + + pub registry_url: String, + pub registry_auth_token: String, + + pub auth_redis_addr_read: Option, + pub auth_redis_addr_write: Option, + pub redis_pool_size: u32, + #[serde(default = "default_is_test", skip)] /// This is an internal flag to disable logging, cannot be defined by user pub is_test: bool, @@ -41,6 +49,13 @@ impl Configuration { pub fn log_level(&self) -> tracing::Level { tracing::Level::from_str(self.log_level.as_str()).expect("Invalid log level") } + + pub fn auth_redis_addr(&self) -> Option { + match (&self.auth_redis_addr_read, &self.auth_redis_addr_write) { + (None, None) => None, + (addr_read, addr_write) => Some(RedisAddr::from((addr_read, addr_write))), + } + } } fn default_port() -> u16 { diff --git a/src/error.rs b/src/error.rs index 1cc93fc..5b1b6bd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -125,6 +125,15 @@ pub enum Error { #[error(transparent)] Other(#[from] anyhow::Error), + + #[error(transparent)] + Redis(#[from] crate::storage::error::StorageError), + + #[error(transparent)] + InvalidHeaderValue(#[from] hyper::header::InvalidHeaderValue), + + #[error(transparent)] + ToStrError(#[from] hyper::header::ToStrError), } impl IntoResponse for Error { diff --git a/src/extractors.rs b/src/extractors.rs new file mode 100644 index 0000000..37b8eed --- /dev/null +++ b/src/extractors.rs @@ -0,0 +1,84 @@ +use { + crate::state::AppState, + async_trait::async_trait, + axum::{ + extract::{FromRequestParts, Path}, + headers::{authorization::Bearer, Authorization}, + http::request::Parts, + TypedHeader, + }, + hyper::StatusCode, + serde_json::json, + std::{collections::HashMap, sync::Arc}, + tracing::warn, +}; + +/// Extracts project_id from uri and project_secret from Authorization header. +/// Verifies their correctness against registry and returns AuthedProjectId +/// struct. +pub struct AuthedProjectId(pub String, pub String); + +#[async_trait] +impl FromRequestParts> for AuthedProjectId { + type Rejection = (StatusCode, String); + + async fn from_request_parts( + parts: &mut Parts, + state: &Arc, + ) -> Result { + let Path(path_args) = Path::>::from_request_parts(parts, state) + .await + .map_err(|_| { + ( + StatusCode::BAD_REQUEST, + json!({ + "reason": "Invalid project_id. Please make sure to include project_id in uri. " + }).to_string(), + ) + })?; + + let TypedHeader(project_secret) = TypedHeader::>::from_request_parts(parts, state).await.map_err(|_| { + ( + StatusCode::UNAUTHORIZED, + json!({ + "reason": "Unauthorized. Please make sure to include project secret in Authorization header. " + }).to_string(), + ) + })?; + + let project_id = path_args + .get("project_id") + .ok_or(( + StatusCode::BAD_REQUEST, + json!({"reason": "Invalid data for authentication".to_string()}).to_string(), + ))? + .to_string(); + + let authenticated = state + .registry + .is_authenticated(&project_id, project_secret.token()) + .await + .map_err(|e| { + warn!(?e, "Failed to authenticate project"); + ( + StatusCode::BAD_REQUEST, + "Invalid data for authentication".to_string(), + ) + })?; + + if !authenticated { + return Err(( + StatusCode::UNAUTHORIZED, + json!({ + "reason": "Invalid project_secret. Please make sure to include proper project secret in Authorization header." + }) + .to_string(), + )); + }; + + Ok(AuthedProjectId( + project_id, + project_secret.token().to_string(), + )) + } +} diff --git a/src/handlers/get_subscribers.rs b/src/handlers/get_subscribers.rs index 68abfa2..915f462 100644 --- a/src/handlers/get_subscribers.rs +++ b/src/handlers/get_subscribers.rs @@ -1,19 +1,14 @@ use { - crate::{error::Result, state::AppState, types::ClientData}, - axum::{ - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, - Json, - }, + crate::{error::Result, extractors::AuthedProjectId, state::AppState, types::ClientData}, + axum::{extract::State, http::StatusCode, response::IntoResponse, Json}, futures::TryStreamExt, log::info, std::sync::Arc, }; pub async fn handler( - Path(project_id): Path, State(state): State>, + AuthedProjectId(project_id, _): AuthedProjectId, ) -> Result { info!("Getting subscribers for project: {}", project_id); diff --git a/src/handlers/notify.rs b/src/handlers/notify.rs index c41a3b4..e88c4cb 100644 --- a/src/handlers/notify.rs +++ b/src/handlers/notify.rs @@ -2,12 +2,13 @@ use { crate::{ analytics::message_info::MessageInfo, error, + extractors::AuthedProjectId, jsonrpc::{JsonRpcParams, JsonRpcPayload}, state::AppState, types::{ClientData, Envelope, EnvelopeType0, Notification}, }, axum::{ - extract::{ConnectInfo, Path, State}, + extract::{ConnectInfo, State}, http::StatusCode, response::IntoResponse, Json, @@ -56,9 +57,9 @@ pub struct Response { } pub async fn handler( - Path(project_id): Path, State(state): State>, ConnectInfo(addr): ConnectInfo, + AuthedProjectId(project_id, _): AuthedProjectId, Json(cast_args): Json, ) -> Result { // Request id for logs diff --git a/src/handlers/subscribe_topic.rs b/src/handlers/subscribe_topic.rs index fa4283e..d0e25cd 100644 --- a/src/handlers/subscribe_topic.rs +++ b/src/handlers/subscribe_topic.rs @@ -1,11 +1,6 @@ use { - crate::state::AppState, - axum::{ - extract::{Path, State}, - response::IntoResponse, - Json, - }, - hyper::HeaderMap, + crate::{extractors::AuthedProjectId, state::AppState}, + axum::{self, extract::State, response::IntoResponse, Json}, log::info, mongodb::{bson::doc, options::ReplaceOptions}, rand::{rngs::StdRng, Rng}, @@ -27,57 +22,48 @@ pub struct ProjectData { } pub async fn handler( - headers: HeaderMap, - Path(project_id): Path, State(state): State>, + AuthedProjectId(project_id, project_secret): AuthedProjectId, ) -> Result { info!("Generating keypair for project: {}", project_id); let db = state.database.clone(); - match headers.get("Authorization") { - Some(project_secret) => { - let mut hasher = sha2::Sha256::new(); - hasher.update(project_secret.as_bytes()); - hasher.update(project_id.as_bytes()); - let seed = hasher.finalize(); + let mut hasher = sha2::Sha256::new(); + hasher.update(project_secret.as_bytes()); + hasher.update(project_id.as_bytes()); + let seed = hasher.finalize(); - let mut rng: StdRng = SeedableRng::from_seed(seed.into()); + let mut rng: StdRng = SeedableRng::from_seed(seed.into()); - let secret = StaticSecret::from(rng.gen::<[u8; 32]>()); - let public = PublicKey::from(&secret); + let secret = StaticSecret::from(rng.gen::<[u8; 32]>()); + let public = PublicKey::from(&secret); - let public_key = hex::encode(public.as_bytes()); + let public_key = hex::encode(public.as_bytes()); - let topic = sha256::digest(public.as_bytes()); - let project_data = ProjectData { - id: project_id.clone(), - private_key: hex::encode(secret.to_bytes()), - public_key: public_key.clone(), - topic: topic.clone(), - }; + let topic = sha256::digest(public.as_bytes()); + let project_data = ProjectData { + id: project_id.clone(), + private_key: hex::encode(secret.to_bytes()), + public_key: public_key.clone(), + topic: topic.clone(), + }; - info!( - "Saving project_info to database for project: {} with pubkey: {}", - project_id, public_key - ); + info!( + "Saving project_info to database for project: {} with pubkey: {}", + project_id, public_key + ); - db.collection::("project_data") - .replace_one( - doc! { "_id": project_id.clone()}, - project_data, - ReplaceOptions::builder().upsert(true).build(), - ) - .await?; + db.collection::("project_data") + .replace_one( + doc! { "_id": project_id.clone()}, + project_data, + ReplaceOptions::builder().upsert(true).build(), + ) + .await?; - info!("Subscribing to project topic: {}", &topic); + info!("Subscribing to project topic: {}", &topic); - state.wsclient.subscribe(topic.into()).await?; + state.wsclient.subscribe(topic.into()).await?; - Ok(Json(json!({ "publicKey": public_key })).into_response()) - } - None => Ok(Json(json!({ - "reason": "Unauthorized. Please make sure to include project secret in Authorization header. " - })) - .into_response()), - } + Ok(Json(json!({ "publicKey": public_key })).into_response()) } diff --git a/src/handlers/webhooks/delete_webhook.rs b/src/handlers/webhooks/delete_webhook.rs index a8d233e..51202b9 100644 --- a/src/handlers/webhooks/delete_webhook.rs +++ b/src/handlers/webhooks/delete_webhook.rs @@ -1,5 +1,5 @@ use { - crate::{error::Result, state::AppState, types::WebhookInfo}, + crate::{error::Result, extractors::AuthedProjectId, state::AppState, types::WebhookInfo}, axum::{ extract::{Path, State}, response::IntoResponse, @@ -11,7 +11,8 @@ use { }; pub async fn handler( - Path((project_id, webhook_id)): Path<(String, Uuid)>, + AuthedProjectId(project_id, _): AuthedProjectId, + Path((_, webhook_id)): Path<(String, Uuid)>, State(state): State>, ) -> Result { let request_id = uuid::Uuid::new_v4(); @@ -26,5 +27,5 @@ pub async fn handler( ) .await?; - Ok(axum::http::StatusCode::OK) + Ok(axum::http::StatusCode::OK.into_response()) } diff --git a/src/handlers/webhooks/get_webhooks.rs b/src/handlers/webhooks/get_webhooks.rs index cf8b18c..d3e0edc 100644 --- a/src/handlers/webhooks/get_webhooks.rs +++ b/src/handlers/webhooks/get_webhooks.rs @@ -1,11 +1,7 @@ use { super::WebhookConfig, - crate::{error::Result, state::AppState, types::WebhookInfo}, - axum::{ - extract::{Path, State}, - response::IntoResponse, - Json, - }, + crate::{error::Result, extractors::AuthedProjectId, state::AppState, types::WebhookInfo}, + axum::{extract::State, response::IntoResponse, Json}, futures::TryStreamExt, log::info, mongodb::bson::doc, @@ -13,7 +9,7 @@ use { }; pub async fn handler( - Path(project_id): Path, + AuthedProjectId(project_id, _): AuthedProjectId, State(state): State>, ) -> Result { let request_id = uuid::Uuid::new_v4(); @@ -36,5 +32,5 @@ pub async fn handler( .try_collect() .await?; - Ok((axum::http::StatusCode::OK, Json(webhooks))) + Ok((axum::http::StatusCode::OK, Json(webhooks)).into_response()) } diff --git a/src/handlers/webhooks/register_webhook.rs b/src/handlers/webhooks/register_webhook.rs index 818584a..cd43f6e 100644 --- a/src/handlers/webhooks/register_webhook.rs +++ b/src/handlers/webhooks/register_webhook.rs @@ -1,11 +1,13 @@ use { super::WebhookConfig, - crate::{error::Result, handlers::webhooks::validate_url, state::AppState, types::WebhookInfo}, - axum::{ - extract::{Path, State}, - response::IntoResponse, - Json, + crate::{ + error::Result, + extractors::AuthedProjectId, + handlers::webhooks::validate_url, + state::AppState, + types::WebhookInfo, }, + axum::{extract::State, response::IntoResponse, Json}, log::info, mongodb::bson::doc, serde::Serialize, @@ -19,7 +21,7 @@ struct RegisterWebhookResponse { } pub async fn handler( - Path(project_id): Path, + AuthedProjectId(project_id, _): AuthedProjectId, State(state): State>, Json(webhook_info): Json, ) -> Result { @@ -47,5 +49,6 @@ pub async fn handler( Ok(( axum::http::StatusCode::CREATED, Json(RegisterWebhookResponse { id: webhook_id }), - )) + ) + .into_response()) } diff --git a/src/handlers/webhooks/update_webhook.rs b/src/handlers/webhooks/update_webhook.rs index 2bb6730..8773131 100644 --- a/src/handlers/webhooks/update_webhook.rs +++ b/src/handlers/webhooks/update_webhook.rs @@ -1,6 +1,6 @@ use { super::WebhookConfig, - crate::{error::Result, state::AppState, types::WebhookInfo}, + crate::{error::Result, extractors::AuthedProjectId, state::AppState, types::WebhookInfo}, axum::{ extract::{Path, State}, response::IntoResponse, @@ -13,7 +13,8 @@ use { }; pub async fn handler( - Path((project_id, webhook_id)): Path<(String, Uuid)>, + Path((_, webhook_id)): Path<(String, Uuid)>, + AuthedProjectId(project_id, _): AuthedProjectId, State(state): State>, Json(webhook_info): Json, ) -> Result { @@ -29,5 +30,5 @@ pub async fn handler( ) .await?; - Ok(axum::http::StatusCode::NO_CONTENT) + Ok(axum::http::StatusCode::NO_CONTENT.into_response()) } diff --git a/src/lib.rs b/src/lib.rs index 8c1e2eb..814d60f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,11 +26,14 @@ pub mod analytics; pub mod auth; pub mod config; pub mod error; +pub mod extractors; pub mod handlers; pub mod jsonrpc; mod metrics; mod networking; +pub mod registry; mod state; +mod storage; pub mod types; pub mod websocket_service; pub mod wsclient; @@ -80,6 +83,12 @@ pub async fn bootstap(mut shutdown: broadcast::Receiver<()>, config: Configurati &config.project_id, )); + let registry = Arc::new(registry::Registry::new( + &config.registry_url, + &config.registry_auth_token, + &config, + )?); + // Creating state let state = AppState::new( analytics, @@ -89,6 +98,7 @@ pub async fn bootstap(mut shutdown: broadcast::Receiver<()>, config: Configurati wsclient.clone(), http_client, Some(Metrics::default()), + registry, )?; let port = state.config.port; diff --git a/src/registry.rs b/src/registry.rs new file mode 100644 index 0000000..37e0eb0 --- /dev/null +++ b/src/registry.rs @@ -0,0 +1,118 @@ +use { + crate::{ + config::Configuration, + error::Result, + storage::{redis::Redis, KeyValueStorage}, + }, + hyper::header, + serde::{Deserialize, Serialize}, + sha2::{Digest, Sha256}, + std::{sync::Arc, time::Duration}, + tracing::error, + tungstenite::http::HeaderValue, +}; + +pub struct RegistryHttpClient { + addr: String, + http_client: reqwest::Client, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RegistryAuthResponse { + pub is_valid: bool, +} + +impl RegistryHttpClient { + pub fn new(base_url: impl Into, auth_token: &str) -> Result { + let mut auth_value = HeaderValue::from_str(&format!("Bearer {}", auth_token))?; + + // Make sure we're not leaking auth token in debug output. + auth_value.set_sensitive(true); + + let mut headers = header::HeaderMap::new(); + headers.insert(header::AUTHORIZATION, auth_value); + + let http_client = reqwest::Client::builder() + .default_headers(headers) + .build()?; + + Ok(Self { + addr: base_url.into(), + http_client, + }) + } + + pub async fn authenticate(&self, id: &str, secret: &str) -> Result { + let res: RegistryAuthResponse = self + .http_client + .get(format!( + "{}/internal/project/validate-cast-keys?projectId={id}&secret={secret}", + self.addr + )) + .send() + .await? + .json() + .await?; + + Ok(if res.is_valid { + hyper::StatusCode::OK + } else { + hyper::StatusCode::UNAUTHORIZED + }) + } +} + +pub struct Registry { + client: Arc, + cache: Option>, +} + +impl Registry { + pub fn new(url: &str, auth_token: &str, config: &Configuration) -> Result { + let client = Arc::new(RegistryHttpClient::new(url, auth_token)?); + + let cache = if let Some(redis_addr) = &config.auth_redis_addr() { + Some(Arc::new(Redis::new( + redis_addr, + config.redis_pool_size as usize, + )?)) + } else { + None + }; + Ok(Self { client, cache }) + } + + pub async fn is_authenticated(&self, id: &str, secret: &str) -> Result { + self.is_authenticated_internal(id, secret) + .await + .map_err(|e| { + error!("Failed to authenticate project: {}", e); + e + }) + } + + async fn is_authenticated_internal(&self, id: &str, secret: &str) -> Result { + let mut hasher = Sha256::new(); + hasher.update(id); + hasher.update(secret); + let hash = hasher.finalize(); + let hash = hex::encode(hash); + + if let Some(cache) = &self.cache { + if let Some(validity) = cache.get(&hash).await? { + return Ok(validity); + } + } + + let validity = self.client.authenticate(id, secret).await?.is_success(); + + if let Some(cache) = &self.cache { + cache.set(&hash, &validity, Some(CACHE_TTL)).await?; + } + + Ok(validity) + } +} + +const CACHE_TTL: Duration = Duration::from_secs(60 * 30); diff --git a/src/state.rs b/src/state.rs index 89d2669..8982e61 100644 --- a/src/state.rs +++ b/src/state.rs @@ -3,6 +3,7 @@ use { analytics::CastAnalytics, error::Result, metrics::Metrics, + registry::Registry, types::{ClientData, LookupEntry, WebhookInfo}, Configuration, }, @@ -25,11 +26,13 @@ pub struct AppState { pub keypair: Keypair, pub wsclient: Arc, pub http_relay_client: Arc, + pub registry: Arc, } build_info::build_info!(fn build_info); impl AppState { + #[allow(clippy::too_many_arguments)] pub fn new( analytics: CastAnalytics, config: Configuration, @@ -38,6 +41,7 @@ impl AppState { wsclient: Arc, http_relay_client: Arc, metrics: Option, + registry: Arc, ) -> crate::Result { let build_info: &BuildInfo = build_info(); @@ -50,6 +54,7 @@ impl AppState { keypair, wsclient, http_relay_client, + registry, }) } diff --git a/src/storage/error.rs b/src/storage/error.rs new file mode 100644 index 0000000..2091fa3 --- /dev/null +++ b/src/storage/error.rs @@ -0,0 +1,23 @@ +//! Error typedefs used by this crate + +use thiserror::Error as ThisError; + +/// The error produced from most Storage functions +#[derive(Debug, ThisError)] +pub enum StorageError { + /// Couldn't set the expiration for the given key + #[error("couldn't set the expiry to the key")] + SetExpiry, + /// Unable to serialize data to store + #[error("error on serialize data")] + Serialize, + /// Unable to deserialize data from store + #[error("error on deserialize data")] + Deserialize, + /// Error on establishing a connection with the storage + #[error("error on open connection")] + Connection(String), + /// An unexpected error occurred + #[error("{0:?}")] + Other(String), +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs new file mode 100644 index 0000000..e18913a --- /dev/null +++ b/src/storage/mod.rs @@ -0,0 +1,44 @@ +use { + crate::storage::error::StorageError, + async_trait::async_trait, + serde::{de::DeserializeOwned, Serialize}, + std::{fmt::Debug, time::Duration}, +}; + +pub mod error; +pub mod redis; + +/// The Result type returned by Storage functions +pub type StorageResult = Result; + +#[async_trait] +pub trait KeyValueStorage: 'static + Send + Sync + Debug +where + T: Serialize + DeserializeOwned + Send + Sync, +{ + /// Retrieve the data associated with the given key. + async fn get(&self, key: &str) -> StorageResult>; + + /// Set the value for the given key. + async fn set(&self, key: &str, value: &T, ttl: Option) -> StorageResult<()>; + + /// Delete the value associated with the given key. + async fn del(&self, key: &str) -> StorageResult<()>; +} + +/// Holder the type of data will be serialized to be stored. +pub type Data = Vec; + +pub fn serialize(data: &T) -> StorageResult +where + T: Serialize, +{ + rmp_serde::to_vec(data).map_err(|_| StorageError::Serialize) +} + +pub fn deserialize(data: &[u8]) -> StorageResult +where + T: DeserializeOwned, +{ + rmp_serde::from_slice(data).map_err(|_| StorageError::Deserialize) +} diff --git a/src/storage/redis/mod.rs b/src/storage/redis/mod.rs new file mode 100644 index 0000000..6465333 --- /dev/null +++ b/src/storage/redis/mod.rs @@ -0,0 +1,158 @@ +use { + crate::storage::{deserialize, serialize, KeyValueStorage, StorageError, StorageResult}, + async_trait::async_trait, + deadpool_redis::{ + redis::{AsyncCommands, Value}, + Config, + Pool, + }, + serde::{de::DeserializeOwned, Serialize}, + std::{fmt::Debug, time::Duration}, +}; + +const LOCAL_REDIS_ADDR: &str = "redis://localhost:6379/0"; + +#[derive(Debug, Clone)] +pub enum Addr<'a> { + Combined(&'a str), + Separate { read: &'a str, write: &'a str }, +} + +impl<'a> Default for Addr<'a> { + fn default() -> Self { + Self::Combined(LOCAL_REDIS_ADDR) + } +} + +impl<'a> Addr<'a> { + pub fn read(&self) -> &str { + match self { + Self::Combined(addr) => addr, + Self::Separate { read, .. } => read, + } + } + + pub fn write(&self) -> &str { + match self { + Self::Combined(addr) => addr, + Self::Separate { write, .. } => write, + } + } +} + +impl<'a> From<(&'a Option, &'a Option)> for Addr<'a> { + fn from(val: (&'a Option, &'a Option)) -> Self { + match val { + (Some(read), Some(write)) => Self::Separate { read, write }, + (Some(addr), None) => Self::Combined(addr), + (None, Some(addr)) => Self::Combined(addr), + _ => Default::default(), + } + } +} + +/// A interface to interact with Redis cache. +#[derive(Clone)] +pub struct Redis { + read_pool: Pool, + write_pool: Pool, +} + +impl Debug for Redis { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Redis").finish() + } +} + +impl Redis { + /// Instantiate a new Redis. + pub fn new(addr: &Addr<'_>, pool_size: usize) -> StorageResult { + let get_pool = |cfg: Config| -> Result<_, StorageError> { + let pool = cfg + .builder() + .map_err(|e| StorageError::Other(format!("{e}")))? + .max_size(pool_size) + .build() + .map_err(|e| StorageError::Other(format!("{e}")))?; + + Ok(pool) + }; + + let read_config = Config::from_url(addr.read()); + let read_pool = get_pool(read_config)?; + + let write_config = Config::from_url(addr.write()); + let write_pool = get_pool(write_config)?; + + Ok(Self { + read_pool, + write_pool, + }) + } + + async fn set_internal( + &self, + key: &str, + data: &[u8], + ttl: Option, + ) -> StorageResult<()> { + let mut conn = self + .write_pool + .get() + .await + .map_err(|e| StorageError::Connection(format!("{e}")))?; + + let res_fut = if let Some(ttl) = ttl { + let ttl = ttl + .as_secs() + .try_into() + .map_err(|_| StorageError::SetExpiry)?; + + conn.set_ex(key, data, ttl) + } else { + conn.set(key, data) + }; + + res_fut + .await + .map_err(|e| StorageError::Other(format!("{e}")))?; + + Ok(()) + } +} + +#[async_trait] +impl KeyValueStorage for Redis +where + T: Serialize + DeserializeOwned + Send + Sync, +{ + async fn get(&self, key: &str) -> StorageResult> { + self.read_pool + .get() + .await + .map_err(|e| StorageError::Connection(format!("{e}")))? + .get::<_, Value>(key) + .await + .map_err(|e| StorageError::Other(format!("{e}"))) + .map(|data| match data { + Value::Nil => Ok(None), + Value::Data(data) => Ok(Some(deserialize(&data)?)), + _ => Err(StorageError::Deserialize), + })? + } + + async fn set(&self, key: &str, value: &T, ttl: Option) -> StorageResult<()> { + let data = serialize(value)?; + self.set_internal(key, &data, ttl).await + } + + async fn del(&self, key: &str) -> StorageResult<()> { + self.write_pool + .get() + .await + .map_err(|e| StorageError::Connection(format!("{e}")))? + .del(key) + .await + .map_err(|e| StorageError::Other(format!("{e}"))) + } +} diff --git a/tests/context/server.rs b/tests/context/server.rs index 99d0c37..c9351d4 100644 --- a/tests/context/server.rs +++ b/tests/context/server.rs @@ -35,6 +35,8 @@ impl CastServer { let relay_url = std::env::var("RELAY_URL").unwrap(); let cast_url = std::env::var("CAST_URL").unwrap(); let test_keypair_seed = std::env::var("TEST_KEYPAIR_SEED").unwrap(); + let registry_url = std::env::var("REGISTRY_URL").unwrap(); + let registry_auth_token = std::env::var("REGISTRY_AUTH_TOKEN").unwrap(); std::thread::spawn(move || { rt.block_on(async move { @@ -54,6 +56,11 @@ impl CastServer { analytics_export_bucket: "".to_string(), analytics_geoip_db_bucket: None, analytics_geoip_db_key: None, + auth_redis_addr_read: None, + auth_redis_addr_write: None, + redis_pool_size: 0, + registry_url, + registry_auth_token, }; cast_server::bootstap(shutdown, config).await diff --git a/tests/integration.rs b/tests/integration.rs index e3b03ae..a1e78cd 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -81,12 +81,12 @@ PROJECT_ID to be set", // Eat up the "connected" message _ = rx.recv().await.unwrap(); - let project_secret = uuid::Uuid::new_v4().to_string(); + let project_secret = std::env::var("CAST_PROJECT_SECRET").expect("CAST_PROJECT_SECRET not set"); // Register project - generating subscribe topic let dapp_pubkey_response: serde_json::Value = http_client .get(format!("{}/{}/subscribe-topic", &cast_url, &project_id)) - .bearer_auth(project_secret) + .bearer_auth(&project_secret) .send() .await .unwrap() @@ -165,7 +165,7 @@ PROJECT_ID to be set", .unwrap(); let resp = rx.recv().await.unwrap(); - // wsclient.fetch(response_topic.clone().into()).await.unwrap(); + let RelayClientEvent::Message(msg) = resp else { panic!("Expected message, got {:?}", resp); }; @@ -193,6 +193,7 @@ PROJECT_ID to be set", .subscribe(notify_topic.clone().into()) .await .unwrap(); + let notification = Notification { title: "string".to_owned(), body: "string".to_owned(), @@ -211,6 +212,7 @@ PROJECT_ID to be set", let _res = http_client .post(format!("{}/{}/notify", &cast_url, &project_id)) + .bearer_auth(&project_secret) .json(¬ify_body) .send() .await @@ -274,6 +276,7 @@ PROJECT_ID to be set", let resp = http_client .post(format!("{}/{}/notify", &cast_url, &project_id)) + .bearer_auth(project_secret) .json(¬ify_body) .send() .await