From 4a914585d9b02ba9aafc69b5e4a3b6cde3ad37f7 Mon Sep 17 00:00:00 2001 From: anton Date: Mon, 10 Feb 2025 14:17:27 +0100 Subject: [PATCH 1/5] sglang inference server --- slurm/serve_r1.slurm | 112 +++++++++++++++++++++++++++++++++++++++ slurm/serve_router.slurm | 46 ++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 slurm/serve_r1.slurm create mode 100644 slurm/serve_router.slurm diff --git a/slurm/serve_r1.slurm b/slurm/serve_r1.slurm new file mode 100644 index 00000000..60a72aad --- /dev/null +++ b/slurm/serve_r1.slurm @@ -0,0 +1,112 @@ +#!/bin/bash +#SBATCH --job-name=r1-server +#SBATCH --partition=hopper-prod +#SBATCH --qos=normal +#SBATCH --nodes=2 +#SBATCH --gpus-per-node=8 +#SBATCH --exclusive +#SBATCH --output=./logs/%x_%j_%n.out +#SBATCH --error=./logs/%x_%j_%n.err +#SBATCH --time=7-00:00:00 +#SBATCH --ntasks-per-node=1 +#SBATCH --requeue + +set -exuo pipefail + +MODEL_PATH="deepseek-ai/DeepSeek-R1" +CONDA_ENV="sglang124" +ROUTER_ADDRESS="" +SERVER_PORT=39877 +DIST_PORT=45000 + +trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1 + +# TODO: Adjust these variables to your cluster configuration +export OUTLINES_CACHE_DIR=/scratch/serve_r1/ocache/ +export TRITON_HOME=/scratch/serve_r1/triton/ +export GLOO_SOCKET_IFNAME="enp71s0" +export NCCL_SOCKET_IFNAME="enp71s0" + +while getopts "m:e:r:h" opt; do + case $opt in + m) MODEL_PATH="$OPTARG" ;; + e) CONDA_ENV="$OPTARG" ;; + r) ROUTER_ADDRESS="$OPTARG" ;; + h|?) echo "Usage: sbatch $0 [-m MODEL_PATH] [-e CONDA_ENV] [-r ROUTER_ADDRESS]"; exit 1 ;; + esac +done + +# TODO: Environment setup, adjust to your cluster configuration +module load cuda/12.4 +source ~/.bashrc +source "$CONDA_PREFIX/etc/profile.d/conda.sh" +conda activate "$CONDA_ENV" || { echo "Failed to activate conda env $CONDA_ENV"; exit 1; } + +FIRST_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1) +FIRST_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$FIRST_NODE" hostname --ip-address) + +# Launch servers synchronously across all nodes +# (--max-running-requests=56 is rough estimate to avoid too many evicted/preempted 16k-long requests) +srun --nodes=2 --ntasks=2 --ntasks-per-node=1 \ + bash -c "python -m sglang.launch_server \ + --model-path '$MODEL_PATH' \ + --tp 16 \ + --dist-init-addr '$FIRST_NODE_IP:$DIST_PORT' \ + --nnodes 2 \ + --node-rank \$SLURM_PROCID \ + --port '$SERVER_PORT' \ + --host 0.0.0.0 \ + --trust-remote-code \ + --max-running-requests 56 \ + --context-length 32768" & + +# Wait for server with timeout +TIMEOUT=3600 # 1h, but model loading should take ~30min +START_TIME=$(date +%s) +echo "Waiting for SGLang server (http://$FIRST_NODE_IP:$SERVER_PORT)..." + +while true; do + if curl -s -o /dev/null -w "%{http_code}" "http://$FIRST_NODE_IP:$SERVER_PORT/health" >/dev/null 2>&1; then + echo "Server is ready at http://$FIRST_NODE_IP:$SERVER_PORT" + break + fi + + CURRENT_TIME=$(date +%s) + if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then + echo "Error: Server failed to start within $TIMEOUT seconds" + exit 1 + fi + + echo "Still waiting... ($(($CURRENT_TIME - $START_TIME)) seconds elapsed)" + sleep 60 +done + +# Register with router only if address was provided +if [ -n "$ROUTER_ADDRESS" ]; then + echo "Registering with router at $ROUTER_ADDRESS..." + curl -X POST "http://$ROUTER_ADDRESS/add_worker?url=http://$FIRST_NODE_IP:$SERVER_PORT" || true + sleep 10 +fi + +echo "Checking available models..." +curl "http://$FIRST_NODE_IP:$SERVER_PORT/v1/models" +sleep 10 + +echo "Executing sanity check..." +curl "http://$FIRST_NODE_IP:$SERVER_PORT/v1/completions" \ + -H "Content-Type: application/json" \ + -d "{ + \"model\": \"default\", + \"prompt\": \"<|begin▁of▁sentence|><|User|>hi, how are you?<|Assistant|>\", + \"max_tokens\": 2048, + \"temperature\": 0.6 + }" + +# Keep the job running with health checks +while true; do + if ! curl -s -o /dev/null "http://$FIRST_NODE_IP:$SERVER_PORT/health"; then + echo "Error: Server health check failed" + exit 1 + fi + sleep 300 +done \ No newline at end of file diff --git a/slurm/serve_router.slurm b/slurm/serve_router.slurm new file mode 100644 index 00000000..b39ca66a --- /dev/null +++ b/slurm/serve_router.slurm @@ -0,0 +1,46 @@ +#!/bin/bash +#SBATCH --job-name=r1-router +#SBATCH --partition=hopper-cpu +#SBATCH --qos=high +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem-per-cpu=1875m +#SBATCH --output=./logs/%x_%j_%n.out +#SBATCH --error=./logs/%x_%j_%n.err +#SBATCH --time=30-00:00:00 +#SBATCH --requeue + +set -exuo pipefail + +# TODO: Adjust these variables to your cluster configuration +CONDA_ENV="sglang124" +ROUTER_PORT=39876 + +trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1 + +while getopts "e:h" opt; do + case $opt in + e) CONDA_ENV="$OPTARG" ;; + h|?) echo "Usage: sbatch $0 [-e CONDA_ENV]"; exit 1 ;; + esac +done + +# TODO: Environment setup, adjust to your cluster configuration +source ~/.bashrc +source "$CONDA_PREFIX/etc/profile.d/conda.sh" +conda activate "$CONDA_ENV" || { echo "Failed to activate conda env $CONDA_ENV"; exit 1; } + +python -m sglang_router.launch_router \ + --port "$ROUTER_PORT" \ + --host 0.0.0.0 \ + --policy "round_robin" \ + --worker-startup-timeout-secs 300 + +# Keep the job running with health checks +while true; do + if ! curl -s -o /dev/null "http://localhost:$ROUTER_PORT/health"; then + echo "Error: Router health check failed" + exit 1 + fi + sleep 300 +done \ No newline at end of file From 0de6d948a030fb2abeee1f380525ebcec2c70622 Mon Sep 17 00:00:00 2001 From: anton Date: Mon, 10 Feb 2025 14:22:23 +0100 Subject: [PATCH 2/5] add vllm --- slurm/experimental/serve_r1_vllm.slurm | 135 +++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 slurm/experimental/serve_r1_vllm.slurm diff --git a/slurm/experimental/serve_r1_vllm.slurm b/slurm/experimental/serve_r1_vllm.slurm new file mode 100644 index 00000000..7e0b0d59 --- /dev/null +++ b/slurm/experimental/serve_r1_vllm.slurm @@ -0,0 +1,135 @@ +#!/bin/bash +#SBATCH --job-name=r1-vllm +#SBATCH --partition=hopper-prod +#SBATCH --qos=normal +#SBATCH --nodes=4 +#SBATCH --gpus-per-node=8 +#SBATCH --exclusive +#SBATCH --output=./logs/%x_%j_%n.out +#SBATCH --error=./logs/%x_%j_%n.err +#SBATCH --time=7-00:00:00 +#SBATCH --ntasks-per-node=1 +#SBATCH --requeue + +set -exuo pipefail + +MODEL_PATH="deepseek-ai/DeepSeek-R1" +CONDA_ENV="vllm7" +SERVER_PORT=8000 +RAY_PORT=6379 +RAY_DASHBOARD_PORT=8265 + +trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1 + +while getopts "m:e:h" opt; do + case $opt in + m) MODEL_PATH="$OPTARG" ;; + e) CONDA_ENV="$OPTARG" ;; + h|?) echo "Usage: sbatch $0 [-m MODEL_PATH] [-e CONDA_ENV]"; exit 1 ;; + esac +done + +# Environment setup +module load cuda/12.1 +source ~/.bashrc +source "$CONDA_PREFIX/etc/profile.d/conda.sh" +conda activate "$CONDA_ENV" || { echo "Failed to activate conda env $CONDA_ENV"; exit 1; } + +# Get nodes information +NODES=($(scontrol show hostnames "$SLURM_JOB_NODELIST")) +HEAD_NODE="${NODES[0]}" +HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) + +echo "SLURM_JOB_ID: $SLURM_JOB_ID" +echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST" +echo "Head node: $HEAD_NODE ($HEAD_NODE_IP)" + +# Start Ray head node +echo "Starting Ray head node at $HEAD_NODE" +srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" \ + ray start --head \ + --node-ip-address="$HEAD_NODE_IP" \ + --port=$RAY_PORT \ + --dashboard-host=0.0.0.0 \ + --dashboard-port=$RAY_DASHBOARD_PORT \ + --block & + +sleep 10 + +# Start Ray worker nodes +WORKER_COUNT=$((SLURM_JOB_NUM_NODES - 1)) +for ((i = 1; i <= WORKER_COUNT; i++)); do + WORKER_NODE="${NODES[$i]}" + echo "Starting Ray worker $i at $WORKER_NODE" + srun --nodes=1 --ntasks=1 -w "$WORKER_NODE" \ + ray start --address "$HEAD_NODE_IP:$RAY_PORT" \ + --block & + sleep 5 +done + +echo "Waiting for Ray cluster to initialize..." +sleep 60 + +# Start vLLM server +echo "Starting vLLM server..." +RAY_ADDRESS="http://$HEAD_NODE_IP:$RAY_DASHBOARD_PORT" ray job submit \ + --working-dir src/open_r1 \ + --no-wait \ + --job-id vllm-server \ + -- vllm serve "$MODEL_PATH" \ + --tensor-parallel-size 8 \ + --pipeline-parallel-size 4 \ + --gpu-memory-utilization 0.90 \ + --max-model-len 32768 \ + --max-num-batched-tokens 262144 \ + --max-num-seqs 128 \ + --max-seq-len-to-capture 32768 \ + --enable-chunked-prefill true \ + --preemption-mode recompute \ + --swap-space 128 \ + --trust-remote-code \ + --distributed-executor-backend ray + +# Wait for server with timeout +TIMEOUT=3600 # 1h +START_TIME=$(date +%s) +echo "Waiting for vLLM server (http://$HEAD_NODE_IP:$SERVER_PORT)..." + +while true; do + if curl -s -o /dev/null -w "%{http_code}" "http://$HEAD_NODE_IP:$SERVER_PORT/health" >/dev/null 2>&1; then + echo "Server is ready at http://$HEAD_NODE_IP:$SERVER_PORT" + break + fi + + CURRENT_TIME=$(date +%s) + if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then + echo "Error: Server failed to start within $TIMEOUT seconds" + exit 1 + fi + + echo "Still waiting... ($(($CURRENT_TIME - $START_TIME)) seconds elapsed)" + sleep 60 +done + +echo "Checking available models..." +curl "http://$HEAD_NODE_IP:$SERVER_PORT/v1/models" +sleep 10 + +echo "Executing sanity check..." +curl "http://$HEAD_NODE_IP:$SERVER_PORT/v1/completions" \ + -H "Content-Type: application/json" \ + -d "{ + \"model\": \"default\", + \"prompt\": \"<|begin▁of▁sentence|><|User|>hi, how are you?<|Assistant|>\", + \"max_tokens\": 2048, + \"temperature\": 0.6 + }" + +# Keep the job running with health checks +while true; do + if ! curl -s -o /dev/null "http://$HEAD_NODE_IP:$SERVER_PORT/health"; then + echo "Error: Server health check failed" + exit 1 + fi + sleep 300 +done \ No newline at end of file From 5103abe99580e62ac92b8f84c6c457c16c2e7746 Mon Sep 17 00:00:00 2001 From: anton Date: Mon, 10 Feb 2025 14:29:58 +0100 Subject: [PATCH 3/5] readme --- slurm/README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 slurm/README.md diff --git a/slurm/README.md b/slurm/README.md new file mode 100644 index 00000000..f81c583c --- /dev/null +++ b/slurm/README.md @@ -0,0 +1,17 @@ +## Serving DeepSeek-R1 on 2x8 H100 SLURM nodes with SGLang + +1. Set up the environment (adjust for your cuda version): +```bash +conda create -n sglang124 python=3.11 +conda activate sglang124 + +pip install torch=2.5.1 --index-url https://download.pytorch.org/whl/cu124 + +pip install sgl-kernel --force-reinstall --no-deps +pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ +``` + +2. Run the server: +```bash +sbatch serve_r1.slurm -m "/fsx/deepseek-r1-checkpoint" -e "sglang124" +``` \ No newline at end of file From 600458b1e12d8e9392b47c6183726701439dd106 Mon Sep 17 00:00:00 2001 From: anton Date: Mon, 10 Feb 2025 16:11:34 +0100 Subject: [PATCH 4/5] add a generation script --- scripts/generate_reason_data.py | 179 +++++++++++++++++++++++++ slurm/README.md | 17 ++- slurm/experimental/serve_r1_vllm.slurm | 3 - slurm/serve_r1.slurm | 3 - slurm/serve_router.slurm | 1 - 5 files changed, 194 insertions(+), 9 deletions(-) create mode 100644 scripts/generate_reason_data.py diff --git a/scripts/generate_reason_data.py b/scripts/generate_reason_data.py new file mode 100644 index 00000000..17840640 --- /dev/null +++ b/scripts/generate_reason_data.py @@ -0,0 +1,179 @@ +import argparse +import asyncio +import json +import os +import random +from asyncio import Lock +from typing import Set + +from datasets import load_dataset +from tqdm.asyncio import tqdm + +import aiofiles +import aiohttp +import uvloop + + +file_lock = Lock() + + +async def generate_completion(session, prompt, args): + retry_budget = 10 + while retry_budget > 0: + try: + await asyncio.sleep(random.uniform(0.0, 0.1)) + async with session.post( + f"http://{args.api_addr}/v1/chat/completions", + json={ + "model": "default", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": args.max_tokens, + "temperature": args.temperature, + "top_p": args.top_p, + }, + headers={"Authorization": "Bearer EMPTY"}, + ) as response: + return await response.json(content_type=None) + except Exception as e: + print(f"API error (will retry): {e}") + retry_budget -= 1 + await asyncio.sleep(10) + return None + + +async def process_example(example, session, args, output_file, pbar): + prompt = args.prompt_template.format(prompt=example[args.prompt_column]) + + try: + tasks = [ + generate_completion(session, prompt, args) + for _ in range(args.num_generations) + ] + + completions = await asyncio.gather(*tasks) + + if any(completion is None for completion in completions): + print(f"Error processing example") + pbar.update(1) + return None + + generations = [] + finish_reasons = [] + api_metadata = [] + + for completion in completions: + generations.append(completion["choices"][0]["message"]["content"]) + finish_reasons.append(completion["choices"][0]["finish_reason"]) + api_metadata.append(completion["usage"]) + + # Combine original dataset fields with generations + result = { + **example, # Preserve all original dataset fields + "generations": generations, + "finish_reasons": finish_reasons, + "api_metadata": api_metadata, + } + + # Write to file with lock + async with file_lock: + async with aiofiles.open(output_file, mode="a") as f: + await f.write(json.dumps(result) + "\n") + await f.flush() + + pbar.set_postfix(active=len(pbar.active_tasks), refresh=False) + pbar.update(1) + + return result + except Exception as e: + print(f"Error processing example: {e}") + pbar.update(1) + return None + + +async def load_processed_uuids(output_file): + processed_uuids = set() + if os.path.exists(output_file): + async with aiofiles.open(output_file, mode="r") as f: + async for line in f: + try: + data = json.loads(line) + processed_uuids.add(data["uuid"]) + except json.JSONDecodeError: + continue + return processed_uuids + + +async def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--dataset-name", type=str, required=True) + parser.add_argument("--output-file", type=str, required=True) + parser.add_argument("--prompt-column", type=str, required=True) + parser.add_argument("--uuid-column", type=str, required=True) + parser.add_argument("--api-addr", type=str, default="localhost:39876") + parser.add_argument("--num-generations", type=int, default=4) + parser.add_argument( + "--prompt-template", + type=str, + default="You will be given a problem. Please reason step by step, and put your final answer within \\boxed{{}}:\n{prompt}", + ) + parser.add_argument("--temperature", type=float, default=0.6) + parser.add_argument("--top-p", type=float, default=0.95) + parser.add_argument("--max-tokens", type=int, default=16384) + parser.add_argument("--max-concurrent", type=int, default=1000) + args = parser.parse_args() + + dataset = load_dataset(args.dataset_name, split="train").shuffle() + processed_uuids = await load_processed_uuids(args.output_file) + + if not os.path.exists(args.output_file): + async with aiofiles.open(args.output_file, mode="w") as f: + await f.write("") + + active_tasks: Set[asyncio.Task] = set() + + pbar = tqdm( + total=len(dataset), + desc="Generating responses", + unit="row", + mininterval=2, + smoothing=0.0001, + ) + pbar.active_tasks = active_tasks + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=60 * 60), + connector=aiohttp.TCPConnector( + limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60 + ), + ) as session: + for example in dataset: + if example["uuid"] not in processed_uuids: + # Wait if we've hit the concurrency limit + while len(active_tasks) >= args.max_concurrent: + done, active_tasks = await asyncio.wait( + active_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done: + try: + await task + except Exception as e: + print(f"Task failed: {e}") + + task = asyncio.create_task( + process_example(example, session, args, args.output_file, pbar) + ) + active_tasks.add(task) + task.add_done_callback(active_tasks.discard) + + pbar.set_postfix(active=len(active_tasks), refresh=True) + + # Wait for remaining tasks + if active_tasks: + await asyncio.gather(*active_tasks, return_exceptions=True) + + pbar.close() + + +if __name__ == "__main__": + uvloop.install() + asyncio.run(main()) diff --git a/slurm/README.md b/slurm/README.md index f81c583c..10106459 100644 --- a/slurm/README.md +++ b/slurm/README.md @@ -11,7 +11,20 @@ pip install sgl-kernel --force-reinstall --no-deps pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ ``` -2. Run the server: +2. Run the server and wait for the model to load: ```bash -sbatch serve_r1.slurm -m "/fsx/deepseek-r1-checkpoint" -e "sglang124" +sbatch slurm/serve_r1.slurm -m "/fsx/deepseek-r1-checkpoint" -e "sglang124" +``` + +3. Run the data generation script: +```bash +python scripts/generate_reasoning.py \ + --dataset-name "AI-MO/NuminaMath-1.5" \ + --output-file "numinamath_r1_generations.jsonl" \ + --prompt-column "problem" \ + --uuid-column "problem" \ + --api-addr ":39877" \ + --num-generations 2 \ + --max-tokens 16384 \ + --max-concurrent 200 ``` \ No newline at end of file diff --git a/slurm/experimental/serve_r1_vllm.slurm b/slurm/experimental/serve_r1_vllm.slurm index 7e0b0d59..9f1ffd93 100644 --- a/slurm/experimental/serve_r1_vllm.slurm +++ b/slurm/experimental/serve_r1_vllm.slurm @@ -9,7 +9,6 @@ #SBATCH --error=./logs/%x_%j_%n.err #SBATCH --time=7-00:00:00 #SBATCH --ntasks-per-node=1 -#SBATCH --requeue set -exuo pipefail @@ -19,8 +18,6 @@ SERVER_PORT=8000 RAY_PORT=6379 RAY_DASHBOARD_PORT=8265 -trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1 - while getopts "m:e:h" opt; do case $opt in m) MODEL_PATH="$OPTARG" ;; diff --git a/slurm/serve_r1.slurm b/slurm/serve_r1.slurm index 60a72aad..6cb3719d 100644 --- a/slurm/serve_r1.slurm +++ b/slurm/serve_r1.slurm @@ -9,7 +9,6 @@ #SBATCH --error=./logs/%x_%j_%n.err #SBATCH --time=7-00:00:00 #SBATCH --ntasks-per-node=1 -#SBATCH --requeue set -exuo pipefail @@ -19,8 +18,6 @@ ROUTER_ADDRESS="" SERVER_PORT=39877 DIST_PORT=45000 -trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1 - # TODO: Adjust these variables to your cluster configuration export OUTLINES_CACHE_DIR=/scratch/serve_r1/ocache/ export TRITON_HOME=/scratch/serve_r1/triton/ diff --git a/slurm/serve_router.slurm b/slurm/serve_router.slurm index b39ca66a..0fe96177 100644 --- a/slurm/serve_router.slurm +++ b/slurm/serve_router.slurm @@ -33,7 +33,6 @@ conda activate "$CONDA_ENV" || { echo "Failed to activate conda env $CONDA_ENV"; python -m sglang_router.launch_router \ --port "$ROUTER_PORT" \ --host 0.0.0.0 \ - --policy "round_robin" \ --worker-startup-timeout-secs 300 # Keep the job running with health checks From d7208bb7400e05cab0c34b511bf7d1594d213f51 Mon Sep 17 00:00:00 2001 From: anton Date: Mon, 10 Feb 2025 16:15:05 +0100 Subject: [PATCH 5/5] ruff --- scripts/generate_reason_data.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/scripts/generate_reason_data.py b/scripts/generate_reason_data.py index 17840640..01e6e7a7 100644 --- a/scripts/generate_reason_data.py +++ b/scripts/generate_reason_data.py @@ -45,10 +45,7 @@ async def process_example(example, session, args, output_file, pbar): prompt = args.prompt_template.format(prompt=example[args.prompt_column]) try: - tasks = [ - generate_completion(session, prompt, args) - for _ in range(args.num_generations) - ] + tasks = [generate_completion(session, prompt, args) for _ in range(args.num_generations)] completions = await asyncio.gather(*tasks) @@ -142,26 +139,20 @@ async def main(): async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=60 * 60), - connector=aiohttp.TCPConnector( - limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60 - ), + connector=aiohttp.TCPConnector(limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60), ) as session: for example in dataset: if example["uuid"] not in processed_uuids: # Wait if we've hit the concurrency limit while len(active_tasks) >= args.max_concurrent: - done, active_tasks = await asyncio.wait( - active_tasks, return_when=asyncio.FIRST_COMPLETED - ) + done, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED) for task in done: try: await task except Exception as e: print(f"Task failed: {e}") - task = asyncio.create_task( - process_example(example, session, args, args.output_file, pbar) - ) + task = asyncio.create_task(process_example(example, session, args, args.output_file, pbar)) active_tasks.add(task) task.add_done_callback(active_tasks.discard)