Skip to content

Commit

Permalink
Merge branch 'main' into almaz/reward-weights
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Feb 11, 2025
2 parents d4ff980 + fa9b621 commit f330dd6
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions scripts/generate_reasoning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import asyncio
import hashlib
import json
import os
import random
Expand Down Expand Up @@ -87,14 +88,14 @@ async def process_example(example, session, args, output_file, pbar):
return None


async def load_processed_uuids(output_file):
async def load_processed_uuids(output_file, uuid_column):
processed_uuids = set()
if os.path.exists(output_file):
async with aiofiles.open(output_file, mode="r") as f:
async for line in f:
try:
data = json.loads(line)
processed_uuids.add(data["uuid"])
processed_uuids.add(hashlib.md5(str(data[uuid_column]).encode()).hexdigest())
except json.JSONDecodeError:
continue
return processed_uuids
Expand All @@ -120,7 +121,9 @@ async def main():
args = parser.parse_args()

dataset = load_dataset(args.dataset_name, split="train").shuffle()
processed_uuids = await load_processed_uuids(args.output_file)
processed_uuids = await load_processed_uuids(args.output_file, args.uuid_column)
if processed_uuids:
print(f"Found {len(processed_uuids)} already processed examples, resuming from there...")

if not os.path.exists(args.output_file):
async with aiofiles.open(args.output_file, mode="w") as f:
Expand All @@ -129,7 +132,7 @@ async def main():
active_tasks: Set[asyncio.Task] = set()

pbar = tqdm(
total=len(dataset),
total=len(dataset) - len(processed_uuids),
desc="Generating responses",
unit="row",
mininterval=2,
Expand All @@ -142,7 +145,8 @@ async def main():
connector=aiohttp.TCPConnector(limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60),
) as session:
for example in dataset:
if example["uuid"] not in processed_uuids:
uuid = hashlib.md5(str(example[args.uuid_column]).encode()).hexdigest()
if uuid not in processed_uuids:
# Wait if we've hit the concurrency limit
while len(active_tasks) >= args.max_concurrent:
done, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)
Expand Down

0 comments on commit f330dd6

Please sign in to comment.