Skip to content

Commit 0ba1e28

Browse files
Convert authenticated client to reqwest middleware (#488)
While initially I wanted to create a number of traits for reqwest types, as it turns out, there are better solutions. `reqwest_middleware` provides a flexible way to insert middleware that is activated either when the request is being built or just before it's sent. `AuthenticatedClient` is converted to a middleware type that handles auth. This enables authentication with auth in rattler itself as well as in `async_http_range_reader` through the same interface. One consequence is removal of the blocking client and features, as blocking comms are not supported in the middleware crate. I would argue that in most places where blocking communication is preferred -- e. g. simple programs performing a limited number of network requests -- it's often enough to just spawn a tokio runtime to wait on the future, because compared to network latency the cost will be negligible. Applications doing intensive network i/o are either async already anyway or likely should be. This is not yet ready for merge -- I'm not sure what to do with Python binding. Technically we could replicate this type using the middleware client, but it's a bit weird that it's removed from Rust API but stays in Python API. I will update the docs once there's a view on how to proceed.
1 parent 69b05d6 commit 0ba1e28

29 files changed

+439
-537
lines changed

crates/async_http_range_reader/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ itertools = "0.11.0"
1414
bisection = "0.1.0"
1515
memmap2 = "0.9.0"
1616
reqwest = { version = "0.11.22", default-features = false, features = ["stream"] }
17+
reqwest-middleware = "0.2.4"
1718
tokio = { version = "1.33.0", default-features = false }
1819
tokio-stream = { version = "0.1.14", features = ["sync"] }
1920
tokio-util = "0.7.9"

crates/async_http_range_reader/src/async_http_range_reader.rs

+20-14
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use futures::{FutureExt, Stream, StreamExt};
2121
use http_content_range::{ContentRange, ContentRangeBytes};
2222
use memmap2::MmapMut;
2323
use reqwest::header::HeaderMap;
24-
use reqwest::{Client, Response, Url};
24+
use reqwest::{Response, Url};
2525
use std::{
2626
io::{self, ErrorKind, SeekFrom},
2727
ops::Range,
@@ -68,7 +68,7 @@ pub use crate::async_http_range_reader_error::AsyncHttpRangeReaderError;
6868
/// if response.status() == reqwest::StatusCode::NOT_MODIFIED {
6969
/// Ok(None)
7070
/// } else {
71-
/// let reader = AsyncHttpRangeReader::from_head_response(client, response).await?;
71+
/// let reader = AsyncHttpRangeReader::from_head_response(client.into(), response).await?;
7272
/// Ok(Some(reader))
7373
/// }
7474
/// }
@@ -127,10 +127,16 @@ pub enum CheckSupportMethod {
127127
Head,
128128
}
129129

130+
fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
131+
response
132+
.error_for_status()
133+
.map_err(reqwest_middleware::Error::Reqwest)
134+
}
135+
130136
impl AsyncHttpRangeReader {
131137
/// Construct a new `AsyncHttpRangeReader`.
132138
pub async fn new(
133-
client: reqwest::Client,
139+
client: reqwest_middleware::ClientWithMiddleware,
134140
url: reqwest::Url,
135141
check_method: CheckSupportMethod,
136142
) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> {
@@ -162,7 +168,7 @@ impl AsyncHttpRangeReader {
162168
/// requests. This will return a number of bytes from the end of the stream. Use the
163169
/// `initial_chunk_size` parameter to define how many bytes should be requested from the end.
164170
pub async fn initial_tail_request(
165-
client: reqwest::Client,
171+
client: reqwest_middleware::ClientWithMiddleware,
166172
url: reqwest::Url,
167173
initial_chunk_size: u64,
168174
extra_headers: HeaderMap,
@@ -176,7 +182,7 @@ impl AsyncHttpRangeReader {
176182
.headers(extra_headers)
177183
.send()
178184
.await
179-
.and_then(Response::error_for_status)
185+
.and_then(error_for_status)
180186
.map_err(Arc::new)
181187
.map_err(AsyncHttpRangeReaderError::HttpError)?;
182188
Ok(tail_response)
@@ -185,7 +191,7 @@ impl AsyncHttpRangeReader {
185191
/// Initialize the reader from [`AsyncHttpRangeReader::initial_tail_request`] (or a user
186192
/// provided response that also has a range of bytes from the end as body)
187193
pub async fn from_tail_response(
188-
client: reqwest::Client,
194+
client: reqwest_middleware::ClientWithMiddleware,
189195
tail_request_response: Response,
190196
) -> Result<Self, AsyncHttpRangeReaderError> {
191197
// Get the size of the file from this initial request
@@ -260,7 +266,7 @@ impl AsyncHttpRangeReader {
260266
/// Send an initial range request to determine if the remote accepts range
261267
/// requests and get the content length
262268
pub async fn initial_head_request(
263-
client: reqwest::Client,
269+
client: reqwest_middleware::ClientWithMiddleware,
264270
url: reqwest::Url,
265271
extra_headers: HeaderMap,
266272
) -> Result<Response, AsyncHttpRangeReaderError> {
@@ -270,7 +276,7 @@ impl AsyncHttpRangeReader {
270276
.headers(extra_headers)
271277
.send()
272278
.await
273-
.and_then(Response::error_for_status)
279+
.and_then(error_for_status)
274280
.map_err(Arc::new)
275281
.map_err(AsyncHttpRangeReaderError::HttpError)?;
276282
Ok(head_response)
@@ -279,7 +285,7 @@ impl AsyncHttpRangeReader {
279285
/// Initialize the reader from [`AsyncHttpRangeReader::initial_head_request`] (or a user
280286
/// provided response the)
281287
pub async fn from_head_response(
282-
client: reqwest::Client,
288+
client: reqwest_middleware::ClientWithMiddleware,
283289
head_response: Response,
284290
) -> Result<Self, AsyncHttpRangeReaderError> {
285291
// Are range requests supported?
@@ -387,7 +393,7 @@ impl AsyncHttpRangeReader {
387393
/// become available.
388394
#[tracing::instrument(name = "fetch_ranges", skip_all, fields(url))]
389395
async fn run_streamer(
390-
client: Client,
396+
client: reqwest_middleware::ClientWithMiddleware,
391397
url: Url,
392398
initial_tail_response: Option<(Response, u64)>,
393399
mut memory_map: MmapMut,
@@ -447,7 +453,7 @@ async fn run_streamer(
447453
.send()
448454
.instrument(span)
449455
.await
450-
.and_then(Response::error_for_status)
456+
.and_then(error_for_status)
451457
.map_err(|e| std::io::Error::new(ErrorKind::Other, e))
452458
{
453459
Err(e) => {
@@ -649,7 +655,7 @@ mod test {
649655

650656
// Construct an AsyncRangeReader
651657
let (mut range, _) = AsyncHttpRangeReader::new(
652-
Client::new(),
658+
Client::new().into(),
653659
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
654660
check_method,
655661
)
@@ -743,7 +749,7 @@ mod test {
743749

744750
// Construct an AsyncRangeReader
745751
let (mut range, _) = AsyncHttpRangeReader::new(
746-
Client::new(),
752+
Client::new().into(),
747753
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
748754
check_method,
749755
)
@@ -784,7 +790,7 @@ mod test {
784790
async fn test_not_found() {
785791
let server = StaticDirectoryServer::new(Path::new(env!("CARGO_MANIFEST_DIR")));
786792
let err = AsyncHttpRangeReader::new(
787-
Client::new(),
793+
Client::new().into(),
788794
server.url().join("not-found").unwrap(),
789795
CheckSupportMethod::Head,
790796
)

crates/async_http_range_reader/src/async_http_range_reader_error.rs

+9-3
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ pub enum AsyncHttpRangeReaderError {
1010

1111
/// Other HTTP error
1212
#[error(transparent)]
13-
HttpError(#[from] Arc<reqwest::Error>),
13+
HttpError(#[from] Arc<reqwest_middleware::Error>),
1414

1515
/// An error occurred during transport
1616
#[error("an error occurred during transport: {0}")]
17-
TransportError(#[source] Arc<reqwest::Error>),
17+
TransportError(#[source] Arc<reqwest_middleware::Error>),
1818

1919
/// An IO error occurred
2020
#[error("io error occurred: {0}")]
@@ -39,8 +39,14 @@ impl From<std::io::Error> for AsyncHttpRangeReaderError {
3939
}
4040
}
4141

42+
impl From<reqwest_middleware::Error> for AsyncHttpRangeReaderError {
43+
fn from(err: reqwest_middleware::Error) -> Self {
44+
AsyncHttpRangeReaderError::TransportError(Arc::new(err))
45+
}
46+
}
47+
4248
impl From<reqwest::Error> for AsyncHttpRangeReaderError {
4349
fn from(err: reqwest::Error) -> Self {
44-
AsyncHttpRangeReaderError::TransportError(Arc::new(err))
50+
AsyncHttpRangeReaderError::TransportError(Arc::new(err.into()))
4551
}
4652
}

crates/rattler-bin/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ rattler_repodata_gateway = { version = "0.16.2", path = "../rattler_repodata_gat
3636
rattler_solve = { version = "0.16.2", path = "../rattler_solve", features = ["resolvo", "libsolv_c"] }
3737
rattler_virtual_packages = { version = "0.16.2", path = "../rattler_virtual_packages" }
3838
reqwest = { version = "0.11.22", default-features = false }
39+
reqwest-middleware = { version = "0.2.4" }
3940
tokio = { version = "1.32.0", features = ["rt-multi-thread", "macros"] }
4041
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
4142

crates/rattler-bin/src/commands/create.rs

+10-5
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@ use rattler_conda_types::{
1515
PrefixRecord, RepoDataRecord, Version,
1616
};
1717
use rattler_networking::{
18-
retry_policies::default_retry_policy, AuthenticatedClient, AuthenticationStorage,
18+
retry_policies::default_retry_policy, AuthenticationMiddleware, AuthenticationStorage,
1919
};
2020
use rattler_repodata_gateway::fetch::{
2121
CacheResult, DownloadProgress, FetchRepoDataError, FetchRepoDataOptions,
2222
};
2323
use rattler_repodata_gateway::sparse::SparseRepoData;
2424
use rattler_solve::{libsolv_c, resolvo, SolverImpl, SolverTask};
2525
use reqwest::Client;
26+
use std::sync::Arc;
2627
use std::{
2728
borrow::Cow,
2829
env,
@@ -117,8 +118,12 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
117118
.expect("failed to create client");
118119

119120
let authentication_storage = AuthenticationStorage::default();
121+
let download_client = reqwest_middleware::ClientBuilder::new(download_client)
122+
.with_arc(Arc::new(AuthenticationMiddleware::new(
123+
authentication_storage,
124+
)))
125+
.build();
120126

121-
let download_client = AuthenticatedClient::from_client(download_client, authentication_storage);
122127
let multi_progress = global_multi_progress();
123128

124129
let repodata_cache_path = cache_dir.join("repodata");
@@ -297,7 +302,7 @@ async fn execute_transaction(
297302
transaction: Transaction<PrefixRecord, RepoDataRecord>,
298303
target_prefix: PathBuf,
299304
cache_dir: PathBuf,
300-
download_client: AuthenticatedClient,
305+
download_client: reqwest_middleware::ClientWithMiddleware,
301306
) -> anyhow::Result<()> {
302307
// Open the package cache
303308
let package_cache = PackageCache::new(cache_dir.join("pkgs"));
@@ -385,7 +390,7 @@ async fn execute_transaction(
385390
#[allow(clippy::too_many_arguments)]
386391
async fn execute_operation(
387392
target_prefix: &Path,
388-
download_client: AuthenticatedClient,
393+
download_client: reqwest_middleware::ClientWithMiddleware,
389394
package_cache: &PackageCache,
390395
install_driver: &InstallDriver,
391396
download_pb: Option<&ProgressBar>,
@@ -551,7 +556,7 @@ async fn fetch_repo_data_records_with_progress(
551556
channel: Channel,
552557
platform: Platform,
553558
repodata_cache: &Path,
554-
client: AuthenticatedClient,
559+
client: reqwest_middleware::ClientWithMiddleware,
555560
multi_progress: indicatif::MultiProgress,
556561
) -> Result<Option<SparseRepoData>, anyhow::Error> {
557562
// Create a progress bar

crates/rattler/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ pin-project-lite = "0.2.13"
3636
rattler_conda_types = { version = "0.16.2", path = "../rattler_conda_types" }
3737
rattler_digest = { version = "0.16.2", path = "../rattler_digest" }
3838
rattler_networking = { version = "0.16.2", path = "../rattler_networking", default-features = false }
39-
rattler_package_streaming = { version = "0.16.2", path = "../rattler_package_streaming", features = ["reqwest", "tokio"], default-features = false }
39+
rattler_package_streaming = { version = "0.16.2", path = "../rattler_package_streaming", features = ["reqwest"], default-features = false }
4040
reflink-copy = "0.1.14"
4141
regex = "1.9.6"
4242
reqwest = { version = "0.11.22", default-features = false, features = ["stream", "json", "gzip"] }
43+
reqwest-middleware = "0.2.4"
4344
serde = { version = "1.0.188", features = ["derive"] }
4445
serde_json = { version = "1.0.107", features = ["raw_value"] }
4546
serde_with = "3.3.0"

crates/rattler/src/install/clobber_registry.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ mod tests {
278278
package::IndexJson, PackageRecord, Platform, PrefixRecord, RepoDataRecord,
279279
};
280280
use rattler_digest::{Md5, Sha256};
281-
use rattler_networking::{retry_policies::default_retry_policy, AuthenticatedClient};
281+
use rattler_networking::retry_policies::default_retry_policy;
282282
use rattler_package_streaming::seek::read_package_file;
283283
use transaction::{Transaction, TransactionOperation};
284284

@@ -368,7 +368,7 @@ mod tests {
368368

369369
async fn execute_operation(
370370
target_prefix: &Path,
371-
download_client: &AuthenticatedClient,
371+
download_client: &reqwest_middleware::ClientWithMiddleware,
372372
package_cache: &PackageCache,
373373
install_driver: &InstallDriver,
374374
op: TransactionOperation<PrefixRecord, RepoDataRecord>,
@@ -416,7 +416,7 @@ mod tests {
416416
async fn execute_transaction(
417417
transaction: Transaction<PrefixRecord, RepoDataRecord>,
418418
target_prefix: &Path,
419-
download_client: &AuthenticatedClient,
419+
download_client: &reqwest_middleware::ClientWithMiddleware,
420420
package_cache: &PackageCache,
421421
install_driver: &InstallDriver,
422422
install_options: &InstallOptions,
@@ -515,7 +515,7 @@ mod tests {
515515
execute_transaction(
516516
transaction,
517517
target_prefix.path(),
518-
&AuthenticatedClient::default(),
518+
&reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new()),
519519
&cache,
520520
&InstallDriver::default(),
521521
&InstallOptions::default(),
@@ -572,7 +572,7 @@ mod tests {
572572
execute_transaction(
573573
transaction,
574574
target_prefix.path(),
575-
&AuthenticatedClient::default(),
575+
&reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new()),
576576
&cache,
577577
&install_driver,
578578
&InstallOptions::default(),
@@ -637,7 +637,7 @@ mod tests {
637637
execute_transaction(
638638
transaction,
639639
target_prefix.path(),
640-
&AuthenticatedClient::default(),
640+
&reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new()),
641641
&cache,
642642
&InstallDriver::default(),
643643
&InstallOptions::default(),
@@ -684,7 +684,7 @@ mod tests {
684684
execute_transaction(
685685
transaction,
686686
target_prefix.path(),
687-
&AuthenticatedClient::default(),
687+
&reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new()),
688688
&cache,
689689
&InstallDriver::default(),
690690
&InstallOptions::default(),
@@ -735,7 +735,7 @@ mod tests {
735735
execute_transaction(
736736
transaction,
737737
target_prefix.path(),
738-
&AuthenticatedClient::default(),
738+
&reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new()),
739739
&cache,
740740
&install_driver,
741741
&InstallOptions::default(),

crates/rattler/src/install/mod.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,6 @@ mod test {
621621
use rattler_conda_types::package::ArchiveIdentifier;
622622
use rattler_conda_types::{ExplicitEnvironmentSpec, Platform, Version};
623623
use rattler_lock::LockFile;
624-
use rattler_networking::AuthenticatedClient;
625624

626625
use std::env::temp_dir;
627626
use std::process::Command;
@@ -683,7 +682,7 @@ mod test {
683682
let package_cache = PackageCache::new(temp_dir().join("rattler").join(cache_name));
684683

685684
// Create an HTTP client we can use to download packages
686-
let client = AuthenticatedClient::default();
685+
let client = reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new());
687686

688687
// Specify python version
689688
let python_version =

0 commit comments

Comments
 (0)