Skip to content

Commit

Permalink
Switch from async-std and tide to tokio and axum
Browse files Browse the repository at this point in the history
  • Loading branch information
link2xt authored Aug 2, 2024
1 parent f45b046 commit a7d7312
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 1,178 deletions.
1,283 changes: 168 additions & 1,115 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ license = "MIT OR Apache-2.0"
[dependencies]
a2 = { git = "https://github.com/WalletConnect/a2/", branch = "master" }
anyhow = "1.0.32"
async-std = { version = "1.9", features = ["tokio1", "attributes", "unstable"] }
axum = "0.7.5"
femme = "2.1.0"
humantime = "2.0.1"
log = "0.4.11"
prometheus-client = "0.22.2"
rand = "0.8.5"
reqwest = "0.12.4"
serde = { version = "1.0.114", features = ["derive"] }
serde_json = "1.0.122"
sled = "0.34.2"
structopt = "0.3.15"
tide = "0.16.0"
tokio = { version = "1.39.2", features = ["full"] }
yup-oauth2 = "9.0.0"

[dev-dependencies]
Expand Down
6 changes: 3 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct Opt {
fcm_key_path: String,
}

#[async_std::main]
#[tokio::main]
async fn main() -> Result<()> {
femme::start();

Expand All @@ -63,7 +63,7 @@ async fn main() -> Result<()> {

if let Some(metrics_address) = opt.metrics.clone() {
let state = state.clone();
async_std::task::spawn(async move { metrics::start(state, metrics_address).await });
tokio::task::spawn(async move { metrics::start(state, metrics_address).await });
}

// Setup mulitple parallel notifiers.
Expand All @@ -72,7 +72,7 @@ async fn main() -> Result<()> {
// and use the same HTTP/2 clients, one for production and one for sandbox server.
for _ in 0..50 {
let state = state.clone();
async_std::task::spawn(async move { notifier::start(state, interval).await });
tokio::task::spawn(async move { notifier::start(state, interval).await });
}

server::start(state, host, port).await?;
Expand Down
31 changes: 19 additions & 12 deletions src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
use std::sync::atomic::AtomicI64;

use anyhow::Result;
use axum::http::{header, HeaderMap};
use axum::response::IntoResponse;
use axum::routing::get;
use prometheus_client::encoding::text::encode;
use prometheus_client::metrics::counter::Counter;
use prometheus_client::metrics::gauge::Gauge;
use prometheus_client::registry::Registry;

use anyhow::Result;

use crate::state::State;

#[derive(Debug, Default)]
Expand Down Expand Up @@ -86,18 +88,23 @@ impl Metrics {
}

pub async fn start(state: State, server: String) -> Result<()> {
let mut app = tide::with_state(state);
app.at("/metrics").get(metrics);
app.listen(server).await?;
let app = axum::Router::new()
.route("/metrics", get(metrics))
.with_state(state);
let listener = tokio::net::TcpListener::bind(server).await?;
axum::serve(listener, app).await?;
Ok(())
}

async fn metrics(req: tide::Request<State>) -> tide::Result<tide::Response> {
async fn metrics(axum::extract::State(state): axum::extract::State<State>) -> impl IntoResponse {
let mut encoded = String::new();
encode(&mut encoded, &req.state().metrics().registry).unwrap();
let response = tide::Response::builder(tide::StatusCode::Ok)
.body(encoded)
.content_type("application/openmetrics-text; version=1.0.0; charset=utf-8")
.build();
Ok(response)
encode(&mut encoded, &state.metrics().registry).unwrap();
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
"application/openmetrics-text; version=1.0.0; charset=utf-8"
.parse()
.unwrap(),
);
(headers, encoded)
}
6 changes: 3 additions & 3 deletions src/notifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub async fn start(state: State, interval: std::time::Duration) -> Result<()> {

let Some((timestamp, token)) = schedule.pop()? else {
info!("No tokens to notify, sleeping for a minute.");
async_std::task::sleep(Duration::from_secs(60)).await;
tokio::time::sleep(Duration::from_secs(60)).await;
continue;
};

Expand All @@ -49,7 +49,7 @@ pub async fn start(state: State, interval: std::time::Duration) -> Result<()> {
"Sleeping for {} before next notification.",
humantime::format_duration(delay)
);
async_std::task::sleep(delay).await;
tokio::time::sleep(delay).await;
}

if let Err(err) = wakeup(
Expand All @@ -66,7 +66,7 @@ pub async fn start(state: State, interval: std::time::Duration) -> Result<()> {

// Sleep to avoid busy looping and flooding APNS
// with requests in case of database errors.
async_std::task::sleep(Duration::from_secs(60)).await;
tokio::time::sleep(Duration::from_secs(60)).await;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ mod tests {

use tempfile::tempdir;

#[async_std::test]
#[tokio::test]
async fn test_schedule() -> Result<()> {
let dir = tempdir()?;
let db_path = dir.path().join("db.sled");
Expand Down
108 changes: 67 additions & 41 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use a2::{
Priority, PushType,
};
use anyhow::{bail, Error, Result};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use log::*;
use serde::Deserialize;
use std::str::FromStr;
Expand All @@ -11,13 +14,13 @@ use crate::metrics::Metrics;
use crate::state::State;

pub async fn start(state: State, server: String, port: u16) -> Result<()> {
let mut app = tide::with_state(state);
app.at("/").get(|_| async { Ok("Hello, world!") });
app.at("/register").post(register_device);
app.at("/notify").post(notify_device);

info!("Listening on {server}:port");
app.listen((server, port)).await?;
let app = axum::Router::new()
.route("/", get(|| async { "Hello, world!" }))
.route("/register", post(register_device))
.route("/notify", post(notify_device))
.with_state(state);
let listener = tokio::net::TcpListener::bind((server, port)).await?;
axum::serve(listener, app).await?;
Ok(())
}

Expand All @@ -26,20 +29,44 @@ struct DeviceQuery {
token: String,
}

struct AppError(anyhow::Error);

impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}

impl IntoResponse for AppError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", self.0),
)
.into_response()
}
}

/// Registers a device for heartbeat notifications.
async fn register_device(mut req: tide::Request<State>) -> tide::Result<tide::Response> {
let query: DeviceQuery = req.body_json().await?;
async fn register_device(
axum::extract::State(state): axum::extract::State<State>,
body: String,
) -> Result<(), AppError> {
let query: DeviceQuery = serde_json::from_str(&body)?;
info!("register_device {}", query.token);

let schedule = req.state().schedule();
let schedule = state.schedule();
schedule.insert_token_now(&query.token)?;

// Flush database to ensure we don't lose this token in case of restart.
schedule.flush().await?;

req.state().metrics().heartbeat_registrations_total.inc();
state.metrics().heartbeat_registrations_total.inc();

Ok(tide::Response::new(tide::StatusCode::Ok))
Ok(())
}

enum NotificationToken {
Expand Down Expand Up @@ -90,17 +117,17 @@ async fn notify_fcm(
_package_name: &str,
token: &str,
metrics: &Metrics,
) -> tide::Result<tide::Response> {
) -> Result<StatusCode> {
let Some(fcm_api_key) = fcm_api_key else {
warn!("Cannot notify FCM because key is not set");
return Ok(tide::Response::new(tide::StatusCode::InternalServerError));
return Ok(StatusCode::INTERNAL_SERVER_ERROR);
};

if !token
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == ':' || c == '-')
{
return Ok(tide::Response::new(tide::StatusCode::Gone));
return Ok(StatusCode::GONE);
}

let url = "https://fcm.googleapis.com/v1/projects/delta-chat-fcm/messages:send";
Expand All @@ -118,23 +145,19 @@ async fn notify_fcm(
warn!("Failed to deliver FCM notification to {token}");
warn!("BODY: {body:?}");
warn!("RES: {res:?}");
return Ok(tide::Response::new(tide::StatusCode::Gone));
return Ok(StatusCode::GONE);
}
if status.is_server_error() {
warn!("Internal server error while attempting to deliver FCM notification to {token}");
return Ok(tide::Response::new(tide::StatusCode::InternalServerError));
return Ok(StatusCode::INTERNAL_SERVER_ERROR);
}
info!("Delivered notification to FCM token {token}");
metrics.fcm_notifications_total.inc();
Ok(tide::Response::new(tide::StatusCode::Ok))
Ok(StatusCode::OK)
}

async fn notify_apns(
req: tide::Request<State>,
client: a2::Client,
device_token: String,
) -> tide::Result<tide::Response> {
let schedule = req.state().schedule();
async fn notify_apns(state: State, client: a2::Client, device_token: String) -> Result<StatusCode> {
let schedule = state.schedule();
let payload = DefaultNotificationBuilder::new()
.set_title("New messages")
.set_title_loc_key("new_messages") // Localization key for the title.
Expand All @@ -148,7 +171,7 @@ async fn notify_apns(
// High priority (10).
// <https://developer.apple.com/documentation/usernotifications/sending-notification-requests-to-apns>
apns_priority: Some(Priority::High),
apns_topic: req.state().topic(),
apns_topic: state.topic(),
apns_push_type: Some(PushType::Alert),
..Default::default()
},
Expand All @@ -159,14 +182,14 @@ async fn notify_apns(
match res.code {
200 => {
info!("delivered notification for {}", device_token);
req.state().metrics().direct_notifications_total.inc();
state.metrics().direct_notifications_total.inc();
}
_ => {
warn!("unexpected status: {:?}", res);
}
}

Ok(tide::Response::new(tide::StatusCode::Ok))
Ok(StatusCode::OK)
}
Err(ResponseError(res)) => {
info!("Removing token {} due to error {:?}.", &device_token, res);
Expand All @@ -179,21 +202,23 @@ async fn notify_apns(
error!("failed to remove {}: {:?}", &device_token, err);
}
// Return 410 Gone response so email server can remove the token.
Ok(tide::Response::new(tide::StatusCode::Gone))
Ok(StatusCode::GONE)
} else {
Ok(tide::Response::new(tide::StatusCode::InternalServerError))
Ok(StatusCode::INTERNAL_SERVER_ERROR)
}
}
Err(err) => {
error!("failed to send notification: {}, {:?}", device_token, err);
Ok(tide::Response::new(tide::StatusCode::InternalServerError))
Ok(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}

/// Notifies a single device with a visible notification.
async fn notify_device(mut req: tide::Request<State>) -> tide::Result<tide::Response> {
let device_token = req.body_string().await?;
async fn notify_device(
axum::extract::State(state): axum::extract::State<State>,
device_token: String,
) -> Result<StatusCode, AppError> {
info!("Got direct notification for {device_token}.");

let device_token: NotificationToken = device_token.as_str().parse()?;
Expand All @@ -203,27 +228,28 @@ async fn notify_device(mut req: tide::Request<State>) -> tide::Result<tide::Resp
package_name,
token,
} => {
let client = req.state().fcm_client().clone();
let Ok(fcm_token) = req.state().fcm_token().await else {
return Ok(tide::Response::new(tide::StatusCode::InternalServerError));
let client = state.fcm_client().clone();
let Ok(fcm_token) = state.fcm_token().await else {
return Ok(StatusCode::INTERNAL_SERVER_ERROR);
};
let metrics = req.state().metrics();
let metrics = state.metrics();
notify_fcm(
&client,
fcm_token.as_deref(),
&package_name,
&token,
metrics,
)
.await
.await?;
}
NotificationToken::ApnsSandbox(token) => {
let client = req.state().sandbox_client().clone();
notify_apns(req, client, token).await
let client = state.sandbox_client().clone();
notify_apns(state, client, token).await?;
}
NotificationToken::ApnsProduction(token) => {
let client = req.state().production_client().clone();
notify_apns(req, client, token).await
let client = state.production_client().clone();
notify_apns(state, client, token).await?;
}
}
Ok(StatusCode::OK)
}
2 changes: 1 addition & 1 deletion src/state.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::io::Seek;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;

use a2::{Client, Endpoint};
use anyhow::{Context as _, Result};
use async_std::sync::Arc;

use crate::metrics::Metrics;
use crate::schedule::Schedule;
Expand Down

0 comments on commit a7d7312

Please sign in to comment.