-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: wangli <wangli858794774@gmail.com>
- Loading branch information
Showing
9 changed files
with
2,474 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import os | ||
import sys | ||
import time | ||
import traceback | ||
from dataclasses import dataclass, field | ||
from typing import List, Optional, Union | ||
|
||
import aiohttp | ||
import huggingface_hub.constants | ||
from tqdm.asyncio import tqdm | ||
from transformers import (AutoTokenizer, PreTrainedTokenizer, | ||
PreTrainedTokenizerFast) | ||
|
||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) | ||
|
||
|
||
@dataclass | ||
class RequestFuncInput: | ||
prompt: str | ||
api_url: str | ||
prompt_len: int | ||
output_len: int | ||
model: str | ||
model_name: Optional[str] = None | ||
best_of: int = 1 | ||
logprobs: Optional[int] = None | ||
extra_body: Optional[dict] = None | ||
multi_modal_content: Optional[dict] = None | ||
ignore_eos: bool = False | ||
|
||
|
||
@dataclass | ||
class RequestFuncOutput: | ||
generated_text: str = "" | ||
success: bool = False | ||
latency: float = 0.0 | ||
output_tokens: int = 0 | ||
ttft: float = 0.0 # Time to first token | ||
itl: List[float] = field( | ||
default_factory=list) # List of inter-token latencies | ||
tpot: float = 0.0 # avg next-token latencies | ||
prompt_len: int = 0 | ||
error: str = "" | ||
|
||
|
||
async def async_request_openai_completions( | ||
request_func_input: RequestFuncInput, | ||
pbar: Optional[tqdm] = None, | ||
) -> RequestFuncOutput: | ||
api_url = request_func_input.api_url | ||
assert api_url.endswith( | ||
("completions", "profile") | ||
), "OpenAI Completions API URL must end with 'completions' or 'profile'." | ||
|
||
async with aiohttp.ClientSession(trust_env=True, | ||
timeout=AIOHTTP_TIMEOUT) as session: | ||
payload = { | ||
"model": request_func_input.model_name \ | ||
if request_func_input.model_name else request_func_input.model, | ||
"prompt": request_func_input.prompt, | ||
"temperature": 0.0, | ||
"best_of": request_func_input.best_of, | ||
"max_tokens": request_func_input.output_len, | ||
"logprobs": request_func_input.logprobs, | ||
"stream": True, | ||
"stream_options": { | ||
"include_usage": True, | ||
}, | ||
} | ||
if request_func_input.ignore_eos: | ||
payload["ignore_eos"] = request_func_input.ignore_eos | ||
if request_func_input.extra_body: | ||
payload.update(request_func_input.extra_body) | ||
headers = { | ||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" | ||
} | ||
|
||
output = RequestFuncOutput() | ||
output.prompt_len = request_func_input.prompt_len | ||
|
||
generated_text = "" | ||
st = time.perf_counter() | ||
most_recent_timestamp = st | ||
try: | ||
async with session.post(url=api_url, json=payload, | ||
headers=headers) as response: | ||
if response.status == 200: | ||
first_chunk_received = False | ||
async for chunk_bytes in response.content: | ||
chunk_bytes = chunk_bytes.strip() | ||
if not chunk_bytes: | ||
continue | ||
|
||
chunk = chunk_bytes.decode("utf-8").removeprefix( | ||
"data: ") | ||
if chunk != "[DONE]": | ||
data = json.loads(chunk) | ||
|
||
# NOTE: Some completion API might have a last | ||
# usage summary response without a token so we | ||
# want to check a token was generated | ||
if choices := data.get("choices"): | ||
# Note that text could be empty here | ||
# e.g. for special tokens | ||
text = choices[0].get("text") | ||
timestamp = time.perf_counter() | ||
# First token | ||
if not first_chunk_received: | ||
first_chunk_received = True | ||
ttft = time.perf_counter() - st | ||
output.ttft = ttft | ||
|
||
# Decoding phase | ||
else: | ||
output.itl.append(timestamp - | ||
most_recent_timestamp) | ||
|
||
most_recent_timestamp = timestamp | ||
generated_text += text or "" | ||
elif usage := data.get("usage"): | ||
output.output_tokens = usage.get( | ||
"completion_tokens") | ||
if first_chunk_received: | ||
output.success = True | ||
else: | ||
output.success = False | ||
output.error = ( | ||
"Never received a valid chunk to calculate TTFT." | ||
"This response will be marked as failed!") | ||
output.generated_text = generated_text | ||
output.latency = most_recent_timestamp - st | ||
else: | ||
output.error = response.reason or "" | ||
output.success = False | ||
except Exception: | ||
output.success = False | ||
exc_info = sys.exc_info() | ||
output.error = "".join(traceback.format_exception(*exc_info)) | ||
|
||
if pbar: | ||
pbar.update(1) | ||
return output | ||
|
||
|
||
def get_model(pretrained_model_name_or_path: str) -> str: | ||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': | ||
from modelscope import snapshot_download | ||
|
||
model_path = snapshot_download( | ||
model_id=pretrained_model_name_or_path, | ||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, | ||
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) | ||
|
||
return model_path | ||
return pretrained_model_name_or_path | ||
|
||
def get_tokenizer( | ||
pretrained_model_name_or_path: str, | ||
tokenizer_mode: str = "auto", | ||
trust_remote_code: bool = False, | ||
**kwargs, | ||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: | ||
if pretrained_model_name_or_path is not None and not os.path.exists( | ||
pretrained_model_name_or_path): | ||
pretrained_model_name_or_path = get_model( | ||
pretrained_model_name_or_path) | ||
if tokenizer_mode == "slow": | ||
if kwargs.get("use_fast", False): | ||
raise ValueError( | ||
"Cannot use the fast tokenizer in slow tokenizer mode.") | ||
kwargs["use_fast"] = False | ||
if tokenizer_mode == "mistral": | ||
try: | ||
from vllm.transformers_utils.tokenizer import MistralTokenizer | ||
except ImportError as e: | ||
raise ImportError("MistralTokenizer requires vllm package.\n" | ||
"Please install it with `pip install vllm` " | ||
"to use mistral tokenizer mode.") from e | ||
return MistralTokenizer.from_pretrained( | ||
str(pretrained_model_name_or_path)) | ||
else: | ||
return AutoTokenizer.from_pretrained( | ||
pretrained_model_name_or_path, | ||
trust_remote_code=trust_remote_code, | ||
**kwargs, | ||
) | ||
|
||
ASYNC_REQUEST_FUNCS = { | ||
"vllm": async_request_openai_completions, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Benchmark the latency of processing a single batch of requests.""" | ||
import argparse | ||
import dataclasses | ||
import json | ||
import time | ||
from pathlib import Path | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
import torch | ||
from tqdm import tqdm | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.inputs import PromptType | ||
from vllm.sampling_params import BeamSearchParams | ||
from vllm.utils import FlexibleArgumentParser | ||
|
||
|
||
def main(args: argparse.Namespace): | ||
print(args) | ||
|
||
engine_args = EngineArgs.from_cli_args(args) | ||
|
||
# NOTE(woosuk): If the request cannot be processed in a single batch, | ||
# the engine will automatically process the request in multiple batches. | ||
llm = LLM(**dataclasses.asdict(engine_args)) | ||
|
||
sampling_params = SamplingParams( | ||
n=args.n, | ||
temperature=1.0, | ||
top_p=1.0, | ||
ignore_eos=True, | ||
max_tokens=args.output_len, | ||
) | ||
print(sampling_params) | ||
dummy_prompt_token_ids = np.random.randint(10000, | ||
size=(args.batch_size, | ||
args.input_len)) | ||
dummy_prompts: List[PromptType] = [{ | ||
"prompt_token_ids": batch | ||
} for batch in dummy_prompt_token_ids.tolist()] | ||
|
||
def llm_generate(): | ||
if not args.use_beam_search: | ||
llm.generate(dummy_prompts, | ||
sampling_params=sampling_params, | ||
use_tqdm=False) | ||
else: | ||
llm.beam_search( | ||
dummy_prompts, | ||
BeamSearchParams( | ||
beam_width=args.n, | ||
max_tokens=args.output_len, | ||
ignore_eos=True, | ||
)) | ||
|
||
def run_to_completion(profile_dir: Optional[str] = None): | ||
if profile_dir: | ||
with torch.profiler.profile( | ||
activities=[ | ||
torch.profiler.ProfilerActivity.CPU, | ||
torch.profiler.ProfilerActivity.CUDA, | ||
], | ||
on_trace_ready=torch.profiler.tensorboard_trace_handler( | ||
str(profile_dir))) as p: | ||
llm_generate() | ||
print(p.key_averages().table(sort_by="self_cuda_time_total")) | ||
else: | ||
start_time = time.perf_counter() | ||
llm_generate() | ||
end_time = time.perf_counter() | ||
latency = end_time - start_time | ||
return latency | ||
|
||
print("Warming up...") | ||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): | ||
run_to_completion(profile_dir=None) | ||
|
||
if args.profile: | ||
profile_dir = args.profile_result_dir | ||
if not profile_dir: | ||
profile_dir = Path( | ||
"." | ||
) / "vllm_benchmark_result" / f"latency_result_{time.time()}" | ||
print(f"Profiling (results will be saved to '{profile_dir}')...") | ||
run_to_completion(profile_dir=profile_dir) | ||
return | ||
|
||
# Benchmark. | ||
latencies = [] | ||
|
||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): | ||
latencies.append(run_to_completion(profile_dir=None)) | ||
latencies = np.array(latencies) | ||
percentages = [10, 25, 50, 75, 90, 99] | ||
percentiles = np.percentile(latencies, percentages) | ||
print(f'Avg latency: {np.mean(latencies)} seconds') | ||
for percentage, percentile in zip(percentages, percentiles): | ||
print(f'{percentage}% percentile latency: {percentile} seconds') | ||
|
||
# Output JSON results if specified | ||
if args.output_json: | ||
results = { | ||
"avg_latency": np.mean(latencies), | ||
"latencies": latencies.tolist(), | ||
"percentiles": dict(zip(percentages, percentiles.tolist())), | ||
} | ||
with open(args.output_json, "w") as f: | ||
json.dump(results, f, indent=4) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = FlexibleArgumentParser( | ||
description='Benchmark the latency of processing a single batch of ' | ||
'requests till completion.') | ||
parser.add_argument('--input-len', type=int, default=32) | ||
parser.add_argument('--output-len', type=int, default=128) | ||
parser.add_argument('--batch-size', type=int, default=8) | ||
parser.add_argument('--n', | ||
type=int, | ||
default=1, | ||
help='Number of generated sequences per prompt.') | ||
parser.add_argument('--use-beam-search', action='store_true') | ||
parser.add_argument('--num-iters-warmup', | ||
type=int, | ||
default=10, | ||
help='Number of iterations to run for warmup.') | ||
parser.add_argument('--num-iters', | ||
type=int, | ||
default=30, | ||
help='Number of iterations to run.') | ||
parser.add_argument( | ||
'--profile', | ||
action='store_true', | ||
help='profile the generation process of a single batch') | ||
parser.add_argument( | ||
'--profile-result-dir', | ||
type=str, | ||
default=None, | ||
help=('path to save the pytorch profiler output. Can be visualized ' | ||
'with ui.perfetto.dev or Tensorboard.')) | ||
parser.add_argument( | ||
'--output-json', | ||
type=str, | ||
default=None, | ||
help='Path to save the latency results in JSON format.') | ||
|
||
parser = EngineArgs.add_cli_args(parser) | ||
args = parser.parse_args() | ||
main(args) |
Oops, something went wrong.