diff --git a/src/main.rs b/src/main.rs index 600e8d6..4ddfda9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,17 @@ #![warn(clippy::all, clippy::pedantic, clippy::nursery)] +use std::fmt::{Display, Formatter}; use std::{ net::{IpAddr, SocketAddr}, str::FromStr, sync::Arc, }; +use axum::extract::FromRequestParts; use axum::http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, CACHE_CONTROL}; +use axum::http::request::Parts; use axum::{ - extract::{ConnectInfo, Request, State}, + extract::{ConnectInfo, Request}, http::{HeaderMap, HeaderName, HeaderValue, StatusCode}, middleware::Next, response::{IntoResponse, Response}, @@ -23,9 +26,9 @@ async fn main() { .unwrap_or_else(|_| 8080.to_string()) .parse::() .unwrap_or(8080); + let client_ip_var = std::env::var("CLIENT_IP_HEADER").ok(); let v4_addr = SocketAddr::from(([0; 4], port)); let v6_addr = SocketAddr::from(([0; 16], port)); - let client_ip_var = std::env::var("CLIENT_IP_HEADER").ok(); let state = AppState::new(client_ip_var); let app = Router::new() .route("/raw", any(raw)) @@ -48,15 +51,10 @@ async fn svc(tcp: TcpListener, app: Router) { .unwrap(); } #[allow(clippy::unused_async)] -async fn home( - ConnectInfo(sock_addr): ConnectInfo, - headers: HeaderMap, - State(state): State, -) -> Result { +async fn home(IpAddress(ip): IpAddress, headers: HeaderMap) -> Result { let accept = headers .get("Accept") .map_or("*/*", |x| x.to_str().unwrap_or("invalid header value")); - let ip = get_ip(sock_addr, &headers, state)?; if accept.contains("text/html") { Ok(HtmlOrRaw::Html(include_str!("index.html"))) } else { @@ -65,28 +63,10 @@ async fn home( } #[allow(clippy::unused_async)] -async fn raw( - ConnectInfo(sock_addr): ConnectInfo, - headers: HeaderMap, - State(state): State, -) -> Result { - let ip = get_ip(sock_addr, &headers, state)?; +async fn raw(IpAddress(ip): IpAddress) -> Result { Ok(format!("{ip}\n")) } -fn get_ip(addr: SocketAddr, headers: &HeaderMap, state: AppState) -> Result { - if let Some(header_name) = state.header { - if let Some(header) = headers.get(&*header_name) { - let sock_str = header.to_str()?; - Ok(IpAddr::from_str(sock_str)?) - } else { - Err(Error::NoHeader) - } - } else { - Ok(addr.ip()) - } -} - static CORS_STAR: HeaderValue = HeaderValue::from_static("*"); async fn nocors(request: Request, next: Next) -> Response { @@ -145,13 +125,47 @@ impl IntoResponse for HtmlOrRaw { } } +#[derive(Clone, Debug)] +pub struct IpAddress(IpAddr); + +impl Display for IpAddress { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[axum::async_trait] +impl FromRequestParts for IpAddress { + type Rejection = Error; + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + if let Some(header_name) = state.header.clone() { + if let Some(header) = parts.headers.get(&*header_name) { + let sock_str = header.to_str()?; + Ok(Self(IpAddr::from_str(sock_str)?)) + } else { + Err(Error::NoHeader) + } + } else { + let conn_info: ConnectInfo = ConnectInfo::from_request_parts(parts, state) + .await + .map_err(|_| Error::ConnectInfo)?; + Ok(Self(conn_info.0.ip())) + } + } +} + #[derive(thiserror::Error, Debug)] pub enum Error { #[error("No header found")] NoHeader, - #[error("Could not convert supplied header to string")] + #[error("Could not extract connection info")] + ConnectInfo, + #[error("Could not convert supplied header to string (this is a configuration issue)")] ToStr(#[from] axum::http::header::ToStrError), - #[error("Could not convert supplied header to IP address")] + #[error("Could not convert supplied header to IP address (this is a configuration issue)")] ToAddr(#[from] std::net::AddrParseError), }