Skip to content

Commit

Permalink
Merge branch 'fix-server-url-slurm' of https://github.com/huggingface…
Browse files Browse the repository at this point in the history
…/open-r1 into fix-server-url-slurm
  • Loading branch information
gabrielmbmb committed Jan 25, 2025
2 parents c34e9fd + 484ff82 commit 973b0c5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
7 changes: 7 additions & 0 deletions slurm/generate.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ while [[ $# -gt 0 ]]; do
MAX_NEW_TOKENS="$2"
shift 2
;;
--num-generations)
NUM_GENERATIONS="$2"
shift 2
;;
--hf-output-dataset)
HF_OUTPUT_DATASET="$2"
shift 2
Expand All @@ -68,6 +72,7 @@ fi
HF_DATASET_SPLIT=${HF_DATASET_SPLIT:-"train"}
PROMPT_COLUMN=${PROMPT_COLUMN:-"prompt"}
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-8192}
NUM_GENERATIONS=${NUM_GENERATIONS:-1}
PRIVATE=${PRIVATE:-"false"}

# Print all input arguments
Expand All @@ -80,6 +85,7 @@ echo "PROMPT_COLUMN: $PROMPT_COLUMN"
echo "TEMPERATURE: $TEMPERATURE"
echo "TOP_P: $TOP_P"
echo "MAX_NEW_TOKENS: $MAX_NEW_TOKENS"
echo "NUM_GENERATIONS: $NUM_GENERATIONS"
echo "HF_OUTPUT_DATASET: $HF_OUTPUT_DATASET"
echo "PRIVATE: $PRIVATE"
echo "-------------------"
Expand Down Expand Up @@ -188,6 +194,7 @@ RAY_ADDRESS="http://$head_node_ip:8265" ray job submit \
${TEMPERATURE:+--temperature "$TEMPERATURE"} \
${TOP_P:+--top-p "$TOP_P"} \
--max-new-tokens "$MAX_NEW_TOKENS" \
--num-generations "$NUM_GENERATIONS" \
${HF_OUTPUT_DATASET:+--hf-output-dataset "$HF_OUTPUT_DATASET"} \
${PRIVATE:+--private} \
--vllm-server-url "http://$head_node_ip:8000/v1"
9 changes: 9 additions & 0 deletions src/open_r1/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def build_distilabel_pipeline(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_new_tokens: int = 8192,
num_generations: int = 1,
) -> Pipeline:
generation_kwargs = {"max_new_tokens": max_new_tokens}

Expand All @@ -47,6 +48,7 @@ def build_distilabel_pipeline(
),
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
input_batch_size=10,
num_generations=num_generations,
)

return pipeline
Expand Down Expand Up @@ -105,6 +107,12 @@ def build_distilabel_pipeline(
default=8192,
help="Maximum number of new tokens to generate",
)
parser.add_argument(
"--num-generations",
type=int,
default=1,
help="Number of generations per problem",
)
parser.add_argument(
"--hf-output-dataset",
type=str,
Expand Down Expand Up @@ -135,6 +143,7 @@ def build_distilabel_pipeline(
temperature=args.temperature,
top_p=args.top_p,
max_new_tokens=args.max_new_tokens,
num_generations=args.num_generations,
)

print("Running generation pipeline...")
Expand Down

0 comments on commit 973b0c5

Please sign in to comment.