@@ -311,11 +311,16 @@ mod test {
311
311
routing:: get_service,
312
312
Router ,
313
313
} ;
314
+ use bytes:: Bytes ;
315
+ use futures:: stream;
314
316
use rattler_conda_types:: package:: { ArchiveIdentifier , PackageFile , PathsJson } ;
315
317
use rattler_networking:: retry_policies:: { DoNotRetryPolicy , ExponentialBackoffBuilder } ;
316
- use std:: { fs:: File , future:: IntoFuture , net:: SocketAddr , path:: Path , sync:: Arc } ;
318
+ use std:: {
319
+ convert:: Infallible , fs:: File , future:: IntoFuture , net:: SocketAddr , path:: Path , sync:: Arc ,
320
+ } ;
317
321
use tempfile:: tempdir;
318
322
use tokio:: sync:: Mutex ;
323
+ use tokio_stream:: StreamExt ;
319
324
use tower_http:: services:: ServeDir ;
320
325
use url:: Url ;
321
326
@@ -382,22 +387,69 @@ mod test {
382
387
Ok ( next. run ( req) . await )
383
388
}
384
389
385
- #[ tokio:: test]
386
- pub async fn test_flaky_package_cache ( ) {
390
+ /// A helper middleware function that fails the first two requests.
391
+ async fn fail_with_half_package (
392
+ State ( ( count, bytes) ) : State < ( Arc < Mutex < i32 > > , Arc < Mutex < usize > > ) > ,
393
+ req : Request < Body > ,
394
+ next : Next ,
395
+ ) -> Result < Response , StatusCode > {
396
+ let count = {
397
+ let mut count = count. lock ( ) . await ;
398
+ * count += 1 ;
399
+ * count
400
+ } ;
401
+
402
+ println ! ( "Running middleware for request #{count} for {}" , req. uri( ) ) ;
403
+ let response = next. run ( req) . await ;
404
+
405
+ if count <= 2 {
406
+ // println!("Cutting response body in half");
407
+ let body = response. into_body ( ) ;
408
+ let mut body = body. into_data_stream ( ) ;
409
+ let mut buffer = Vec :: new ( ) ;
410
+ while let Some ( Ok ( chunk) ) = body. next ( ) . await {
411
+ buffer. extend ( chunk) ;
412
+ }
413
+
414
+ let byte_count = bytes. lock ( ) . await . clone ( ) ;
415
+ let bytes = buffer. into_iter ( ) . take ( byte_count) . collect :: < Vec < u8 > > ( ) ;
416
+ // Create a stream that ends prematurely
417
+ let stream = stream:: iter ( vec ! [
418
+ Ok :: <_, Infallible >( Bytes :: from_iter( bytes. into_iter( ) ) ) ,
419
+ // The stream ends after sending partial data, simulating a premature close
420
+ ] ) ;
421
+ let body = Body :: from_stream ( stream) ;
422
+ return Ok ( Response :: new ( body) ) ;
423
+ }
424
+
425
+ Ok ( response)
426
+ }
427
+
428
+ enum Middleware {
429
+ FailTheFirstTwoRequests ,
430
+ FailAfterBytes ( usize ) ,
431
+ }
432
+
433
+ async fn test_flaky_package_cache ( archive_name : & str , middleware : Middleware ) {
387
434
let static_dir = get_test_data_dir ( ) ;
388
435
println ! ( "Serving files from {}" , static_dir. display( ) ) ;
389
436
// Construct a service that serves raw files from the test directory
390
437
let service = get_service ( ServeDir :: new ( static_dir) ) ;
391
438
392
439
// Construct a router that returns data from the static dir but fails the first try.
393
440
let request_count = Arc :: new ( Mutex :: new ( 0 ) ) ;
394
- let router =
395
- Router :: new ( )
396
- . route_service ( "/*key" , service)
397
- . layer ( middleware:: from_fn_with_state (
398
- request_count. clone ( ) ,
399
- fail_the_first_two_requests,
400
- ) ) ;
441
+ let router = Router :: new ( ) . route_service ( "/*key" , service) ;
442
+
443
+ let router = match middleware {
444
+ Middleware :: FailTheFirstTwoRequests => router. layer ( middleware:: from_fn_with_state (
445
+ request_count. clone ( ) ,
446
+ fail_the_first_two_requests,
447
+ ) ) ,
448
+ Middleware :: FailAfterBytes ( size) => router. layer ( middleware:: from_fn_with_state (
449
+ ( request_count. clone ( ) , Arc :: new ( Mutex :: new ( size) ) ) ,
450
+ fail_with_half_package,
451
+ ) ) ,
452
+ } ;
401
453
402
454
// Construct the server that will listen on localhost but with a *random port*. The random
403
455
// port is very important because it enables creating multiple instances at the same time.
@@ -412,7 +464,6 @@ mod test {
412
464
let packages_dir = tempdir ( ) . unwrap ( ) ;
413
465
let cache = PackageCache :: new ( packages_dir. path ( ) ) ;
414
466
415
- let archive_name = "ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2" ;
416
467
let server_url = Url :: parse ( & format ! ( "http://localhost:{}" , addr. port( ) ) ) . unwrap ( ) ;
417
468
418
469
// Do the first request without
@@ -448,4 +499,17 @@ mod test {
448
499
assert_eq ! ( * request_count_lock, 3 , "Expected there to be 3 requests" ) ;
449
500
}
450
501
}
502
+
503
+ #[ tokio:: test]
504
+ async fn test_flaky ( ) {
505
+ let tar_bz2 = "ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2" ;
506
+ let conda = "conda-22.11.1-py38haa244fe_1.conda" ;
507
+
508
+ test_flaky_package_cache ( tar_bz2, Middleware :: FailTheFirstTwoRequests ) . await ;
509
+ test_flaky_package_cache ( conda, Middleware :: FailTheFirstTwoRequests ) . await ;
510
+
511
+ test_flaky_package_cache ( tar_bz2, Middleware :: FailAfterBytes ( 1000 ) ) . await ;
512
+ test_flaky_package_cache ( conda, Middleware :: FailAfterBytes ( 1000 ) ) . await ;
513
+ test_flaky_package_cache ( conda, Middleware :: FailAfterBytes ( 50 ) ) . await ;
514
+ }
451
515
}
0 commit comments