Skip to content

Commit c1b7ff2

Browse files
authored
fix: flaky package extract error (#535)
This adds a test to check that the flaky download behavior has been fixed. This was caused by a Read implementation to consistently return an error at some point which caused the zip crate to panic.
1 parent 54a0e43 commit c1b7ff2

File tree

5 files changed

+142
-20
lines changed

5 files changed

+142
-20
lines changed

crates/rattler/src/package_cache.rs

+75-11
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,16 @@ mod test {
311311
routing::get_service,
312312
Router,
313313
};
314+
use bytes::Bytes;
315+
use futures::stream;
314316
use rattler_conda_types::package::{ArchiveIdentifier, PackageFile, PathsJson};
315317
use rattler_networking::retry_policies::{DoNotRetryPolicy, ExponentialBackoffBuilder};
316-
use std::{fs::File, future::IntoFuture, net::SocketAddr, path::Path, sync::Arc};
318+
use std::{
319+
convert::Infallible, fs::File, future::IntoFuture, net::SocketAddr, path::Path, sync::Arc,
320+
};
317321
use tempfile::tempdir;
318322
use tokio::sync::Mutex;
323+
use tokio_stream::StreamExt;
319324
use tower_http::services::ServeDir;
320325
use url::Url;
321326

@@ -382,22 +387,69 @@ mod test {
382387
Ok(next.run(req).await)
383388
}
384389

385-
#[tokio::test]
386-
pub async fn test_flaky_package_cache() {
390+
/// A helper middleware function that fails the first two requests.
391+
async fn fail_with_half_package(
392+
State((count, bytes)): State<(Arc<Mutex<i32>>, Arc<Mutex<usize>>)>,
393+
req: Request<Body>,
394+
next: Next,
395+
) -> Result<Response, StatusCode> {
396+
let count = {
397+
let mut count = count.lock().await;
398+
*count += 1;
399+
*count
400+
};
401+
402+
println!("Running middleware for request #{count} for {}", req.uri());
403+
let response = next.run(req).await;
404+
405+
if count <= 2 {
406+
// println!("Cutting response body in half");
407+
let body = response.into_body();
408+
let mut body = body.into_data_stream();
409+
let mut buffer = Vec::new();
410+
while let Some(Ok(chunk)) = body.next().await {
411+
buffer.extend(chunk);
412+
}
413+
414+
let byte_count = bytes.lock().await.clone();
415+
let bytes = buffer.into_iter().take(byte_count).collect::<Vec<u8>>();
416+
// Create a stream that ends prematurely
417+
let stream = stream::iter(vec![
418+
Ok::<_, Infallible>(Bytes::from_iter(bytes.into_iter())),
419+
// The stream ends after sending partial data, simulating a premature close
420+
]);
421+
let body = Body::from_stream(stream);
422+
return Ok(Response::new(body));
423+
}
424+
425+
Ok(response)
426+
}
427+
428+
enum Middleware {
429+
FailTheFirstTwoRequests,
430+
FailAfterBytes(usize),
431+
}
432+
433+
async fn test_flaky_package_cache(archive_name: &str, middleware: Middleware) {
387434
let static_dir = get_test_data_dir();
388435
println!("Serving files from {}", static_dir.display());
389436
// Construct a service that serves raw files from the test directory
390437
let service = get_service(ServeDir::new(static_dir));
391438

392439
// Construct a router that returns data from the static dir but fails the first try.
393440
let request_count = Arc::new(Mutex::new(0));
394-
let router =
395-
Router::new()
396-
.route_service("/*key", service)
397-
.layer(middleware::from_fn_with_state(
398-
request_count.clone(),
399-
fail_the_first_two_requests,
400-
));
441+
let router = Router::new().route_service("/*key", service);
442+
443+
let router = match middleware {
444+
Middleware::FailTheFirstTwoRequests => router.layer(middleware::from_fn_with_state(
445+
request_count.clone(),
446+
fail_the_first_two_requests,
447+
)),
448+
Middleware::FailAfterBytes(size) => router.layer(middleware::from_fn_with_state(
449+
(request_count.clone(), Arc::new(Mutex::new(size))),
450+
fail_with_half_package,
451+
)),
452+
};
401453

402454
// Construct the server that will listen on localhost but with a *random port*. The random
403455
// port is very important because it enables creating multiple instances at the same time.
@@ -412,7 +464,6 @@ mod test {
412464
let packages_dir = tempdir().unwrap();
413465
let cache = PackageCache::new(packages_dir.path());
414466

415-
let archive_name = "ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2";
416467
let server_url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap();
417468

418469
// Do the first request without
@@ -448,4 +499,17 @@ mod test {
448499
assert_eq!(*request_count_lock, 3, "Expected there to be 3 requests");
449500
}
450501
}
502+
503+
#[tokio::test]
504+
async fn test_flaky() {
505+
let tar_bz2 = "ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2";
506+
let conda = "conda-22.11.1-py38haa244fe_1.conda";
507+
508+
test_flaky_package_cache(tar_bz2, Middleware::FailTheFirstTwoRequests).await;
509+
test_flaky_package_cache(conda, Middleware::FailTheFirstTwoRequests).await;
510+
511+
test_flaky_package_cache(tar_bz2, Middleware::FailAfterBytes(1000)).await;
512+
test_flaky_package_cache(conda, Middleware::FailAfterBytes(1000)).await;
513+
test_flaky_package_cache(conda, Middleware::FailAfterBytes(50)).await;
514+
}
451515
}

crates/rattler_package_streaming/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@ tokio = { workspace = true, features = ["rt", "macros"] }
4343
walkdir = { workspace = true }
4444
rstest = { workspace = true }
4545
rstest_reuse = { workspace = true }
46+
assert_matches = { workspace = true }

crates/rattler_package_streaming/src/lib.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//! This crate provides the ability to extract a Conda package archive or specific parts of it.
44
55
use std::path::PathBuf;
6+
use zip::result::ZipError;
67

78
use rattler_digest::{Md5Hash, Sha256Hash};
89

@@ -27,7 +28,7 @@ pub enum ExtractError {
2728
CouldNotCreateDestination(#[source] std::io::Error),
2829

2930
#[error("invalid zip archive")]
30-
ZipError(#[from] zip::result::ZipError),
31+
ZipError(#[source] zip::result::ZipError),
3132

3233
#[error("a component is missing from the Conda archive")]
3334
MissingComponent,
@@ -49,6 +50,15 @@ pub enum ExtractError {
4950
ArchiveMemberParseError(PathBuf, #[source] std::io::Error),
5051
}
5152

53+
impl From<ZipError> for ExtractError {
54+
fn from(value: ZipError) -> Self {
55+
match value {
56+
ZipError::Io(io) => Self::IoError(io),
57+
e => Self::ZipError(e),
58+
}
59+
}
60+
}
61+
5262
#[cfg(feature = "reqwest")]
5363
impl From<::reqwest::Error> for ExtractError {
5464
fn from(err: ::reqwest::Error) -> Self {

crates/rattler_package_streaming/src/read.rs

+14-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
//! [`std::io::Read`] trait.
33
44
use super::{ExtractError, ExtractResult};
5+
use std::mem::ManuallyDrop;
56
use std::{ffi::OsStr, io::Read, path::Path};
67
use zip::read::read_zipfile_from_stream;
78

@@ -55,24 +56,29 @@ pub fn extract_conda(reader: impl Read, destination: &Path) -> Result<ExtractRes
5556

5657
// Iterate over all entries in the zip-file and extract them one-by-one
5758
while let Some(file) = read_zipfile_from_stream(&mut md5_reader)? {
59+
// If an error occurs while we are reading the contents of the zip we don't want to
60+
// seek to the end of the file. Using [`ManuallyDrop`] we prevent `drop` to be called on
61+
// the `file` in case the stack unwinds.
62+
let mut file = ManuallyDrop::new(file);
63+
5864
if file
5965
.mangled_name()
6066
.file_name()
6167
.map(OsStr::to_string_lossy)
6268
.map_or(false, |file_name| file_name.ends_with(".tar.zst"))
6369
{
64-
stream_tar_zst(file)?.unpack(destination)?;
70+
stream_tar_zst(&mut *file)?.unpack(destination)?;
71+
} else {
72+
// Manually read to the end of the stream if that didn't happen.
73+
std::io::copy(&mut *file, &mut std::io::sink())?;
6574
}
75+
76+
// Take the file out of the [`ManuallyDrop`] to properly drop it.
77+
let _ = ManuallyDrop::into_inner(file);
6678
}
6779

6880
// Read the file to the end to make sure the hash is properly computed.
69-
let mut buf = [0; 1 << 14];
70-
loop {
71-
let bytes_read = md5_reader.read(&mut buf)?;
72-
if bytes_read == 0 {
73-
break;
74-
}
75-
}
81+
std::io::copy(&mut md5_reader, &mut std::io::sink())?;
7682

7783
// Get the hashes
7884
let (sha256_reader, md5) = md5_reader.finalize();

crates/rattler_package_streaming/tests/extract.rs

+41
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use rattler_conda_types::package::IndexJson;
22
use rattler_package_streaming::read::{extract_conda, extract_tar_bz2};
3+
use rattler_package_streaming::ExtractError;
34
use rstest::rstest;
45
use rstest_reuse::{self, apply, template};
56
use std::fs::File;
7+
use std::io::Read;
68
use std::path::{Path, PathBuf};
79

810
fn test_data_dir() -> PathBuf {
@@ -235,3 +237,42 @@ async fn test_extract_url_async(#[case] url: &str, #[case] sha256: &str, #[case]
235237
assert_eq!(&format!("{:x}", result.sha256), sha256);
236238
assert_eq!(&format!("{:x}", result.md5), md5);
237239
}
240+
241+
#[rstest]
242+
fn test_extract_flaky_conda(#[values(0, 1, 13, 50, 74, 150, 8096, 16384, 20000)] cutoff: usize) {
243+
let input = "conda-22.11.1-py38haa244fe_1.conda";
244+
let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR"));
245+
println!("Target dir: {}", temp_dir.display());
246+
let file_path = Path::new(input);
247+
let target_dir = temp_dir.join(file_path.file_stem().unwrap());
248+
let result = extract_conda(
249+
FlakyReader {
250+
reader: File::open(test_data_dir().join(file_path)).unwrap(),
251+
total_read: 0,
252+
cutoff,
253+
},
254+
&target_dir,
255+
)
256+
.expect_err("this should error out and not panic");
257+
258+
assert_matches::assert_matches!(result, ExtractError::IoError(_));
259+
}
260+
261+
struct FlakyReader<R: Read> {
262+
reader: R,
263+
cutoff: usize,
264+
total_read: usize,
265+
}
266+
267+
impl<R: Read> Read for FlakyReader<R> {
268+
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
269+
let remaining = self.cutoff.saturating_sub(self.total_read);
270+
if remaining == 0 {
271+
return Err(std::io::Error::new(std::io::ErrorKind::Other, "flaky"));
272+
}
273+
let max_read = buf.len().min(remaining);
274+
let bytes_read = self.reader.read(&mut buf[..max_read])?;
275+
self.total_read += bytes_read;
276+
Ok(bytes_read)
277+
}
278+
}

0 commit comments

Comments
 (0)