-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support for pubmed dataset and random query list (#223)
* add support for pubmed dataset and random query list Signed-off-by: leslieluyu <leslie.luyu@gmail.com>
- Loading branch information
1 parent
3b76d39
commit 225adb3
Showing
3 changed files
with
274 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# How to benchmark pubmed datasets by send query randomly | ||
This README outlines how to prepare the PubMed datasets for benchmarking ChatQnA and creating a query list based on these datasets. It also explains how to randomly send queries from the list to the ChatQnA pipeline in order to obtain performance data that is more consistent with real user scenarios. | ||
|
||
## 1. prepare the pubmed datasets | ||
|
||
- To simulate a practical user scenario, we have chosen to use industrial data from PubMed. The original PubMed data can be found here: [Hugging Face - MedRAG PubMed](https://huggingface.co/datasets/MedRAG/pubmed/tree/main/chunk). | ||
- In order to observe and compare the performance of the ChatQnA pipeline with different sizes of ingested datasets, we created four files: pubmed_10.txt, pubmed_100.txt, pubmed_1000.txt, and pubmed_10000.txt. These files contain 10, 100, 1,000, and 10,000 records of data extracted from [pubmed23n0001.jsonl] | ||
|
||
|
||
### 1.1 get pubmed data | ||
wget https://huggingface.co/datasets/MedRAG/pubmed/resolve/main/chunk/pubmed23n0001.jsonl | ||
|
||
### 1.2 use script to extract data | ||
A prepared script, extract_lines.sh, is available to extract lines from the original pubmed file into the dataset and query list. | ||
#### Usage: | ||
``` | ||
$ cd dataset | ||
$./extract_lines.sh input_file output_file begin_id end_id | ||
``` | ||
|
||
|
||
|
||
### 1.3 prepare 4 dataset files | ||
The commands below will generate the 4 pubmed dataset files. And the 4 dataset files will be ingested by dataprep before benchmarking: | ||
``` | ||
./extract_lines.sh pubmed23n0001.jsonl pubmed_10.txt pubmed23n0001_0 pubmed23n0001_9 | ||
./extract_lines.sh pubmed23n0001.jsonl pubmed_100.txt pubmed23n0001_0 pubmed23n0001_99 | ||
./extract_lines.sh pubmed23n0001.jsonl pubmed_1000.txt pubmed23n0001_0 pubmed23n0001_999 | ||
./extract_lines.sh pubmed23n0001.jsonl pubmed_10000.txt pubmed23n0001_0 pubmed23n0001_9999 | ||
``` | ||
|
||
### 1.4 prepare the query list | ||
Basically, the random queries will be based on 10% of the ingested dataset, so we only need to prepare a maximum of 1,000 records for the random query list | ||
``` | ||
cp pubmed_1000.txt pubmed_q1000.txt | ||
``` | ||
|
||
|
||
|
||
## 2. How to use pubmed qlist | ||
> NOTE:<BR>Unlike chatqnafixed.py, which sends a fixed prompt each time, chatqna_qlist_pubmed.py is designed to benchmark the ChatQnA pipeline using the PubMed query list. <BR> | ||
Each time it randomly selects a query from the query list file and sends it to the ChatQnA pipeline | ||
|
||
- First make sure use the correct benchmark_target in run.yaml | ||
|
||
``` | ||
bench-target: "chatqna_qlist_pubmed" | ||
``` | ||
- Ensure that the environment variables are set correctly: | ||
- DATASET: The specific name of the query list file. Default: "pubmed_q1000.txt" | ||
- MAX_LINES: The maximum number of lines from the query list that will be used for random queries. Default: 1000 | ||
- MAX_TOKENS: The parameter sent to the ChatQnA pipeline to specify the maximum number of tokens the language model can generate. Default: 128 | ||
- PROMPT: A user-defined prompt that will be sent to the ChatQnA pipeline. | ||
- Then run the benchmark script,for example: | ||
``` | ||
./stresscli.py load-test --profile run.yaml | ||
``` | ||
For more information, please refer to the [stresscli](./README.md) documentation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#!/bin/bash | ||
|
||
# Copyright (C) 2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Check for the correct number of arguments | ||
if [ "$#" -ne 4 ]; then | ||
echo "Usage: $0 input_file output_file begin_id end_id" | ||
exit 1 | ||
fi | ||
|
||
input_file="$1" | ||
output_file="$2" | ||
begin_id="$3" | ||
end_id="$4" | ||
|
||
# Create or clear the output file | ||
> "$output_file" | ||
|
||
# Initialize a flag to indicate whether we have started writing | ||
writing=false | ||
|
||
# Read through the input file line by line | ||
while IFS= read -r line; do | ||
# Check if the line is valid JSON and extract the ID | ||
if echo "$line" | jq -e . >/dev/null 2>&1; then | ||
# Extract the ID from the JSON object | ||
id=$(echo "$line" | jq -r .id) | ||
|
||
# Check if we have reached the beginning ID | ||
if [[ "$id" == "$begin_id" ]]; then | ||
echo "$line" >> "$output_file" | ||
writing=true # Start writing after finding begin_id | ||
continue # Continue to next line after writing begin_id | ||
fi | ||
|
||
# If we have started writing, keep writing until we reach end_id | ||
if [[ "$writing" == true ]]; then | ||
echo "$line" >> "$output_file" | ||
fi | ||
|
||
# Stop processing if we reach end_id | ||
if [[ "$id" == "$end_id" ]]; then | ||
break # Stop after writing end_id, do not write again | ||
fi | ||
fi | ||
done < "$input_file" | ||
|
||
echo "Records from ID $begin_id to $end_id have been extracted to $output_file." |
166 changes: 166 additions & 0 deletions
166
evals/benchmark/stresscli/locust/chatqna_qlist_pubmed.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright (C) 2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import logging | ||
import os | ||
import random | ||
import time | ||
|
||
import tokenresponse as token | ||
|
||
# Add debug logging | ||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
|
||
DATASET = os.getenv("DATASET", "pubmed_q1000.txt") | ||
MAX_LINES = os.getenv("MAX_LINES", 1000) | ||
MAX_TOKENS = os.getenv("MAX_TOKENS", 128) | ||
MAX_WORDS = os.getenv("MAX_WORDS", 1024) | ||
PROMPT = os.getenv( | ||
"PROMPT", | ||
f". Give me the content related to this title and please repeat the answer multiple times till the word count exceeds {MAX_WORDS}.", | ||
) | ||
|
||
# Initialize the data | ||
cwd = os.path.dirname(os.path.abspath(__file__)) | ||
filename = os.path.join(cwd, "..", "dataset", DATASET) | ||
logging.info(f"The dataset filename: {filename}") | ||
logging.info(f"MAX_LINES: {MAX_LINES}") | ||
logging.info(f"MAX_TOKENS: {MAX_TOKENS}") | ||
logging.info(f"MAX_WORDS: {MAX_WORDS}") | ||
|
||
# filename = os.path.join(cwd, "..", "dataset", "pubmed_q1000_fix.txt") | ||
prompt_suffix = PROMPT | ||
|
||
# Global dictionary to store data | ||
data_dict = {} | ||
max_lines = 0 | ||
|
||
|
||
def load_pubmed_data(filename): | ||
"""Load PubMed data into a dictionary and determine max lines.""" | ||
global data_dict, max_lines | ||
# create timestamp t1, t2 in the end of this function print the consumed time by t2-t1 | ||
t1 = time.time() | ||
try: | ||
with open(filename, "r", encoding="utf-8") as f: | ||
for line_num, line in enumerate(f, 1): | ||
try: | ||
line = line.strip() | ||
|
||
if not line: # Skip empty lines | ||
continue | ||
|
||
data = json.loads(line) | ||
data_dict[line_num] = data | ||
max_lines = line_num | ||
|
||
except json.JSONDecodeError: | ||
logging.warning(f"Invalid JSON at line {line_num}: {line[:100]}...") | ||
except (IndexError, ValueError) as e: | ||
logging.warning(f"Invalid ID format at line {line_num}: {str(e)}") | ||
except Exception as e: | ||
logging.warning(f"Error processing line {line_num}: {str(e)}") | ||
print(f"Current length of data_dict: {len(data_dict)}") | ||
|
||
logging.info(f"Loaded {len(data_dict)} items. Max ID: {max_lines}") | ||
|
||
# Add validation check | ||
if len(data_dict) < 2: # Assuming we should have more than 10 items | ||
logging.error("Suspiciously few items loaded. Possible data loading issue.") | ||
return False | ||
t2 = time.time() | ||
print(f"load_pubmed_data time:{t2-t1:.4f} seconds") | ||
return True | ||
except Exception as e: | ||
logging.error(f"Error loading file: {str(e)}") | ||
return False | ||
|
||
|
||
if not load_pubmed_data(filename): | ||
exit() | ||
|
||
|
||
def getDataByLine(line_num): | ||
"""Get document by its line number.""" | ||
return data_dict[line_num] | ||
# return data_dict.get(line_num) | ||
|
||
|
||
def getRandomDocument(): | ||
"""Get a random document using line numbers.""" | ||
if not data_dict: | ||
logging.error("No data loaded") | ||
return None | ||
|
||
# get min of max_lines and MAX_LINES | ||
random_max = min(max_lines, int(MAX_LINES)) | ||
logging.info(f"random_max={random_max}") | ||
|
||
random_line = random.randint(1, random_max) | ||
doc = getDataByLine(random_line) | ||
if doc: | ||
return doc | ||
|
||
logging.error("Failed to find valid document after ") | ||
return None | ||
|
||
|
||
def getUrl(): | ||
return "/v1/chatqna" | ||
|
||
|
||
def getReqData(): | ||
doc = getRandomDocument() | ||
message = f"{doc['title']}{prompt_suffix}" | ||
logging.debug(f"Selected document: {message}") | ||
return {"messages": f"{message}", "max_tokens": int(MAX_TOKENS)} | ||
|
||
|
||
def respStatics(environment, resp): | ||
return token.respStatics(environment, resp) | ||
|
||
|
||
def staticsOutput(environment, reqlist): | ||
token.staticsOutput(environment, reqlist) | ||
|
||
|
||
# write a function to get the title of each line of the data_dict | ||
def get_title(data_dict): | ||
titles = [] | ||
# get the length of the data_dict first, then iterate the data_dict get each line, each line is a doc | ||
length = len(data_dict) | ||
print(f"length={length}") | ||
for i in range(length): | ||
i = i + 1 | ||
doc = data_dict.get(i) | ||
# print the length of doc json | ||
total_chars = 0 | ||
for key, value in doc.items(): | ||
if isinstance(value, str): # Check if the value is a string | ||
total_chars += len(value) | ||
|
||
id = doc["id"] | ||
title = doc["title"] | ||
# print(f"title={title},length={len(title)}") | ||
|
||
print(f"i={i}, id={id}, total_chars={total_chars}, doclenth length={len(doc)},length={len(title)}") | ||
titles.append(title) | ||
return titles | ||
|
||
|
||
# test | ||
if __name__ == "__main__": | ||
logging.info("Starting the program") | ||
# getRandomDocument() | ||
get_title(data_dict) | ||
# filename = "../dataset/pubmed_q1000.json" | ||
# test_parse_pubmed(filename) | ||
# Test the random document retrieval | ||
for _ in range(3): # Test 3 random retrievals | ||
doc = getRandomDocument() | ||
if doc: | ||
logging.info(f"Retrieved document: ID={doc['id']}, Title={doc['title'][:50]}...") | ||
else: | ||
logging.warning("No document found for generated ID") |