From 2aef0751ba370cb03a99ecadede78680618d00ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Ml=C3=A1dek?= Date: Fri, 31 Jan 2025 18:00:29 +0100 Subject: [PATCH] axum: allow body types other than `axum::body::Body` in `Service`s passed to `serve` --- axum/CHANGELOG.md | 5 +++ axum/src/serve/mod.rs | 80 +++++++++++++++++++++++++++++++++---------- 2 files changed, 66 insertions(+), 19 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 9749052ff3..eb22c7741a 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +# Unreleased + +- **changed:** `serve` has an additional generic argument and can now work with any response body + type, not just `axum::body::Body` ([3205]) + # 0.8.2 - **added:** Implement `OptionalFromRequest` for `Json` ([#3142]) diff --git a/axum/src/serve/mod.rs b/axum/src/serve/mod.rs index ca673b44eb..156626b106 100644 --- a/axum/src/serve/mod.rs +++ b/axum/src/serve/mod.rs @@ -2,6 +2,7 @@ use std::{ convert::Infallible, + error::Error as StdError, fmt::Debug, future::{poll_fn, Future, IntoFuture}, io, @@ -11,6 +12,7 @@ use std::{ use axum_core::{body::Body, extract::Request, response::Response}; use futures_util::{pin_mut, FutureExt}; +use http_body::Body as HttpBody; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(any(feature = "http1", feature = "http2"))] @@ -94,12 +96,15 @@ pub use self::listener::{Listener, ListenerExt, TapIo}; /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -pub fn serve(listener: L, make_service: M) -> Serve +pub fn serve(listener: L, make_service: M) -> Serve where L: Listener, M: for<'a> Service, Error = Infallible, Response = S>, - S: Service + Clone + Send + 'static, + S: Service, Error = Infallible> + Clone + Send + 'static, S::Future: Send, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, { Serve { listener, @@ -111,14 +116,14 @@ where /// Future returned by [`serve`]. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[must_use = "futures must be awaited or polled"] -pub struct Serve { +pub struct Serve { listener: L, make_service: M, - _marker: PhantomData, + _marker: PhantomData<(S, B)>, } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Serve +impl Serve where L: Listener, { @@ -148,7 +153,7 @@ where /// /// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never /// error. It returns `Ok(())` only after the `signal` future completes. - pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown + pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown where F: Future + Send + 'static, { @@ -167,7 +172,7 @@ where } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for Serve +impl Debug for Serve where L: Debug + 'static, M: Debug, @@ -188,14 +193,17 @@ where } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for Serve +impl IntoFuture for Serve where L: Listener, L::Addr: Debug, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, - S: Service + Clone + Send + 'static, + S: Service, Error = Infallible> + Clone + Send + 'static, S::Future: Send, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, { type Output = io::Result<()>; type IntoFuture = private::ServeFuture; @@ -209,15 +217,15 @@ where /// Serve future with graceful shutdown enabled. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[must_use = "futures must be awaited or polled"] -pub struct WithGracefulShutdown { +pub struct WithGracefulShutdown { listener: L, make_service: M, signal: F, - _marker: PhantomData, + _marker: PhantomData<(S, B)>, } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl WithGracefulShutdown +impl WithGracefulShutdown where L: Listener, { @@ -228,7 +236,7 @@ where } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for WithGracefulShutdown +impl Debug for WithGracefulShutdown where L: Debug + 'static, M: Debug, @@ -252,15 +260,18 @@ where } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for WithGracefulShutdown +impl IntoFuture for WithGracefulShutdown where L: Listener, L::Addr: Debug, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, - S: Service + Clone + Send + 'static, + S: Service, Error = Infallible> + Clone + Send + 'static, S::Future: Send, F: Future + Send + 'static, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, { type Output = io::Result<()>; type IntoFuture = private::ServeFuture; @@ -274,15 +285,18 @@ where } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl WithGracefulShutdown +impl WithGracefulShutdown where L: Listener, L::Addr: Debug, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, - S: Service + Clone + Send + 'static, + S: Service, Error = Infallible> + Clone + Send + 'static, S::Future: Send, F: Future + Send + 'static, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, { async fn run(self) { let Self { @@ -439,7 +453,7 @@ mod tests { }; use axum_core::{body::Body, extract::Request}; - use http::StatusCode; + use http::{Response, StatusCode}; use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::UnixListener; @@ -447,6 +461,7 @@ mod tests { io::{self, AsyncRead, AsyncWrite}, net::TcpListener, }; + use tower::ServiceBuilder; #[cfg(unix)] use super::IncomingStream; @@ -458,7 +473,7 @@ mod tests { handler::{Handler, HandlerWithoutStateExt}, routing::get, serve::ListenerExt, - Router, + Router, ServiceExt, }; #[allow(dead_code, unused_must_use)] @@ -686,4 +701,31 @@ mod tests { let body = String::from_utf8(body.to_vec()).unwrap(); assert_eq!(body, "Hello, World!"); } + + #[crate::test] + async fn serving_with_custom_body_type() { + struct CustomBody; + impl http_body::Body for CustomBody { + type Data = bytes::Bytes; + type Error = std::convert::Infallible; + fn poll_frame( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> + { + #![allow(clippy::unreachable)] // The implementation is not used, we just need to provide one. + unreachable!(); + } + } + + let app = ServiceBuilder::new() + .layer_fn(|_| tower::service_fn(|_| std::future::ready(Ok(Response::new(CustomBody))))) + .service(Router::<()>::new().route("/hello", get(|| async {}))); + let addr = "0.0.0.0:0"; + + _ = serve( + TcpListener::bind(addr).await.unwrap(), + app.into_make_service(), + ); + } }