Skip to content

Commit

Permalink
Merge pull request #65 from jina-ai/fix-benchmark
Browse files Browse the repository at this point in the history
fix: fix benchmark
  • Loading branch information
jemmyshin authored Jul 19, 2023
2 parents 3125db8 + 7dbfff5 commit b6caad1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
13 changes: 8 additions & 5 deletions open_gpt/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,13 @@ def end_measure(start_measures):
# GPU mem
for i in range(torch.cuda.device_count()):
measures[str(i)] = (
torch.cuda.memory_allocated(i) - start_measures[str(i)]
) / GB
torch.cuda.memory_allocated(i) - start_measures[
str(i)]
) / GB
measures[f"{i}-peak"] = (
torch.cuda.max_memory_allocated(i) - start_measures[str(i)]
) / GB
torch.cuda.max_memory_allocated(i) -
start_measures[str(i)]
) / GB

return measures

Expand Down Expand Up @@ -165,7 +167,8 @@ def end_record(self, generation_outputs: Union[str, List[str]]):
)
else:
num_tokens = sum(
list(map(lambda x: len(self._tokenizer(x)) - 2, generation_outputs))
list(map(lambda x: len(self._tokenizer(x)['input_ids']) - 2,
generation_outputs))
)
self._generation_length.append(num_tokens)
self._time_stamp = None
Expand Down
9 changes: 2 additions & 7 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def run_benchmark(model, max_new_tokens, llm_measure):
generated_text = model.generate(
PROMPT, max_new_tokens=max_new_tokens, do_sample=args.do_sample
)
llm_measure.end_record(generated_text)
llm_measure.end_record(generated_text['choices'][0]['text'])
llm_measure.stats(stage='prefill' if max_new_tokens == 1 else 'decoding')
llm_measure.clear()

Expand Down Expand Up @@ -71,7 +71,7 @@ def main(args):
'--precision', type=str, default='fp16', help='precision used for inference'
)
parser.add_argument(
'--repeat-time', type=int, default=10, help='repeat time for benchmark'
'--repeat-time', type=int, default=100, help='repeat time for benchmark'
)
parser.add_argument(
'--do-sample',
Expand All @@ -90,10 +90,5 @@ def main(args):
'bit4',
'bit8',
], 'precision must be fp16 or bit4 or bit8'
if args.adapter_name is not None:
assert args.precision in [
'bit4',
'bit8',
], 'precision must be bit4 or bit8 when using adapter'

main(args)

0 comments on commit b6caad1

Please sign in to comment.