-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_inference_speeds.py
82 lines (71 loc) · 2.33 KB
/
get_inference_speeds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import time
import json
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import transformers
from transformers import AutoConfig, AutoModelForSequenceClassification
import warnings
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error() # suppress bigbird warning on seq<256 where it uses standard attention
def get_inference_speed(model, seq_len=256, bs=1, n_trials=1, device="cpu"):
config = AutoConfig.from_pretrained(model, num_labels=4)
if "gpt" in model: # we don't use the tokenizer but this prevents errors:
config.pad_token_id = config.eos_token_id
model = AutoModelForSequenceClassification.from_config(config)
model.eval()
model.to(device)
input_ids = torch.randint(0, 10000, (bs, seq_len))
with torch.no_grad():
start = time.time()
for _ in range(n_trials):
model(input_ids)
end = time.time()
return (end - start) / n_trials
def gather_all_speeds(
models,
batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128],
seq_lens=[32, 64, 128, 256, 512],
n_trials=1,
device="cpu",
threads=24,
):
torch.set_num_interop_threads(threads) # Inter-op parallelism
torch.set_num_threads(threads)
results = []
for bs in batch_sizes:
for seq_len in seq_lens:
for model in models:
speed = get_inference_speed(model, seq_len, bs, n_trials, device)
results.append({"model": model, "bs": bs, "seq_len": seq_len, "speed": speed})
results = pd.DataFrame(results, columns=["model", "bs", "seq_len", "speed"])
return results
models = [
"albert-base-v2",
"bert-base-uncased",
"distilgpt2",
"distilroberta-base",
"EleutherAI/gpt-neo-125M",
"facebook/muppet-roberta-base",
"funnel-transformer/medium-base",
"funnel-transformer/small-base",
"google/bigbird-roberta-base",
"google/electra-base-discriminator",
"google/mobilebert-uncased",
"gpt2",
"microsoft/deberta-v3-base",
"roberta-base",
"squeezebert/squeezebert-uncased",
]
device = "cpu"
threads = 24
results = gather_all_speeds(
models,
batch_sizes=[1, 2, 4, 8, 16, 32],
seq_lens=[16, 32, 64, 128, 256, 512],
n_trials=5,
device=device,
threads=threads,
)
results.to_csv(f"results/inference_speed_{device}_{threads}_while_train.csv")