Skip to content

Commit

Permalink
fix: draft commits for flamingo
Browse files Browse the repository at this point in the history
  • Loading branch information
numb3r3 committed Apr 21, 2023
1 parent f57a8e3 commit 63a6647
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 38 deletions.
4 changes: 4 additions & 0 deletions open_gpts/models/flamingo/flamingo_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def init_flamingo(
vis_hidden_size,
cross_attn_every_n_layers,
use_media_placement_augmentation,
dtype=None,
):
"""
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
Expand Down Expand Up @@ -105,6 +106,9 @@ def init_flamingo(
self.use_media_placement_augmentation = use_media_placement_augmentation
self.initialized_flamingo = True

if dtype is not None and str(dtype) == 'torch.float16':
self.gated_cross_attn_layers.half()

def forward(self, *input, **kwargs):
"""Condition the Flamingo layers on the media locations before forward()"""

Expand Down
33 changes: 17 additions & 16 deletions open_gpts/models/flamingo/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def load_model_and_transforms(
AlignDevicesHook,
add_hook_to_module,
attach_align_device_hook_on_blocks,
remove_hook_from_module,
)

# load the vision model
model_name, *pretrained = clip_model_name.split("::")
pretrained = pretrained[0] if len(pretrained) == 1 else 'openai'
clip_model, _, image_processor = open_clip.create_model_and_transforms(
model_name, pretrained=pretrained, device=device, precision='fp16'
model_name, pretrained=pretrained, device='cuda', precision='fp16'
)
clip_model.to('cuda')
# set the vision encoder to output the visual features
clip_model.visual.output_tokens = True

Expand All @@ -54,17 +54,17 @@ def load_model_and_transforms(
elif hasattr(clip_model, 'transformer'):
del clip_model.transformer

execution_device = next(iter(clip_model.parameters())).device
add_hook_to_module(clip_model, AlignDevicesHook(io_same_device=True), append=True)

attach_align_device_hook_on_blocks(
clip_model,
execution_device=execution_device,
offload=None,
offload_buffers=False,
weights_map=None,
preload_module_classes=None,
)
# execution_device = next(iter(clip_model.parameters())).device
# add_hook_to_module(clip_model, AlignDevicesHook(io_same_device=True), append=True)
#
# attach_align_device_hook_on_blocks(
# clip_model,
# execution_device=execution_device,
# offload=None,
# offload_buffers=False,
# weights_map=None,
# preload_module_classes=None,
# )

# load the language model
lang_model, tokenizer = load_model_and_tokenizer(
Expand Down Expand Up @@ -97,7 +97,7 @@ def load_model_and_transforms(

flamingo_config = {
"image_size": open_clip.get_model_config(model_name)["vision_cfg"]["width"],
"cross_attn_every_n_layers": 1,
"cross_attn_every_n_layers": 4,
"end_chunk_token_id": tokenizer.encode("<|endofchunk|>")[-1],
"media_token_id": tokenizer.encode("<image>")[-1],
}
Expand All @@ -119,6 +119,7 @@ def load_model_and_transforms(
lang_model,
model_config=flamingo_config,
device=device,
dtype='torch.float16',
)

# Freeze all parameters
Expand All @@ -139,8 +140,8 @@ def load_model_and_transforms(
import torch
from huggingface_hub import hf_hub_download

checkpoint_path = hf_hub_download(model_name_or_path, "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)
# checkpoint_path = hf_hub_download(model_name_or_path, "checkpoint.pt")
# model.load_state_dict(torch.load(checkpoint_path), strict=False)

return model, tokenizer, image_processor

Expand Down
27 changes: 22 additions & 5 deletions open_gpts/models/flamingo/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(
language_model: 'nn.Module',
model_config: dict = {},
device: Optional[Union[str, 'torch.device']] = 'cpu',
dtype: Optional[Union[str, 'torch.dtype']] = 'torch.float32',
**kwargs,
):
"""An open source version of DeepMind's Flamingo model!
Expand All @@ -22,6 +23,7 @@ def __init__(
:param language_model: the language model to extract textual features and generate the output texts, e.g., LLaMa model
:param model_config: a dictionary of model configuration
:param device: the device to run the model on
:param dtype: the data type to run the model on
:param kwargs: other arguments
"""
super().__init__()
Expand All @@ -32,6 +34,9 @@ def __init__(
self.language_model = language_model

self.perceiver = PerceiverResampler(dim=self.model_config['image_size'])
if str(dtype) == 'torch.float16':
self.perceiver.half()

self.perceiver.to(device)

self.media_token_id = model_config['media_token_id']
Expand All @@ -42,6 +47,7 @@ def __init__(
vis_hidden_size=model_config['image_size'],
cross_attn_every_n_layers=model_config['cross_attn_every_n_layers'],
use_media_placement_augmentation=False,
dtype=dtype,
)
self.language_model.gated_cross_attn_layers.to(device)

Expand Down Expand Up @@ -117,9 +123,18 @@ def generate(
:return: text_inputs with generated tokens appended to it (batch_size, sequence_length)
"""

vision_inputs = vision_inputs.to(dtype=torch.float16)
vision_inputs = vision_inputs.cuda()
text_inputs = text_inputs.cuda()
if attention_mask is not None:
attention_mask = attention_mask.cuda()

if num_beams > 1:
vision_inputs = vision_inputs.repeat_interleave(num_beams, dim=0)

print(f'===> vision inputs device: {vision_inputs.device}')

vision_x = self._vision_encode(vision_inputs=vision_inputs)

print(f'====> encode vision done {vision_x.device}...')
Expand All @@ -137,7 +152,8 @@ def generate(
layer.condition_attend_previous(attend_previous)

print(f'===> start generation ...')

print(f'===> text device: {text_inputs.device}')
print(f'===> attention_mask device: {attention_mask.device}')
output = self.language_model.generate(
text_inputs,
attention_mask=attention_mask,
Expand Down Expand Up @@ -176,8 +192,9 @@ def _vision_encode(self, vision_inputs: 'torch.Tensor') -> 'torch.Tensor':
B, T, F = vision_inputs.shape[:3]
assert F == 1, "Only single frame supported"

device = next(iter(self.vision_encoder.parameters())).device
vision_inputs = vision_inputs.to(device)
# device = next(iter(self.vision_encoder.parameters())).device
# print(f'===> encoder device: {device}')
vision_inputs = vision_inputs.to('cuda')

vision_x = rearrange(vision_inputs, "B T F c h w -> (B T F) c h w")

Expand All @@ -188,9 +205,9 @@ def _vision_encode(self, vision_inputs: 'torch.Tensor') -> 'torch.Tensor':

vision_x = self.perceiver(vision_x) # reshapes to (B, T, n, d)

device = next(iter(self.language_model.parameters())).device
# device = next(iter(self.language_model.parameters())).device

vision_x = vision_x.to(device)
vision_x = vision_x.to('cuda')

for layer in self.language_model._get_decoder_layers():
layer.condition_vis_x(vision_x)
Expand Down
34 changes: 18 additions & 16 deletions open_gpts/models/llama/loading.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import TYPE_CHECKING, Union

import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

if TYPE_CHECKING:
import torch

from loguru import logger
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


def load_model_and_tokenizer(
Expand All @@ -20,18 +17,23 @@ def load_model_and_tokenizer(
tokenizer_name_or_path, local_files_only=True
)

# Create a model and initialize it with empty weights
config = AutoConfig.from_pretrained(model_name_or_path, local_files_only=True)

with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
# # Create a model and initialize it with empty weights
# config = AutoConfig.from_pretrained(model_name_or_path, local_files_only=True)
#
# with init_empty_weights():
# model = AutoModelForCausalLM.from_config(config)
#
# # Load the checkpoint and dispatch it to the right devices
# model = load_checkpoint_and_dispatch(
# model, model_name_or_path, device_map="auto", dtype=dtype, **kwargs
# )

# Load the checkpoint and dispatch it to the right devices
model = load_checkpoint_and_dispatch(
model, model_name_or_path, device_map="auto", dtype=dtype, **kwargs
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.float16,
# device_map="auto",
local_files_only=False,
)
model.to(torch.device('cuda:0'))

# model = AutoModelForCausalLM.from_pretrained(
# model_name_or_path, local_files_only=False
# )
return model, tokenizer
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry.dependencies]
# Compatible Python versions
python = ">=3.8"
torch = ">=1.9.0,<2.0.0" # a meta device requires torch >= 1.9.0
# torch = "^1.9,<2.0.0" # a meta device requires torch >= 1.9.0
loguru = "^0.5"
click = "^8.1.3"
numpy = "^1.21.2"
Expand Down

0 comments on commit 63a6647

Please sign in to comment.