Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VideoChat-Flash and InternVideo2.5 #563

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
"xcomposer2_4KHD": "XComposer2_4KHD",
"xcomposer2d5": "XComposer2D5",
"egogpt": "EgoGPT",
"internvideo2_5": "InternVideo2_5",
"videochat_flash": "VideoChat_Flash"
}


Expand Down
304 changes: 304 additions & 0 deletions lmms_eval/models/internvideo2_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
from typing import List, Tuple
from lmms_eval.api.instance import Instance
from decord import VideoReader, cpu
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
import numpy as np
from transformers import AutoModel, AutoTokenizer
from lmms_eval.api.registry import register_model
from accelerate import Accelerator, DistributedType
from lmms_eval.api.model import lmms
from tqdm import tqdm
import logging
import io
from petrel_client.client import Client
from llava.video_utils import VIDEO_READER_FUNCS

client = Client('~/petreloss.conf')

eval_logger = logging.getLogger("eval_logger")

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

DEFAULT_GEN_KWARGS = dict(
num_beams=1,
max_new_tokens=1024,
do_sample=False,
)


def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD)])
return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height

# calculate the existing image aspect ratio
target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)

# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images


def load_image(image, input_size=448, max_num=6):
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values


def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / num_segments
frame_indices = np.array([int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
return frame_indices


def load_video(video_path, max_frames_num, media_dict, input_size=448, max_num=1):
if type(video_path) != str:
assert len(video_path) == 1, video_path
video_path = video_path[0]

if 'start' in media_dict:
clip = [media_dict['start'], media_dict['end']]
else:
clip = None
# print("-------------------------------------------------------------------")
# print(media_dict['video_read_type'], clip, video_path, max_frames_num)
frames, frame_indices, fps, duration = VIDEO_READER_FUNCS[media_dict['video_read_type']](video_path=video_path, num_frames=max_frames_num, sample='middle', fix_start=None, min_num_frames=1, max_num_frames=-1, client=client, clip=clip, local_num_frames=-1)

sec = [str(round(f / fps, 1)) for f in frame_indices]

msg = f"\nThe video lasts for {duration:.2f} seconds, and {len(sec)} frames are uniformly sampled from it. "

pixel_values_list, num_patches_list = [], []
transform = build_transform(input_size=input_size)

for frame in frames:
img = [Image.fromarray(frame, mode='RGB')]
pixel_values = [transform(tile) for tile in img]
pixel_values = torch.stack(pixel_values)
num_patches_list.append(pixel_values.shape[0])
pixel_values_list.append(pixel_values)
pixel_values = torch.cat(pixel_values_list)
return pixel_values, num_patches_list


from datetime import timedelta
from accelerate.state import AcceleratorState
from accelerate.utils import InitProcessGroupKwargs


@register_model("internvideo2_5")
class InternVideo2_5(lmms):
def __init__(
self,
pretrained: str = "OpenGVLab/InternVideo2_5_Chat_8B",
modality: str = "video",
device: str = "cuda:0",
device_map: str = "cuda:0",
batch_size: str = "1",
max_frames_num: int = 32,
**kwargs,
):
super().__init__()
self.max_frames_num = max_frames_num
self.path = pretrained
self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True).eval().cuda()
self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)

batch_size = int(batch_size)
assert batch_size == 1, f"Batch size should be 1 for InternVL2, but got {batch_size}."
self.batch_size_per_gpu = batch_size

accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
self.accelerator = accelerator
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
elif accelerator.num_processes == 1 and device_map == "auto":
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"

if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")

if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == "auto":
eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
self._rank = 0
self._word_size = 1
else:
eval_logger.info(f"Using single device: {self._device}")
self.model.to(self._device)
self._rank = 0
self._world_size = 1

self.modality = modality

@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config

@property
def tokenizer(self):
return self._tokenizer

@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model

@property
def batch_size(self):
return self.batch_size_per_gpu

@property
def device(self):
return self._device

@property
def rank(self):
return self._rank

@property
def world_size(self):
return self._world_size

def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list

def generate_until(self, requests) -> List[str]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
if "until" in gen_kwargs:
gen_kwargs.pop("until")
for k, v in DEFAULT_GEN_KWARGS.items():
if k not in gen_kwargs:
gen_kwargs[k] = v

pop_keys = []
for k, v in gen_kwargs.items():
if k not in DEFAULT_GEN_KWARGS:
pop_keys.append(k)

for k in pop_keys:
gen_kwargs.pop(k)

visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
if self.modality == "image":
if visuals:
visuals = [load_image(visual).to(torch.bfloat16).cuda() for visual in visuals]
pixel_values = torch.cat(visuals, dim=0)
num_patches_list = [visual.size(0) for visual in visuals]
image_tokens = ["<image>"] * len(visuals)
image_tokens = " ".join(image_tokens)
contexts = image_tokens + "\n" + contexts
else:
pixel_values = None
num_patch_list = None
response, history = self.model.chat(self.tokenizer, pixel_values, contexts, gen_kwargs, num_patches_list=num_patches_list, history=None, return_history=True)
elif self.modality == "video":
# assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos. {visuals}"
video_path = visuals[0]
if len(visuals) > 1:
assert len(visuals) == 2, visuals
media_dict = visuals[1]
else:
media_dict = {'video_read_type': 'decord'}

pixel_values, num_patches_list = load_video(video_path, max_frames_num=self.max_frames_num, max_num=1, media_dict=media_dict)

pixel_values = pixel_values.to(torch.bfloat16).cuda()
video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
question = video_prefix + contexts
response, history = self.model.chat(self.tokenizer, pixel_values, question, gen_kwargs, num_patches_list=num_patches_list, history=None, return_history=True)
res.append(response)
pbar.update(1)
pbar.close()
return res

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
assert False, "Not implemented yet."
Loading
Loading