diff --git a/Cargo.lock b/Cargo.lock index 4e335b2aa..5340a2649 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4941,13 +4941,16 @@ dependencies = [ "pixi_consts", "rattler_conda_types", "rattler_digest", + "rayon", "reqwest", "reqwest-middleware", "reqwest-retry", "serde", "serde_json", "tokio", + "tracing", "url", + "uv-configuration", ] [[package]] diff --git a/crates/pypi_mapping/Cargo.toml b/crates/pypi_mapping/Cargo.toml index b2bc69e14..2fc5e5e61 100644 --- a/crates/pypi_mapping/Cargo.toml +++ b/crates/pypi_mapping/Cargo.toml @@ -22,10 +22,13 @@ pixi_config = { workspace = true } pixi_consts = { workspace = true } rattler_conda_types = { workspace = true } rattler_digest = { workspace = true } +rayon = "1.10.0" reqwest = { workspace = true, features = ["json"] } reqwest-middleware = { workspace = true } reqwest-retry = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tokio = { workspace = true } +tracing.workspace = true url = { workspace = true } +uv-configuration = { workspace = true } diff --git a/crates/pypi_mapping/src/custom_pypi_mapping.rs b/crates/pypi_mapping/src/custom_pypi_mapping.rs index bed1531c3..ea26582c4 100644 --- a/crates/pypi_mapping/src/custom_pypi_mapping.rs +++ b/crates/pypi_mapping/src/custom_pypi_mapping.rs @@ -96,8 +96,7 @@ pub async fn amend_pypi_purls( client, &packages_for_prefix_mapping, reporter, - ) - .await?; + )?; let compressed_mapping = prefix_pypi_name_mapping::conda_pypi_name_compressed_mapping(client).await?; diff --git a/crates/pypi_mapping/src/prefix_pypi_name_mapping.rs b/crates/pypi_mapping/src/prefix_pypi_name_mapping.rs index 7a31c4414..4919ae748 100644 --- a/crates/pypi_mapping/src/prefix_pypi_name_mapping.rs +++ b/crates/pypi_mapping/src/prefix_pypi_name_mapping.rs @@ -1,9 +1,8 @@ use std::{ - collections::{BTreeSet, HashMap}, - sync::Arc, + cell::RefCell, collections::{BTreeSet, HashMap}, sync::{Arc, LazyLock, Mutex} }; +use rayon::prelude::*; -use futures::{stream::FuturesUnordered, StreamExt}; use itertools::Itertools; use miette::{IntoDiagnostic, WrapErr}; use rattler_conda_types::{PackageUrl, RepoDataRecord}; @@ -11,14 +10,19 @@ use rattler_digest::Sha256Hash; use reqwest::StatusCode; use reqwest_middleware::ClientWithMiddleware; use serde::{Deserialize, Serialize}; -use tokio::sync::Semaphore; +use tokio::runtime::Runtime; use url::Url; +use uv_configuration::RAYON_INITIALIZE; use super::{ build_pypi_purl_from_package_record, custom_pypi_mapping, is_conda_forge_record, PurlSource, Reporter, }; +thread_local! { + static TOKIO_RT: RefCell> = RefCell::new(None); +} + const STORAGE_URL: &str = "https://conda-mapping.prefix.dev"; const HASH_DIR: &str = "hash-v0"; const COMPRESSED_MAPPING: &str = @@ -63,11 +67,15 @@ async fn try_fetch_single_mapping( } /// Downloads and caches the conda-forge conda-to-pypi name mapping. -pub async fn conda_pypi_name_mapping<'r>( +pub fn conda_pypi_name_mapping<'r>( client: &ClientWithMiddleware, conda_packages: impl IntoIterator, reporter: Option>, ) -> miette::Result> { + // Force the initialization of the rayon thread pool to avoid implicit creation + // by the Installer. + LazyLock::force(&RAYON_INITIALIZE); + let filtered_packages = conda_packages .into_iter() // because we later skip adding purls for packages @@ -85,63 +93,72 @@ pub async fn conda_pypi_name_mapping<'r>( .collect_vec(); let total_records = filtered_packages.len(); - let mut pending_futures = FuturesUnordered::new(); - let concurrency_limit = Arc::new(Semaphore::new(100)); - for (record, hash) in filtered_packages { + let result_map = Arc::new(Mutex::new(HashMap::with_capacity(total_records))); + let error = Arc::new(Mutex::new(None)); + + tracing::info!("Downloading conda-pypi mapping for {} packages", total_records); + filtered_packages.par_iter().for_each(|(record, hash)| { + // Check if we've already encountered an error + if error.lock().unwrap().is_some() { + return; + } + if let Some(reporter) = &reporter { reporter.download_started(record, total_records); } let client = client.clone(); - let reporter = reporter.clone(); - let concurrency_limit = concurrency_limit.clone(); - - // Create a future that fetches the mapping for the record's hash concurrently - // with the rest of the requests. - pending_futures.push(async move { - // Acquire a permit to limit the number of concurrent requests - let _permit = concurrency_limit - .acquire_owned() - .await - .expect("semaphore error"); - - // Fetch the mapping by the hash of the record. - let result = try_fetch_single_mapping(&client, &hash).await; - - // Report the result to the reporter - if let Some(reporter) = reporter { - match &result { - Ok(_) => reporter.download_finished(record, total_records), - Err(_) => reporter.download_failed(record, total_records), - } - } - match result { - Ok(Some(package)) => Ok(Some((hash, package))), - Ok(None) => Ok(None), - Err(e) => Err(e), + // Get or create the thread-local Tokio runtime + let result = TOKIO_RT.with(|rt| { + let mut rt_ref = rt.borrow_mut(); + if rt_ref.is_none() { + *rt_ref = Some(Runtime::new().expect("Failed to create Tokio runtime")); } + + // Execute the async function within the Tokio runtime + rt_ref.as_ref().unwrap().block_on(try_fetch_single_mapping(&client, hash)) }); - } - let mut result_map = HashMap::with_capacity(total_records); - while let Some(result) = pending_futures.next().await { + // Report the result to the reporter + if let Some(reporter) = &reporter { + match &result { + Ok(_) => reporter.download_finished(record, total_records), + Err(_) => reporter.download_failed(record, total_records), + } + } + match result { - Ok(Some((hash, package))) => { + Ok(Some(package)) => { // Add the mapping to the result hashmap - result_map.insert(hash, package); + let mut map = result_map.lock().unwrap(); + map.insert(*hash, package); } Ok(None) => { // If no mapping was found, do nothing. } Err(e) => { - // If an error occurred, bail out,. - return Err(e); + // If an error occurred, store it + let mut err = error.lock().unwrap(); + *err = Some(e); } } + }); + + // Check if any errors occurred + let err = error.lock().unwrap(); + if let Some(e) = err.as_ref() { + // tracing::error!("Failed to download conda-pypi mapping: {:?}", e); + miette::bail!("Failed to download conda-pypi mapping: {:?}", e); } - Ok(result_map) + // Convert Arc> back to HashMap + let result = Arc::try_unwrap(result_map) + .expect("There should be no other references to result_map") + .into_inner() + .expect("Mutex should not be poisoned"); + + Ok(result) } /// Downloads and caches prefix.dev conda-pypi mapping. @@ -162,7 +179,7 @@ pub async fn amend_pypi_purls( ) -> miette::Result<()> { let conda_packages = conda_packages.into_iter().collect_vec(); let conda_mapping = - conda_pypi_name_mapping(client, conda_packages.iter().map(|p| *p as &_), reporter).await?; + conda_pypi_name_mapping(client, conda_packages.iter().map(|p| *p as &_), reporter)?; let compressed_mapping = conda_pypi_name_compressed_mapping(client).await?; for record in conda_packages {