diff --git a/Cargo.lock b/Cargo.lock index d415ce7..84a5f0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -169,7 +169,7 @@ dependencies = [ "async-lock", "async-task", "concurrent-queue", - "fastrand", + "fastrand 1.9.0", "futures-lite", "slab", ] @@ -220,7 +220,7 @@ dependencies = [ "log", "parking", "polling", - "rustix", + "rustix 0.37.23", "slab", "socket2 0.4.9", "waker-fn", @@ -248,7 +248,7 @@ dependencies = [ "cfg-if 1.0.0", "event-listener", "futures-lite", - "rustix", + "rustix 0.37.23", "signal-hook", "windows-sys 0.48.0", ] @@ -456,7 +456,7 @@ dependencies = [ "async-lock", "async-task", "atomic-waker", - "fastrand", + "fastrand 1.9.0", "futures-lite", "log", ] @@ -724,6 +724,12 @@ dependencies = [ "instant", ] +[[package]] +name = "fastrand" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" + [[package]] name = "femme" version = "2.2.1" @@ -807,7 +813,7 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" dependencies = [ - "fastrand", + "fastrand 1.9.0", "futures-core", "futures-io", "memchr", @@ -1225,6 +1231,12 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +[[package]] +name = "linux-raw-sys" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" + [[package]] name = "lock_api" version = "0.4.10" @@ -1295,6 +1307,7 @@ dependencies = [ "serde", "sled", "structopt", + "tempfile", "tide", ] @@ -1734,7 +1747,20 @@ dependencies = [ "errno", "io-lifetimes", "libc", - "linux-raw-sys", + "linux-raw-sys 0.3.8", + "windows-sys 0.48.0", +] + +[[package]] +name = "rustix" +version = "0.38.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" +dependencies = [ + "bitflags 2.4.0", + "errno", + "libc", + "linux-raw-sys 0.4.13", "windows-sys 0.48.0", ] @@ -2162,6 +2188,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ef1adac450ad7f4b3c28589471ade84f25f731a7a0fe30d71dfa9f60fd808e5" +dependencies = [ + "cfg-if 1.0.0", + "fastrand 2.0.2", + "redox_syscall 0.4.1", + "rustix 0.38.21", + "windows-sys 0.48.0", +] + [[package]] name = "textwrap" version = "0.11.0" diff --git a/Cargo.toml b/Cargo.toml index 43e3d26..1fca882 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,3 +20,6 @@ tide = "0.16.0" # Workaround for hyper = { version = "0.14", features = ["tcp"] } + +[dev-dependencies] +tempfile = "3" diff --git a/src/lib.rs b/src/lib.rs index c365371..b058073 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod metrics; pub mod notifier; +pub mod schedule; pub mod server; pub mod state; diff --git a/src/main.rs b/src/main.rs index b258131..b3c5617 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,6 +50,7 @@ async fn main() -> Result<()> { &opt.password, opt.topic.clone(), metrics_state.clone(), + opt.interval, )?; let state2 = state.clone(); diff --git a/src/notifier.rs b/src/notifier.rs index f1a175b..257009c 100644 --- a/src/notifier.rs +++ b/src/notifier.rs @@ -1,15 +1,18 @@ +use std::time::{Duration, SystemTime}; + use a2::{ Client, DefaultNotificationBuilder, Error::ResponseError, NotificationBuilder, NotificationOptions, Priority, }; -use anyhow::Result; +use anyhow::{bail, Context as _, Result}; use log::*; use crate::metrics::Metrics; +use crate::schedule::Schedule; use crate::state::State; pub async fn start(state: State, interval: std::time::Duration) -> Result<()> { - let db = state.db(); + let schedule = state.schedule(); let metrics = state.metrics(); let production_client = state.production_client(); let sandbox_client = state.sandbox_client(); @@ -21,88 +24,114 @@ pub async fn start(state: State, interval: std::time::Duration) -> Result<()> { ); loop { - let wakeup_start = std::time::Instant::now(); - wakeup(db, metrics, production_client, sandbox_client, topic).await; - let elapsed = wakeup_start.elapsed(); - info!( - "Waking up all devices took {}", - humantime::format_duration(elapsed) - ); - async_std::task::sleep(interval.saturating_sub(elapsed)).await; + metrics + .heartbeat_token_count + .set(schedule.token_count() as i64); + + let Some((timestamp, token)) = schedule.pop() else { + info!("No tokens to notify, sleeping for a minute."); + async_std::task::sleep(Duration::from_secs(60)).await; + continue; + }; + + // Sleep until we need to notify the token. + let now = SystemTime::now(); + let timestamp: SystemTime = SystemTime::UNIX_EPOCH + .checked_add(Duration::from_secs(timestamp)) + .unwrap_or(now); + let timestamp = std::cmp::min(timestamp, now); + let delay = timestamp + .checked_add(interval) + .unwrap_or(now) + .duration_since(now) + .unwrap_or_default(); + async_std::task::sleep(delay).await; + + if let Err(err) = wakeup( + schedule, + metrics, + production_client, + sandbox_client, + topic, + token, + ) + .await + { + error!("Failed to notify token: {err:#}"); + + // Sleep to avoid busy looping and flooding APNS + // with requests in case of database errors. + async_std::task::sleep(Duration::from_secs(60)).await; + } } } async fn wakeup( - db: &sled::Db, + schedule: &Schedule, metrics: &Metrics, production_client: &Client, sandbox_client: &Client, topic: Option<&str>, -) { - let tokens = db - .iter() - .filter_map(|entry| match entry { - Ok((key, _)) => Some(String::from_utf8(key.to_vec()).unwrap()), - Err(_) => None, - }) - .collect::>(); - - info!("sending notifications to {} devices", tokens.len()); - metrics.heartbeat_token_count.set(tokens.len() as i64); + key_device_token: String, +) -> Result<()> { + info!("notify: {}", key_device_token); - for key_device_token in tokens { - info!("notify: {}", key_device_token); + let (client, device_token) = + if let Some(sandbox_token) = key_device_token.strip_prefix("sandbox:") { + (sandbox_client, sandbox_token) + } else { + (production_client, key_device_token.as_str()) + }; - let (client, device_token) = - if let Some(sandbox_token) = key_device_token.strip_prefix("sandbox:") { - (sandbox_client, sandbox_token) - } else { - (production_client, key_device_token.as_str()) - }; - - // Send silent notification. - // According to - // to send a silent notification you need to set background notification flag `content-available` to 1 - // and don't include `alert`, `badge` or `sound`. - let payload = DefaultNotificationBuilder::new() - .set_content_available() - .build( - device_token, - NotificationOptions { - // Normal priority (5) means - // "send the notification based on power considerations on the user’s device". - // - apns_priority: Some(Priority::Normal), - apns_topic: topic, - ..Default::default() - }, - ); - - match client.send(payload).await { - Ok(res) => match res.code { - 200 => { - info!("delivered notification for {}", device_token); - metrics.heartbeat_notifications_total.inc(); - } - _ => { - warn!("unexpected status: {:?}", res); - } + // Send silent notification. + // According to + // to send a silent notification you need to set background notification flag `content-available` to 1 + // and don't include `alert`, `badge` or `sound`. + let payload = DefaultNotificationBuilder::new() + .set_content_available() + .build( + device_token, + NotificationOptions { + // Normal priority (5) means + // "send the notification based on power considerations on the user’s device". + // + apns_priority: Some(Priority::Normal), + apns_topic: topic, + ..Default::default() }, - Err(ResponseError(res)) => { - info!( - "Removing token {} due to error {:?}.", - &key_device_token, res - ); - if let Err(err) = db.remove(&key_device_token) { - error!("failed to remove {}: {:?}", &key_device_token, err); - } + ); + + match client.send(payload).await { + Ok(res) => match res.code { + 200 => { + info!("delivered notification for {}", device_token); + schedule + .insert_token_now(&key_device_token) + .await + .context("Failed to update latest notification timestamp")?; + metrics.heartbeat_notifications_total.inc(); } - Err(err) => { - error!( - "failed to send notification: {}, {:?}", - key_device_token, err - ); + _ => { + bail!("unexpected status: {:?}", res); } + }, + Err(ResponseError(res)) => { + info!( + "Removing token {} due to error {:?}.", + &key_device_token, res + ); + schedule + .remove_token(&key_device_token) + .with_context(|| format!("Failed to remove {}", &key_device_token))?; + } + Err(err) => { + // Update notification time regardless of success + // to avoid busy looping. + schedule + .insert_token_now(&key_device_token) + .await + .with_context(|| format!("Failed to update token timestamp: {err:?}"))?; } } + Ok(()) } diff --git a/src/schedule.rs b/src/schedule.rs new file mode 100644 index 0000000..bac9fb3 --- /dev/null +++ b/src/schedule.rs @@ -0,0 +1,117 @@ +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::path::Path; +use std::sync::Mutex; +use std::time::SystemTime; + +use anyhow::Result; + +#[derive(Debug)] +pub struct Schedule { + /// Database to persist tokens and latest notification time. + db: sled::Db, + + /// Min-heap of tokens prioritized by the latest notification timestamp. + heap: Mutex, String)>>, +} + +impl Schedule { + pub fn new(db_path: &Path) -> Result { + let db = sled::open(db_path)?; + let mut heap = BinaryHeap::new(); + for entry in db.iter() { + let (key, value) = entry?; + let token = String::from_utf8(key.to_vec()).unwrap(); + + let timestamp = if let Some(value) = value.get(..8) { + let mut buf: [u8; 8] = [0; 8]; + buf.copy_from_slice(&value[..8]); + u64::from_be_bytes(buf) + } else { + 0 + }; + heap.push((Reverse(timestamp), token)) + } + let heap = Mutex::new(heap); + Ok(Self { db, heap }) + } + + /// Registers a new heartbeat notification token. + /// + /// This should also be called after successful notification + /// to update latest notification time. + pub async fn insert_token(&self, token: &str, now: u64) -> Result<()> { + self.db.insert(token.as_bytes(), &u64::to_be_bytes(now))?; + self.db.flush_async().await?; + let mut heap = self.heap.lock().unwrap(); + heap.push((Reverse(now), token.to_owned())); + Ok(()) + } + + pub async fn insert_token_now(&self, token: &str) -> Result<()> { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + self.insert_token(token, now).await + } + + /// Removes token from the schedule. + pub fn remove_token(&self, token: &str) -> Result<()> { + self.db.remove(token)?; + Ok(()) + } + + pub fn pop(&self) -> Option<(u64, String)> { + let mut heap = self.heap.lock().unwrap(); + let (timestamp, token) = heap.pop()?; + Some((timestamp.0, token)) + } + + /// Returns the number of tokens in the schedule. + pub fn token_count(&self) -> usize { + let heap = self.heap.lock().unwrap(); + heap.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use tempfile::tempdir; + + #[async_std::test] + async fn test_schedule() -> Result<()> { + let dir = tempdir()?; + let db_path = dir.path().join("db.sled"); + let schedule = Schedule::new(&db_path)?; + + schedule.insert_token("foo", 10).await?; + schedule.insert_token("bar", 20).await?; + + let (first_timestamp, first_token) = schedule.pop().unwrap(); + assert_eq!(first_timestamp, 10); + assert_eq!(first_token, "foo"); + schedule.insert_token("foo", 30).await?; + + // Reopen to test persistence. + drop(schedule); + let schedule = Schedule::new(&db_path)?; + + let (second_timestamp, second_token) = schedule.pop().unwrap(); + assert_eq!(second_timestamp, 20); + assert_eq!(second_token, "bar"); + + // Simulate restart or crash, token "bar" was not reinserted or removed by the app. + drop(schedule); + let schedule = Schedule::new(&db_path)?; + + // Token "bar" is still there. + let (second_timestamp, second_token) = schedule.pop().unwrap(); + assert_eq!(second_timestamp, 20); + assert_eq!(second_token, "bar"); + + Ok(()) + } +} diff --git a/src/server.rs b/src/server.rs index 6505f56..da4a17e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -29,9 +29,8 @@ async fn register_device(mut req: tide::Request) -> tide::Result) -> tide::Result) -> tide::Result // // Unsubscribe invalid token from heartbeat notification if it is subscribed. - if let Err(err) = db.remove(device_token) { + if let Err(err) = schedule.remove_token(device_token) { error!("failed to remove {}: {:?}", &device_token, err); } // Return 410 Gone response so email server can remove the token. diff --git a/src/state.rs b/src/state.rs index 0c58495..860a359 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,12 +1,13 @@ +use std::io::Seek; use std::path::Path; +use std::time::Duration; use a2::{Client, Endpoint}; use anyhow::{Context as _, Result}; use async_std::sync::Arc; -use log::*; -use std::io::Seek; use crate::metrics::Metrics; +use crate::schedule::Schedule; #[derive(Debug, Clone)] pub struct State { @@ -15,7 +16,7 @@ pub struct State { #[derive(Debug)] pub struct InnerState { - db: sled::Db, + schedule: Schedule, production_client: Client, @@ -24,6 +25,9 @@ pub struct InnerState { topic: Option, metrics: Arc, + + /// Heartbeat notification interval. + interval: Duration, } impl State { @@ -33,8 +37,9 @@ impl State { password: &str, topic: Option, metrics: Arc, + interval: Duration, ) -> Result { - let db = sled::open(db)?; + let schedule = Schedule::new(db)?; let production_client = Client::certificate(&mut certificate, password, Endpoint::Production) .context("Failed to create production client")?; @@ -42,21 +47,20 @@ impl State { let sandbox_client = Client::certificate(&mut certificate, password, Endpoint::Sandbox) .context("Failed to create sandbox client")?; - info!("{} devices registered currently", db.len()); - Ok(State { inner: Arc::new(InnerState { - db, + schedule, production_client, sandbox_client, topic, metrics, + interval, }), }) } - pub fn db(&self) -> &sled::Db { - &self.inner.db + pub fn schedule(&self) -> &Schedule { + &self.inner.schedule } pub fn production_client(&self) -> &Client { @@ -74,4 +78,8 @@ impl State { pub fn metrics(&self) -> &Metrics { self.inner.metrics.as_ref() } + + pub fn interval(&self) -> Duration { + self.inner.interval + } }