diff --git a/proxy/src/cache/endpoints.rs b/proxy/src/cache/endpoints.rs index 12c33169bf12..400c76291e3b 100644 --- a/proxy/src/cache/endpoints.rs +++ b/proxy/src/cache/endpoints.rs @@ -1,7 +1,7 @@ use std::convert::Infallible; +use std::future::pending; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::time::Duration; use dashmap::DashSet; use redis::streams::{StreamReadOptions, StreamReadReply}; @@ -19,25 +19,38 @@ use crate::rate_limiter::GlobalRateLimiter; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::types::EndpointId; +#[allow(clippy::enum_variant_names)] #[derive(Deserialize, Debug, Clone)] -pub(crate) struct ControlPlaneEventKey { - endpoint_created: Option, - branch_created: Option, - project_created: Option, +#[serde(tag = "type", rename_all(deserialize = "snake_case"))] +enum ControlPlaneEvent { + EndpointCreated { endpoint_created: EndpointCreated }, + BranchCreated { branch_created: BranchCreated }, + ProjectCreated { project_created: ProjectCreated }, } + #[derive(Deserialize, Debug, Clone)] struct EndpointCreated { endpoint_id: String, } + #[derive(Deserialize, Debug, Clone)] struct BranchCreated { branch_id: String, } + #[derive(Deserialize, Debug, Clone)] struct ProjectCreated { project_id: String, } +impl TryFrom<&Value> for ControlPlaneEvent { + type Error = anyhow::Error; + fn try_from(value: &Value) -> Result { + let json = String::from_redis_value(value)?; + Ok(serde_json::from_str(&json)?) + } +} + pub struct EndpointsCache { config: EndpointCacheConfig, endpoints: DashSet, @@ -60,6 +73,7 @@ impl EndpointsCache { ready: AtomicBool::new(false), } } + pub(crate) async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool { if !self.ready.load(Ordering::Acquire) { return true; @@ -74,6 +88,7 @@ impl EndpointsCache { } !rejected } + fn should_reject(&self, endpoint: &EndpointId) -> bool { if endpoint.is_endpoint() { !self.endpoints.contains(&EndpointIdInt::from(endpoint)) @@ -87,33 +102,28 @@ impl EndpointsCache { .contains(&ProjectIdInt::from(&endpoint.as_project())) } } - fn insert_event(&self, key: ControlPlaneEventKey) { - // Do not do normalization here, we expect the events to be normalized. - if let Some(endpoint_created) = key.endpoint_created { - self.endpoints - .insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into())); - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::EndpointCreated); - } - if let Some(branch_created) = key.branch_created { - self.branches - .insert(BranchIdInt::from(&branch_created.branch_id.into())); - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::BranchCreated); - } - if let Some(project_created) = key.project_created { - self.projects - .insert(ProjectIdInt::from(&project_created.project_id.into())); - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::ProjectCreated); - } + + fn insert_event(&self, event: ControlPlaneEvent) { + let counter = match event { + ControlPlaneEvent::EndpointCreated { endpoint_created } => { + self.endpoints + .insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into())); + RedisEventsCount::EndpointCreated + } + ControlPlaneEvent::BranchCreated { branch_created } => { + self.branches + .insert(BranchIdInt::from(&branch_created.branch_id.into())); + RedisEventsCount::BranchCreated + } + ControlPlaneEvent::ProjectCreated { project_created } => { + self.projects + .insert(ProjectIdInt::from(&project_created.project_id.into())); + RedisEventsCount::ProjectCreated + } + }; + Metrics::get().proxy.redis_events_count.inc(counter); } + pub async fn do_read( &self, mut con: ConnectionWithCredentialsProvider, @@ -131,12 +141,13 @@ impl EndpointsCache { } if cancellation_token.is_cancelled() { info!("cancellation token is cancelled, exiting"); - tokio::time::sleep(Duration::from_secs(60 * 60 * 24 * 7)).await; - // 1 week. + // Maintenance tasks run forever. Sleep forever when canceled. + pending::<()>().await; } tokio::time::sleep(self.config.retry_interval).await; } } + async fn read_from_stream( &self, con: &mut ConnectionWithCredentialsProvider, @@ -162,10 +173,7 @@ impl EndpointsCache { ) .await } - fn parse_key_value(value: &Value) -> anyhow::Result { - let s: String = FromRedisValue::from_redis_value(value)?; - Ok(serde_json::from_str(&s)?) - } + async fn batch_read( &self, conn: &mut ConnectionWithCredentialsProvider, @@ -196,27 +204,25 @@ impl EndpointsCache { anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name); } - let res = res.keys.pop().expect("Checked length above"); - let len = res.ids.len(); - for x in res.ids { + let key = res.keys.pop().expect("Checked length above"); + let len = key.ids.len(); + for stream_id in key.ids { total += 1; - for (_, v) in x.map { - let key = match Self::parse_key_value(&v) { - Ok(x) => x, - Err(e) => { + for value in stream_id.map.values() { + match value.try_into() { + Ok(event) => self.insert_event(event), + Err(err) => { Metrics::get().proxy.redis_errors_total.inc(RedisErrors { channel: &self.config.stream_name, }); - tracing::error!("error parsing value {v:?}: {e:?}"); - continue; + tracing::error!("error parsing value {value:?}: {err:?}"); } }; - self.insert_event(key); } if total.is_power_of_two() { tracing::debug!("endpoints read {}", total); } - *last_id = x.id; + *last_id = stream_id.id; } if return_when_finish && len <= self.config.default_batch_size { break; @@ -229,11 +235,11 @@ impl EndpointsCache { #[cfg(test)] mod tests { - use super::ControlPlaneEventKey; + use super::ControlPlaneEvent; #[test] - fn test() { - let s = "{\"branch_created\":null,\"endpoint_created\":{\"endpoint_id\":\"ep-rapid-thunder-w0qqw2q9\"},\"project_created\":null,\"type\":\"endpoint_created\"}"; - serde_json::from_str::(s).unwrap(); + fn test_parse_control_plane_event() { + let s = r#"{"branch_created":null,"endpoint_created":{"endpoint_id":"ep-rapid-thunder-w0qqw2q9"},"project_created":null,"type":"endpoint_created"}"#; + serde_json::from_str::(s).unwrap(); } }