From 7dbfff59c9f1a153ea85f85983179bdcf296184a Mon Sep 17 00:00:00 2001 From: jemmyshin Date: Tue, 18 Jul 2023 17:16:06 +0800 Subject: [PATCH] fix: fix benchmark --- open_gpt/profile.py | 13 ++++++++----- scripts/benchmark.py | 9 ++------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/open_gpt/profile.py b/open_gpt/profile.py index 2db6990..2dedf84 100644 --- a/open_gpt/profile.py +++ b/open_gpt/profile.py @@ -117,11 +117,13 @@ def end_measure(start_measures): # GPU mem for i in range(torch.cuda.device_count()): measures[str(i)] = ( - torch.cuda.memory_allocated(i) - start_measures[str(i)] - ) / GB + torch.cuda.memory_allocated(i) - start_measures[ + str(i)] + ) / GB measures[f"{i}-peak"] = ( - torch.cuda.max_memory_allocated(i) - start_measures[str(i)] - ) / GB + torch.cuda.max_memory_allocated(i) - + start_measures[str(i)] + ) / GB return measures @@ -165,7 +167,8 @@ def end_record(self, generation_outputs: Union[str, List[str]]): ) else: num_tokens = sum( - list(map(lambda x: len(self._tokenizer(x)) - 2, generation_outputs)) + list(map(lambda x: len(self._tokenizer(x)['input_ids']) - 2, + generation_outputs)) ) self._generation_length.append(num_tokens) self._time_stamp = None diff --git a/scripts/benchmark.py b/scripts/benchmark.py index 76657c4..ce33fa7 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -10,7 +10,7 @@ def run_benchmark(model, max_new_tokens, llm_measure): generated_text = model.generate( PROMPT, max_new_tokens=max_new_tokens, do_sample=args.do_sample ) - llm_measure.end_record(generated_text) + llm_measure.end_record(generated_text['choices'][0]['text']) llm_measure.stats(stage='prefill' if max_new_tokens == 1 else 'decoding') llm_measure.clear() @@ -71,7 +71,7 @@ def main(args): '--precision', type=str, default='fp16', help='precision used for inference' ) parser.add_argument( - '--repeat-time', type=int, default=10, help='repeat time for benchmark' + '--repeat-time', type=int, default=100, help='repeat time for benchmark' ) parser.add_argument( '--do-sample', @@ -90,10 +90,5 @@ def main(args): 'bit4', 'bit8', ], 'precision must be fp16 or bit4 or bit8' - if args.adapter_name is not None: - assert args.precision in [ - 'bit4', - 'bit8', - ], 'precision must be bit4 or bit8 when using adapter' main(args)