Skip to content

Commit 25db060

Browse files
author
weimingc
committed
fix
1 parent d7d54bc commit 25db060

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

demo_trt_llm/build_visual_engine.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,10 @@ def build_vila_engine(args):
240240
vision_tower = model.get_vision_tower()
241241
image_processor = vision_tower.image_processor
242242
raw_image = Image.new('RGB', [10, 10]) # dummy image
243-
image = image_processor(images=raw_image,
244-
return_tensors="pt")['pixel_values'].to(
245-
args.device, torch.float16)
246-
243+
image = image_processor(images=raw_image,return_tensors="pt")['pixel_values']
244+
if isinstance(image, list):
245+
image = image[0].unsqueeze(0)
246+
image = image.to(args.device, torch.float16)
247247
class VilaVisionWrapper(torch.nn.Module):
248248

249249
def __init__(self, tower, projector):

0 commit comments

Comments
 (0)