Skip to content

Commit 7f98923

Browse files
cfgfungregisss
andauthored
Add an example of Segment Anything Model [Inference] (huggingface#814)
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
1 parent 2a0da6b commit 7f98923

File tree

6 files changed

+257
-4
lines changed

6 files changed

+257
-4
lines changed

Makefile

+5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ fast_tests_diffusers:
4141
python -m pip install .[tests]
4242
python -m pytest tests/test_diffusers.py
4343

44+
# Run unit and integration tests related to Image segmentation
45+
fast_tests_image_segmentation:
46+
python -m pip install .[tests]
47+
python -m pytest tests/test_image_segmentation.py
48+
4449
# Run single-card non-regression tests
4550
slow_tests_1x: test_installs
4651
python -m pytest tests/test_examples.py -v -s -k "single_card"

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ The following model architectures, tasks and device distributions have been vali
214214
| OWLViT | | <div style="text-align:left"><li>Single card</li></div> | <li>[zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)</li> |
215215
| ClipSeg | | <div style="text-align:left"><li>Single card</li></div> | <li>[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)</li> |
216216
| Llava / Llava-next | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
217-
217+
| Segment Anything Model | | <div style="text-align:left"><li>Single card</li></div> | <li>[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)</li> |
218218
</div>
219219
220220
- Diffusers:

docs/source/index.mdx

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
7272
| OWLViT | | <div style="text-align:left"><li>Single card</li></div> | <li>[zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)</li> |
7373
| ClipSeg | | <div style="text-align:left"><li>Single card</li></div> | <li>[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)</li> |
7474
| Llava / Llava-next | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
75-
75+
| SAM | | <div style="text-align:left"><li>Single card</li></div> | <li>[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)</li> |
7676

7777
- Diffusers
7878

examples/object-segementation/README.md

+21-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ limitations under the License.
1313

1414
# Object Segmentation Examples
1515

16-
This directory contains an example script that demonstrates how to perform object segmentation on Gaudi with graph mode.
16+
This directory contains two example scripts that demonstrate how to perform object segmentation on Gaudi with graph mode.
1717

1818
## Single-HPU inference
1919

20+
### ClipSeg Model
21+
2022
```bash
2123
python3 run_example.py \
2224
--model_name_or_path "CIDAS/clipseg-rd64-refined" \
@@ -29,4 +31,21 @@ python3 run_example.py \
2931
--print_result
3032
```
3133
Models that have been validated:
32-
- [clipseg-rd64-refined ](https://huggingface.co/CIDAS/clipseg-rd64-refined)
34+
- [clipseg-rd64-refined ](https://huggingface.co/CIDAS/clipseg-rd64-refined)
35+
36+
### Segment Anything Model
37+
38+
```bash
39+
python3 run_example_sam.py \
40+
--model_name_or_path "facebook/sam-vit-huge" \
41+
--image_path "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" \
42+
--point_prompt "450,600" \
43+
--warmup 3 \
44+
--n_iterations 20 \
45+
--use_hpu_graphs \
46+
--bf16 \
47+
--print_result
48+
```
49+
Models that have been validated:
50+
- [facebook/sam-vit-base](https://huggingface.co/facebook/sam-vit-base)
51+
- [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
16+
# Copied from https://huggingface.co/facebook/sam-vit-base
17+
18+
import argparse
19+
import time
20+
21+
import habana_frameworks.torch as ht
22+
import requests
23+
import torch
24+
from PIL import Image
25+
from transformers import AutoModel, AutoProcessor
26+
27+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
28+
29+
30+
if __name__ == "__main__":
31+
parser = argparse.ArgumentParser()
32+
33+
parser.add_argument(
34+
"--model_name_or_path",
35+
default="facebook/sam-vit-huge",
36+
type=str,
37+
help="Path of the pre-trained model",
38+
)
39+
parser.add_argument(
40+
"--image_path",
41+
default="https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png",
42+
type=str,
43+
help='Path of the input image. Should be a single string (eg: --image_path "URL")',
44+
)
45+
parser.add_argument(
46+
"--point_prompt",
47+
default="450, 600",
48+
type=str,
49+
help='Prompt for segmentation. It should be a string seperated by comma. (eg: --point_prompt "450, 600")',
50+
)
51+
parser.add_argument(
52+
"--use_hpu_graphs",
53+
action="store_true",
54+
help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
55+
)
56+
parser.add_argument(
57+
"--bf16",
58+
action="store_true",
59+
help="Whether to use bf16 precision for classification.",
60+
)
61+
parser.add_argument(
62+
"--print_result",
63+
action="store_true",
64+
help="Whether to save the segmentation result.",
65+
)
66+
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
67+
parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.")
68+
69+
args = parser.parse_args()
70+
71+
adapt_transformers_to_gaudi()
72+
73+
processor = AutoProcessor.from_pretrained(args.model_name_or_path)
74+
model = AutoModel.from_pretrained(args.model_name_or_path)
75+
76+
image = Image.open(requests.get(args.image_path, stream=True).raw).convert("RGB")
77+
points = []
78+
for text in args.point_prompt.split(","):
79+
points.append(int(text))
80+
points = [[points]]
81+
82+
if args.use_hpu_graphs:
83+
model = ht.hpu.wrap_in_hpu_graph(model)
84+
85+
autocast = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16)
86+
model.to("hpu")
87+
88+
with torch.no_grad(), autocast:
89+
for i in range(args.warmup):
90+
inputs = processor(image, input_points=points, return_tensors="pt").to("hpu")
91+
outputs = model(**inputs)
92+
torch.hpu.synchronize()
93+
94+
total_model_time = 0
95+
for i in range(args.n_iterations):
96+
inputs = processor(image, input_points=points, return_tensors="pt").to("hpu")
97+
model_start_time = time.time()
98+
outputs = model(**inputs)
99+
torch.hpu.synchronize()
100+
model_end_time = time.time()
101+
total_model_time = total_model_time + (model_end_time - model_start_time)
102+
103+
if args.print_result:
104+
if i == 0: # generate/output once only
105+
iou = outputs.iou_scores
106+
print("iou score: " + str(iou))
107+
108+
print("n_iterations: " + str(args.n_iterations))
109+
print("Total latency (ms): " + str(total_model_time * 1000))
110+
print("Average latency (ms): " + str(total_model_time * 1000 / args.n_iterations))

tests/test_image_segmentation.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import time
17+
from unittest import TestCase
18+
19+
import habana_frameworks.torch as ht
20+
import numpy as np
21+
import requests
22+
import torch
23+
from PIL import Image
24+
from transformers import AutoModel, AutoProcessor
25+
26+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
27+
28+
29+
adapt_transformers_to_gaudi()
30+
31+
# For Gaudi 2
32+
LATENCY_OWLVIT_BF16_GRAPH_BASELINE = 3.7109851837158203
33+
LATENCY_SAM_BF16_GRAPH_BASELINE = 98.92215728759766
34+
35+
36+
class GaudiSAMTester(TestCase):
37+
"""
38+
Tests for Segment Anything Model - SAM
39+
"""
40+
41+
def prepare_model_and_processor(self):
42+
model = AutoModel.from_pretrained("facebook/sam-vit-huge").to("hpu")
43+
processor = AutoProcessor.from_pretrained("facebook/sam-vit-huge")
44+
model = model.eval()
45+
return model, processor
46+
47+
def prepare_data(self):
48+
image = Image.open(
49+
requests.get(
50+
"https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png", stream=True
51+
).raw
52+
).convert("RGB")
53+
input_points = [[[450, 600]]]
54+
return input_points, image
55+
56+
def test_inference_default(self):
57+
model, processor = self.prepare_model_and_processor()
58+
input_points, image = self.prepare_data()
59+
inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
60+
outputs = model(**inputs)
61+
scores = outputs.iou_scores
62+
scores = scores[0][0]
63+
expected_scores = np.array([0.9912, 0.9818, 0.9666])
64+
self.assertEqual(len(scores), 3)
65+
self.assertLess(np.abs(scores.cpu().detach().numpy() - expected_scores).max(), 0.02)
66+
67+
def test_inference_bf16(self):
68+
model, processor = self.prepare_model_and_processor()
69+
input_points, image = self.prepare_data()
70+
inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
71+
72+
with torch.autocast(device_type="hpu", dtype=torch.bfloat16): # Autocast BF16
73+
outputs = model(**inputs)
74+
scores = outputs.iou_scores
75+
scores = scores[0][0]
76+
expected_scores = np.array([0.9912, 0.9818, 0.9666])
77+
self.assertEqual(len(scores), 3)
78+
self.assertLess(np.abs(scores.to(torch.float32).cpu().detach().numpy() - expected_scores).max(), 0.02)
79+
80+
def test_inference_hpu_graphs(self):
81+
model, processor = self.prepare_model_and_processor()
82+
input_points, image = self.prepare_data()
83+
inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
84+
85+
model = ht.hpu.wrap_in_hpu_graph(model) # Apply graph
86+
87+
outputs = model(**inputs)
88+
scores = outputs.iou_scores
89+
scores = scores[0][0]
90+
expected_scores = np.array([0.9912, 0.9818, 0.9666])
91+
self.assertEqual(len(scores), 3)
92+
self.assertLess(np.abs(scores.to(torch.float32).cpu().detach().numpy() - expected_scores).max(), 0.02)
93+
94+
def test_no_latency_regression_bf16(self):
95+
warmup = 3
96+
iterations = 10
97+
98+
model, processor = self.prepare_model_and_processor()
99+
input_points, image = self.prepare_data()
100+
101+
model = ht.hpu.wrap_in_hpu_graph(model)
102+
103+
with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
104+
for i in range(warmup):
105+
inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
106+
_ = model(**inputs)
107+
torch.hpu.synchronize()
108+
109+
total_model_time = 0
110+
for i in range(iterations):
111+
inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
112+
model_start_time = time.time()
113+
_ = model(**inputs)
114+
torch.hpu.synchronize()
115+
model_end_time = time.time()
116+
total_model_time = total_model_time + (model_end_time - model_start_time)
117+
118+
latency = total_model_time * 1000 / iterations # in terms of ms
119+
self.assertLessEqual(latency, 1.05 * LATENCY_SAM_BF16_GRAPH_BASELINE)

0 commit comments

Comments
 (0)