We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d7d54bc commit 25db060Copy full SHA for 25db060
demo_trt_llm/build_visual_engine.py
@@ -240,10 +240,10 @@ def build_vila_engine(args):
240
vision_tower = model.get_vision_tower()
241
image_processor = vision_tower.image_processor
242
raw_image = Image.new('RGB', [10, 10]) # dummy image
243
- image = image_processor(images=raw_image,
244
- return_tensors="pt")['pixel_values'].to(
245
- args.device, torch.float16)
246
-
+ image = image_processor(images=raw_image,return_tensors="pt")['pixel_values']
+ if isinstance(image, list):
+ image = image[0].unsqueeze(0)
+ image = image.to(args.device, torch.float16)
247
class VilaVisionWrapper(torch.nn.Module):
248
249
def __init__(self, tower, projector):
0 commit comments