Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the actual async generation script #273

Merged
merged 6 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions scripts/generate_reason_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import argparse
import asyncio
import json
import os
import random
from asyncio import Lock
from typing import Set

from datasets import load_dataset
from tqdm.asyncio import tqdm

import aiofiles
import aiohttp
import uvloop


file_lock = Lock()


async def generate_completion(session, prompt, args):
retry_budget = 10
while retry_budget > 0:
try:
await asyncio.sleep(random.uniform(0.0, 0.1))
async with session.post(
f"http://{args.api_addr}/v1/chat/completions",
json={
"model": "default",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": args.max_tokens,
"temperature": args.temperature,
"top_p": args.top_p,
},
headers={"Authorization": "Bearer EMPTY"},
) as response:
return await response.json(content_type=None)
except Exception as e:
print(f"API error (will retry): {e}")
retry_budget -= 1
await asyncio.sleep(10)
return None


async def process_example(example, session, args, output_file, pbar):
prompt = args.prompt_template.format(prompt=example[args.prompt_column])

try:
tasks = [generate_completion(session, prompt, args) for _ in range(args.num_generations)]

completions = await asyncio.gather(*tasks)

if any(completion is None for completion in completions):
print(f"Error processing example")
pbar.update(1)
return None

generations = []
finish_reasons = []
api_metadata = []

for completion in completions:
generations.append(completion["choices"][0]["message"]["content"])
finish_reasons.append(completion["choices"][0]["finish_reason"])
api_metadata.append(completion["usage"])

# Combine original dataset fields with generations
result = {
**example, # Preserve all original dataset fields
"generations": generations,
"finish_reasons": finish_reasons,
"api_metadata": api_metadata,
}

# Write to file with lock
async with file_lock:
async with aiofiles.open(output_file, mode="a") as f:
await f.write(json.dumps(result) + "\n")
await f.flush()

pbar.set_postfix(active=len(pbar.active_tasks), refresh=False)
pbar.update(1)

return result
except Exception as e:
print(f"Error processing example: {e}")
pbar.update(1)
return None


async def load_processed_uuids(output_file):
processed_uuids = set()
if os.path.exists(output_file):
async with aiofiles.open(output_file, mode="r") as f:
async for line in f:
try:
data = json.loads(line)
processed_uuids.add(data["uuid"])
except json.JSONDecodeError:
continue
return processed_uuids


async def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-name", type=str, required=True)
parser.add_argument("--output-file", type=str, required=True)
parser.add_argument("--prompt-column", type=str, required=True)
parser.add_argument("--uuid-column", type=str, required=True)
parser.add_argument("--api-addr", type=str, default="localhost:39876")
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument(
"--prompt-template",
type=str,
default="You will be given a problem. Please reason step by step, and put your final answer within \\boxed{{}}:\n{prompt}",
)
parser.add_argument("--temperature", type=float, default=0.6)
parser.add_argument("--top-p", type=float, default=0.95)
parser.add_argument("--max-tokens", type=int, default=16384)
parser.add_argument("--max-concurrent", type=int, default=1000)
args = parser.parse_args()

dataset = load_dataset(args.dataset_name, split="train").shuffle()
processed_uuids = await load_processed_uuids(args.output_file)

if not os.path.exists(args.output_file):
async with aiofiles.open(args.output_file, mode="w") as f:
await f.write("")

active_tasks: Set[asyncio.Task] = set()

pbar = tqdm(
total=len(dataset),
desc="Generating responses",
unit="row",
mininterval=2,
smoothing=0.0001,
)
pbar.active_tasks = active_tasks

async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=60 * 60),
connector=aiohttp.TCPConnector(limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60),
) as session:
for example in dataset:
if example["uuid"] not in processed_uuids:
# Wait if we've hit the concurrency limit
while len(active_tasks) >= args.max_concurrent:
done, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done:
try:
await task
except Exception as e:
print(f"Task failed: {e}")

task = asyncio.create_task(process_example(example, session, args, args.output_file, pbar))
active_tasks.add(task)
task.add_done_callback(active_tasks.discard)

pbar.set_postfix(active=len(active_tasks), refresh=True)

# Wait for remaining tasks
if active_tasks:
await asyncio.gather(*active_tasks, return_exceptions=True)

pbar.close()


if __name__ == "__main__":
uvloop.install()
asyncio.run(main())
17 changes: 15 additions & 2 deletions slurm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,20 @@ pip install sgl-kernel --force-reinstall --no-deps
pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/
```

2. Run the server:
2. Run the server and wait for the model to load:
```bash
sbatch serve_r1.slurm -m "/fsx/deepseek-r1-checkpoint" -e "sglang124"
sbatch slurm/serve_r1.slurm -m "/fsx/deepseek-r1-checkpoint" -e "sglang124"
```

3. Run the data generation script:
```bash
python scripts/generate_reasoning.py \
--dataset-name "AI-MO/NuminaMath-1.5" \
--output-file "numinamath_r1_generations.jsonl" \
--prompt-column "problem" \
--uuid-column "problem" \
--api-addr "<SGLANG_SERVER_ADDRESS>:39877" \
--num-generations 2 \
--max-tokens 16384 \
--max-concurrent 200
```
3 changes: 0 additions & 3 deletions slurm/experimental/serve_r1_vllm.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#SBATCH --error=./logs/%x_%j_%n.err
#SBATCH --time=7-00:00:00
#SBATCH --ntasks-per-node=1
#SBATCH --requeue

set -exuo pipefail

Expand All @@ -19,8 +18,6 @@ SERVER_PORT=8000
RAY_PORT=6379
RAY_DASHBOARD_PORT=8265

trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1

while getopts "m:e:h" opt; do
case $opt in
m) MODEL_PATH="$OPTARG" ;;
Expand Down
3 changes: 0 additions & 3 deletions slurm/serve_r1.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#SBATCH --error=./logs/%x_%j_%n.err
#SBATCH --time=7-00:00:00
#SBATCH --ntasks-per-node=1
#SBATCH --requeue

set -exuo pipefail

Expand All @@ -19,8 +18,6 @@ ROUTER_ADDRESS=""
SERVER_PORT=39877
DIST_PORT=45000

trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1

# TODO: Adjust these variables to your cluster configuration
export OUTLINES_CACHE_DIR=/scratch/serve_r1/ocache/
export TRITON_HOME=/scratch/serve_r1/triton/
Expand Down
1 change: 0 additions & 1 deletion slurm/serve_router.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ conda activate "$CONDA_ENV" || { echo "Failed to activate conda env $CONDA_ENV";
python -m sglang_router.launch_router \
--port "$ROUTER_PORT" \
--host 0.0.0.0 \
--policy "round_robin" \
--worker-startup-timeout-secs 300

# Keep the job running with health checks
Expand Down