forked from huggingface/optimum-habana
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_video_mae.py
135 lines (111 loc) · 4.44 KB
/
test_video_mae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from unittest import TestCase
import habana_frameworks.torch as ht
import numpy as np
import pytest
import torch
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
if os.environ.get("GAUDI2_CI", "0") == "1":
# Gaudi2 CI baselines
LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE = 17.544198036193848
else:
# Gaudi1 CI baselines
LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE = 61.953186988830566
MODEL_NAME = "MCG-NJU/videomae-base-finetuned-kinetics"
@pytest.fixture(scope="module")
def frame_buf():
return list(np.random.default_rng(123).random((16, 3, 224, 224)))
@pytest.fixture(scope="module")
def processor():
return VideoMAEImageProcessor.from_pretrained(MODEL_NAME)
@pytest.fixture(autouse=True, scope="class")
def inputs(request, frame_buf, processor):
request.cls.inputs = processor(frame_buf, return_tensors="pt")
request.cls.inputs_hpu = request.cls.inputs.copy().to("hpu")
@pytest.fixture(autouse=True, scope="class")
def outputs_cpu(request):
model = VideoMAEForVideoClassification.from_pretrained(MODEL_NAME)
model.eval()
with torch.no_grad():
output = model(**request.cls.inputs)
request.cls.outputs_cpu = output
@pytest.fixture(autouse=True, scope="class")
def model_hpu(request):
request.cls.model_hpu = VideoMAEForVideoClassification.from_pretrained(MODEL_NAME).to("hpu")
request.cls.model_hpu_graph = ht.hpu.wrap_in_hpu_graph(request.cls.model_hpu)
@pytest.fixture(autouse=True, scope="class")
def outputs_hpu_default(request):
with torch.no_grad():
output = request.cls.model_hpu(**request.cls.inputs_hpu)
request.cls.outputs_hpu_default = output
class GaudiVideoMAETester(TestCase):
"""
Tests for VideoMAE on Gaudi
"""
def test_inference_default(self):
"""
Tests for equivalent cpu and hpu runs
"""
self.assertTrue(
torch.equal(
self.outputs_cpu.logits.topk(10).indices,
self.outputs_hpu_default.logits.cpu().topk(10).indices,
)
)
self.assertTrue(torch.allclose(self.outputs_cpu.logits, self.outputs_hpu_default.logits, atol=5e-3))
def test_inference_bf16(self):
"""
Tests for similar bf16 to regular inference
"""
with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16):
outputs = self.model_hpu(**self.inputs_hpu)
self.assertTrue(
torch.equal(
self.outputs_hpu_default.logits.topk(5).indices,
outputs.logits.topk(5).indices,
)
)
def test_inference_graph_bf16(self):
"""
Test for similar bf16 to regular inference in graph mode
"""
with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16):
outputs = self.model_hpu_graph(**self.inputs_hpu)
self.assertTrue(
torch.equal(
self.outputs_hpu_default.logits.topk(5).indices,
outputs.logits.topk(5).indices,
)
)
def test_latency_graph_bf16(self):
"""
Tests for performance degredations by up to 5%
"""
warm_up_iters = 5
test_iters = 10
with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16):
for _ in range(warm_up_iters):
self.model_hpu_graph(**self.inputs_hpu)
torch.hpu.synchronize()
start_time = time.time()
with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16):
for _ in range(test_iters):
self.model_hpu_graph(**self.inputs_hpu)
torch.hpu.synchronize()
time_per_iter = (time.time() - start_time) * 1000 / test_iters # Time in ms
self.assertLess(time_per_iter, 1.05 * LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE)