From f6f514d7f1b7d1fddab60336e6a29fd901f05591 Mon Sep 17 00:00:00 2001 From: Fangdun Tsai Date: Tue, 1 Oct 2024 02:32:51 +0800 Subject: [PATCH] refactor(core): use BodyDataStream --- viz-core/src/body.rs | 42 +------------------------- viz-core/src/middleware/compression.rs | 7 ++++- viz-core/src/request.rs | 7 +++-- viz-core/src/types/multipart.rs | 3 +- viz-core/tests/response.rs | 4 +-- viz-test/tests/body.rs | 33 ++++++++++++-------- viz-test/tests/response.rs | 2 +- 7 files changed, 38 insertions(+), 60 deletions(-) diff --git a/viz-core/src/body.rs b/viz-core/src/body.rs index 52d82500..9426b74b 100644 --- a/viz-core/src/body.rs +++ b/viz-core/src/body.rs @@ -4,7 +4,7 @@ use std::{ }; use bytes::Bytes; -use futures_util::{Stream, TryStream, TryStreamExt}; +use futures_util::{TryStream, TryStreamExt}; use http_body_util::{combinators::UnsyncBoxBody, BodyExt, BodyStream, Full, StreamBody}; use hyper::body::{Frame, Incoming, SizeHint}; use sync_wrapper::SyncWrapper; @@ -53,7 +53,6 @@ impl Body { B::Data: Into, B::Error: Into, { - // Copied from Axum, thanks. let mut body = Some(body); ::downcast_mut::>>(&mut body) .and_then(Option::take) @@ -133,45 +132,6 @@ impl HttpBody for Body { } } -impl Stream for Body { - type Item = Result; - - #[inline] - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match match self.get_mut() { - Self::Empty => return Poll::Ready(None), - Self::Full(inner) => Pin::new(inner) - .poll_frame(cx) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, - Self::Boxed(inner) => Pin::new(inner) - .get_pin_mut() - .poll_frame(cx) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, - Self::Incoming(inner) => Pin::new(inner) - .poll_frame(cx) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, - } { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(frame)) => Poll::Ready(frame.into_data().map(Ok).ok()), - } - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - let sh = match self { - Self::Empty => return (0, Some(0)), - Self::Full(inner) => inner.size_hint(), - Self::Boxed(_) => return (0, None), - Self::Incoming(inner) => inner.size_hint(), - }; - ( - usize::try_from(sh.lower()).unwrap_or(usize::MAX), - sh.upper().map(|v| usize::try_from(v).unwrap_or(usize::MAX)), - ) - } -} - impl From<()> for Body { fn from((): ()) -> Self { Self::Empty diff --git a/viz-core/src/middleware/compression.rs b/viz-core/src/middleware/compression.rs index cf623d70..2b5b4885 100644 --- a/viz-core/src/middleware/compression.rs +++ b/viz-core/src/middleware/compression.rs @@ -3,6 +3,8 @@ use std::str::FromStr; use async_compression::tokio::bufread; +use futures_util::TryStreamExt; +use http_body_util::BodyExt; use tokio_util::io::{ReaderStream, StreamReader}; use crate::{ @@ -78,7 +80,10 @@ impl IntoResponse for Compress { match self.algo { ContentCoding::Gzip | ContentCoding::Deflate | ContentCoding::Brotli => { res = res.map(|body| { - let body = StreamReader::new(body); + let body = StreamReader::new( + body.into_data_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)), + ); if self.algo == ContentCoding::Gzip { Body::from_stream(ReaderStream::new(bufread::GzipEncoder::new(body))) } else if self.algo == ContentCoding::Deflate { diff --git a/viz-core/src/request.rs b/viz-core/src/request.rs index c248842c..a672703f 100644 --- a/viz-core/src/request.rs +++ b/viz-core/src/request.rs @@ -308,7 +308,10 @@ impl RequestExt for Request { .ok_or(PayloadError::MissingBoundary)? .as_str(); - Ok(Multipart::new(self.incoming()?, boundary)) + Ok(Multipart::new( + self.incoming()?.into_data_stream(), + boundary, + )) } #[cfg(feature = "state")] @@ -495,7 +498,7 @@ impl RequestLimitsExt for Request { .ok_or(PayloadError::MissingBoundary)? .as_str(); Ok(Multipart::with_limits( - self.incoming()?, + self.incoming()?.into_data_stream(), boundary, self.extensions() .get::>() diff --git a/viz-core/src/types/multipart.rs b/viz-core/src/types/multipart.rs index 5693a186..8f020db6 100644 --- a/viz-core/src/types/multipart.rs +++ b/viz-core/src/types/multipart.rs @@ -1,6 +1,7 @@ //! Represents a Multipart extractor. use form_data::FormData; +use http_body_util::BodyDataStream; use crate::{Body, Error, FromRequest, IntoResponse, Request, RequestExt, Response, StatusCode}; @@ -9,7 +10,7 @@ use super::{Payload, PayloadError}; pub use form_data::{Error as MultipartError, Limits as MultipartLimits}; /// Extracts the data from the multipart body of a request. -pub type Multipart = FormData; +pub type Multipart> = FormData; impl Payload for Multipart { const NAME: &'static str = "multipart"; diff --git a/viz-core/tests/response.rs b/viz-core/tests/response.rs index ea6e3b1e..6a8c459c 100644 --- a/viz-core/tests/response.rs +++ b/viz-core/tests/response.rs @@ -7,7 +7,7 @@ use test::Bencher; use futures_util::{stream, Stream, StreamExt}; use headers::{ContentDisposition, ContentType, HeaderMapExt}; -use http_body_util::{BodyExt, Full}; +use http_body_util::{BodyDataStream, BodyExt, Full}; use serde::{Deserialize, Serialize}; use viz_core::{ header::{CONTENT_DISPOSITION, CONTENT_LOCATION, LOCATION}, @@ -78,7 +78,7 @@ async fn response_ext() -> Result<()> { let resp = Response::stream(stream::repeat("viz").take(2).map(Result::<_, Error>::Ok)); assert!(resp.ok()); - let body: Body = resp.into_body(); + let body: BodyDataStream<_> = resp.into_body().into_data_stream(); assert_eq!(Stream::size_hint(&body), (0, None)); let (item, stream) = body.into_future().await; assert_eq!(item.unwrap().unwrap().to_vec(), b"viz"); diff --git a/viz-test/tests/body.rs b/viz-test/tests/body.rs index 60f3aff6..7194d990 100644 --- a/viz-test/tests/body.rs +++ b/viz-test/tests/body.rs @@ -63,18 +63,18 @@ async fn incoming_stream() -> Result<()> { use viz::Router; use viz_test::TestServer; - let empty = Body::Empty; + let empty = Body::Empty + .into_data_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); assert_eq!(Stream::size_hint(&empty), (0, Some(0))); - let mut reader = - TryStreamExt::map_err(empty, |e| std::io::Error::new(std::io::ErrorKind::Other, e)) - .into_async_read(); + let mut reader = empty.into_async_read(); let mut buf = Vec::new(); reader.read_to_end(&mut buf).await?; assert!(buf.is_empty()); let router = Router::new() .post("/login-empty", |mut req: Request| async move { - let mut body = req.incoming()?; + let mut body = req.incoming()?.into_data_stream(); let size_hint = Stream::size_hint(&body); assert_eq!(size_hint.0, 0); assert_eq!(size_hint.1, Some(0)); @@ -82,7 +82,7 @@ async fn incoming_stream() -> Result<()> { Ok(()) }) .post("/login", |mut req: Request| async move { - let mut body = req.incoming()?; + let mut body = req.incoming()?.into_data_stream(); let size_hint = Stream::size_hint(&body); assert_eq!(size_hint.0, 12); assert_eq!(size_hint.1, Some(12)); @@ -195,35 +195,44 @@ async fn outgoing_body() -> Result<()> { async fn outgoing_stream() -> Result<()> { use futures_util::{AsyncReadExt, Stream, StreamExt, TryStreamExt}; - let empty = Body::::Empty; + let empty = Body::::Empty + .into_data_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); assert_eq!(Stream::size_hint(&empty), (0, Some(0))); let mut reader = empty.into_async_read(); let mut buf = Vec::new(); reader.read_to_end(&mut buf).await?; assert!(buf.is_empty()); - let full_none = Body::from(Full::new(Bytes::new())); + let full_none = Body::from(Full::new(Bytes::new())) + .into_data_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); assert_eq!(Stream::size_hint(&full_none), (0, Some(0))); let mut reader = full_none.into_async_read(); let mut buf = Vec::new(); reader.read_to_end(&mut buf).await?; assert!(buf.is_empty()); - let mut full_some: Body = Full::new(Bytes::from(vec![1, 0, 2, 4])).into(); + let mut full_some = Full::new(Bytes::from(vec![1, 0, 2, 4])) + .into_data_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); assert_eq!(Stream::size_hint(&full_some), (4, Some(4))); assert_eq!(full_some.next().await.unwrap().unwrap(), vec![1, 0, 2, 4]); assert_eq!(Stream::size_hint(&full_some), (0, Some(0))); assert!(full_some.next().await.is_none()); - let boxed: Body = UnsyncBoxBody::new(Full::new(Bytes::new()).map_err(Into::into)).into(); + let boxed = UnsyncBoxBody::new(Full::new(Bytes::new())) + .into_data_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); assert_eq!(Stream::size_hint(&boxed), (0, None)); let mut reader = boxed.into_async_read(); let mut buf = Vec::new(); reader.read_to_end(&mut buf).await?; assert!(buf.is_empty()); - let mut boxed: Body = - UnsyncBoxBody::new(Full::new(Bytes::from(vec![2, 0, 4, 8])).map_err(Into::into)).into(); + let mut boxed = UnsyncBoxBody::new(Full::new(Bytes::from(vec![2, 0, 4, 8]))) + .into_data_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); assert_eq!(Stream::size_hint(&boxed), (0, None)); assert_eq!(boxed.next().await.unwrap().unwrap(), vec![2, 0, 4, 8]); assert_eq!(Stream::size_hint(&boxed), (0, None)); diff --git a/viz-test/tests/response.rs b/viz-test/tests/response.rs index d32a595b..743a85d6 100644 --- a/viz-test/tests/response.rs +++ b/viz-test/tests/response.rs @@ -71,7 +71,7 @@ async fn response_ext() -> Result<()> { let resp = Response::stream(stream::repeat("viz").take(2).map(Result::<_, Error>::Ok)); assert!(resp.ok()); - let body: Body = resp.into_body(); + let body = resp.into_body().into_data_stream(); assert_eq!(Stream::size_hint(&body), (0, None)); let (item, stream) = body.into_future().await; assert_eq!(item.unwrap().unwrap().to_vec(), b"viz");