Skip to content

Commit

Permalink
Merge branch 'refs/heads/dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Feb 7, 2025
2 parents b9c025b + 6e4a84a commit 3486f9e
Show file tree
Hide file tree
Showing 15 changed files with 332 additions and 71 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/build-wheels-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ jobs:
- { artname: 'wheel', os: ubuntu-20.04, pyver: '3.12', cuda: '', rocm: '6.1', torch: '2.4.0', cudaarch: '' }

# ROCm 6.2
- { artname: 'wheel', os: ubuntu-20.04, pyver: '3.10', cuda: '', rocm: '6.2', torch: '2.5.0', cudaarch: '' }
- { artname: 'wheel', os: ubuntu-20.04, pyver: '3.11', cuda: '', rocm: '6.2', torch: '2.5.0', cudaarch: '' }
- { artname: 'wheel', os: ubuntu-20.04, pyver: '3.12', cuda: '', rocm: '6.2', torch: '2.5.0', cudaarch: '' }
- { artname: 'wheel', os: ubuntu-20.04-l, pyver: '3.10', cuda: '', rocm: '6.2', torch: '2.5.0', cudaarch: '' }
- { artname: 'wheel', os: ubuntu-20.04-l, pyver: '3.11', cuda: '', rocm: '6.2', torch: '2.5.0', cudaarch: '' }
- { artname: 'wheel', os: ubuntu-20.04-l, pyver: '3.12', cuda: '', rocm: '6.2', torch: '2.5.0', cudaarch: '' }

# sdist
- { artname: 'sdist', os: ubuntu-20.04, pyver: '3.11', cuda: '', rocm: '', torch: '2.3.1', cudaarch: '' }
Expand Down
80 changes: 73 additions & 7 deletions examples/multimodal_grounding_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Model:
current_image: Image or None = None
current_description: str

def __init__(self, model_directory):
def __init__(self, model_directory, bbox_mode: str):
self.model_directory = model_directory
self.config = None
self.vision_model = None
Expand All @@ -61,17 +61,22 @@ def __init__(self, model_directory):
self.current_image = None
self.current_emb = None
self.current_description = ""
bbox_funcs = {
"qwen2": self.get_grounding_bb_qwen2,
"qwen25": self.get_grounding_bb_qwen25,
}
self.bbox_func = bbox_funcs[bbox_mode]

def load(self):
"""Load and initialize the things"""
self.config = ExLlamaV2Config(self.model_directory)
self.config.max_seq_len = 16384
self.config.max_seq_len = 8192

self.vision_model = ExLlamaV2VisionTower(self.config)
self.vision_model.load(progress = True)

self.model = ExLlamaV2(self.config)
self.cache = ExLlamaV2Cache(self.model, lazy = True, max_seq_len = 16384)
self.cache = ExLlamaV2Cache(self.model, lazy = True, max_seq_len = 32768)
self.model.load_autosplit(self.cache, progress = True)
self.tokenizer = ExLlamaV2Tokenizer(self.config)

Expand Down Expand Up @@ -148,14 +153,21 @@ def inference(self, settext_fn, update_fn):
lastupdate = time.time()
settext_fn(text)
update_fn()
#
# text = \
# """And you may find yourself living in a shotgun shack
# And you may find yourself in another part of the world
# And you may find yourself behind the wheel of a large automobile
# And you may find yourself in a beautiful house, with a beautiful wife
# And you may ask yourself, "Well, how did I get here?\""""

settext_fn(text)
update_fn()
self.current_description = text
print("Image description from model:")
print(text)

def get_grounding_bb(self, start, end) -> tuple:
def get_grounding_bb_qwen2(self, start, end) -> tuple:
"""
Prompt the model again and try to extraxt the bounding box of the image details indicated by selected portion
of the description. We do this by repeating the exact same prompt up to and including the selected text, but
Expand Down Expand Up @@ -209,6 +221,55 @@ def get_grounding_bb(self, start, end) -> tuple:

return a, b

def get_grounding_bb_qwen25(self, start, end) -> tuple:
"""
Qwen2.5 works the same way, except the coordinates are no longer normalized and the format is:
"(x0,y0,x1,y1)"
"""

if start >= end:
return None, None

# Including leading space
if start > 0 and self.current_description[start - 1] == " ":
start -= 1

# Repeat the same prompt up to the selection, with grounding tokens added
prompt = self.get_prompt()
prompt += self.current_description[:start]
prompt += "<|object_ref_start|>"
prompt += self.current_description[start:end]
prompt += "<|object_ref_end|><|box_start|>("

bb_string, res = self.generator.generate(
prompt = prompt,
add_bos = True,
max_new_tokens = 28,
stop_conditions = [self.tokenizer.single_id("<|box_end|>")],
gen_settings = ExLlamaV2Sampler.Settings.greedy(),
embeddings = [self.current_emb],
completion_only = True,
return_last_results = True, # debug purposes
)
bb_string = "(" + bb_string

print(f"Generation: {bb_string}")
pprint.pprint(res, indent = 4)

# BB string is in the format "(x0,y0,x1,y1)" with integer coordinates

s = self.current_image.size
try:
d = tuple(map(int, bb_string.strip("()").split(",")))
a = (d[0] / s[0], d[1] / s[1])
b = (d[2] / s[0], d[3] / s[1])
except:
print("No bounding box could be determined")
a, b = None, None

return a, b



class GroundingDemo(QMainWindow):

Expand Down Expand Up @@ -472,7 +533,7 @@ def on_selection_made(self, pos):

print(f"Selected span: {start}, {end}")
print(f"Selected text: {repr(self.model.current_description[start:end])}")
a, b = self.model.get_grounding_bb(start, end)
a, b = self.model.bbox_func(start, end)
self.image_label.set_bounding_box(a, b)


Expand All @@ -481,9 +542,14 @@ def on_selection_made(self, pos):
# https://huggingface.co/turboderp/Qwen2-VL-7B-Instruct-exl2

def main():
model_dir = "/mnt/str/models/qwen2-vl-7b-instruct-exl2/6.0bpw"

# model_dir = "/mnt/str/models/qwen2-vl-7b-instruct-exl2/6.0bpw"
# bbox_mode = "qwen25"
model_dir = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/6.0bpw"
bbox_mode = "qwen25"

app = QApplication(sys.argv)
model = Model(model_dir)
model = Model(model_dir, bbox_mode)
model.load()
window = GroundingDemo(model, model_dir)
window.show()
Expand Down
49 changes: 33 additions & 16 deletions exllamav2/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class Params:

# Qwen2-VL (2, 2.5)

if arch_string == "Qwen2VLForConditionalGeneration":
if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]:
arch_recognized = True
self.lm.layer_keys += \
layer_keys_llama_norms + \
Expand All @@ -368,27 +368,44 @@ class Params:
self.lm.mrope = True
self.lm.rope_freq_half = True

read_config["vision_config"].update({"model_type": "qwen2"})
self.vt_prefix = "visual."
self.vt.keys.update({
"fused_qkv": ".attn.qkv",
"attn_o": ".attn.proj",
"mlp_gate": None,
"mlp_up": ".mlp.fc1",
"mlp_down": ".mlp.fc2",
"norm_1": ".norm1",
"norm_2": ".norm2",
"layers": "blocks",
"patch_conv": "patch_embed.proj",
})
self.vt.mlp_gate = False
if arch_string == "Qwen2VLForConditionalGeneration":
read_config["vision_config"].update({"model_type": "qwen2"})
self.vt.keys.update({
"fused_qkv": ".attn.qkv",
"attn_o": ".attn.proj",
"mlp_gate": None,
"mlp_up": ".mlp.fc1",
"mlp_down": ".mlp.fc2",
"norm_1": ".norm1",
"norm_2": ".norm2",
"layers": "blocks",
"patch_conv": "patch_embed.proj",
})
self.vt.mlp_gate = False
self.vt.mlp_act_func = "quickgelu"
self.vt.norm = "layernorm"
elif arch_string == "Qwen2_5_VLForConditionalGeneration":
read_config["vision_config"].update({"model_type": "qwen2.5"})
self.vt.keys.update({
"fused_qkv": ".attn.qkv",
"attn_o": ".attn.proj",
"mlp_gate": ".mlp.gate_proj",
"mlp_up": ".mlp.up_proj",
"mlp_down": ".mlp.down_proj",
"norm_1": ".norm1",
"norm_2": ".norm2",
"layers": "blocks",
"patch_conv": "patch_embed.proj",
})
self.vt.mlp_gate = True
self.vt.mlp_act_func = "silu"
self.vt.norm = "rmsnorm"
self.vt.mlp_bias = True
self.vt.attention_bias_qkv = True
self.vt.attention_bias_o = True
self.vt.vision_input_norm = False
self.vt.vision_conv3d = True
self.vt.mlp_act_func = "quickgelu"
self.vt.norm = "layernorm"

self.mmp_prefix = "visual.merger."
self.mmp.keys.update({
Expand Down
44 changes: 32 additions & 12 deletions exllamav2/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.")

if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]:
from flash_attn import flash_attn_func
from flash_attn import flash_attn_func, flash_attn_varlen_func
has_flash_attn = True

if [2, 5, 7] <= flash_attn_ver:
from flash_attn import flash_attn_func, flash_attn_with_kvcache
from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
# import flash_attn_2_cuda as flash_attn_cuda

signature = list(inspect.signature(flash_attn_func).parameters)
Expand Down Expand Up @@ -882,7 +882,9 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
k_states = k_states[:, :, -self.sliding_window:, :]
v_states = v_states[:, :, -self.sliding_window:, :]

if attn_params.is_causal():
if self.layer_idx in attn_params.block_diag_layers:
attn_mask_lr = attn_params.get_block_diag_mask(q_states.device)
elif attn_params.is_causal():
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
else:
attn_mask_lr = attn_params.get_attn_mask(q_states.device)
Expand All @@ -904,7 +906,9 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
attn_weights = torch.matmul(q_states, k_states)

attn_weights *= self.scaling
if causal:
if self.layer_idx in attn_params.block_diag_layers:
attn_mask = attn_params.get_block_diag_mask(attn_weights.device)
elif causal:
attn_mask = attn_params.get_attn_mask(attn_weights.device)

if cfg.attn_logit_softcapping:
Expand Down Expand Up @@ -939,14 +943,30 @@ def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_para
if has_flash_attn_with_softcap:
flash_kwargs["softcap"] = cfg.attn_logit_softcapping

attn_output = flash_attn_func(
q_states,
k_states,
v_states,
causal = causal,
softmax_scale = self.scaling,
**flash_kwargs
)
if self.layer_idx in attn_params.block_diag_layers:
q_states = q_states.flatten(start_dim = 0, end_dim = 1)
k_states = k_states.flatten(start_dim = 0, end_dim = 1)
v_states = v_states.flatten(start_dim = 0, end_dim = 1)
max_seqlen = attn_params.get_cu_seqlens_max()
cu_seqlens = attn_params.get_cu_seqlens(self.device_idx)
attn_output = flash_attn_varlen_func(
q_states,
k_states,
v_states,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen
)
else:
attn_output = flash_attn_func(
q_states,
k_states,
v_states,
causal = causal,
softmax_scale = self.scaling,
**flash_kwargs
)
attn_output = attn_output.reshape((batch_size, q_len, self.num_attention_heads * self.head_dim))
return attn_output

Expand Down
34 changes: 34 additions & 0 deletions exllamav2/attn_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class Params:
alt_rope_embed_dict: dict | None
rope_offsets: torch.Tensor | None
non_causal_attn: bool
block_diag_layers: set
block_diag_mask: torch.Tensor | None
cu_seqlens: torch.Tensor | None
cu_seqlens_max: int | None

def __init__(
self,
Expand Down Expand Up @@ -66,6 +70,11 @@ def __init__(
self.past_len_tp = None
self.paged = paged

self.block_diag_layers = set()
self.block_diag_mask = None
self.cu_seqlens = None
self.cu_seqlens_max = None

def is_causal(self) -> bool:
return self.input_mask is None

Expand Down Expand Up @@ -164,6 +173,31 @@ def get_rope_offsets(self, device_idx: int) -> torch.Tensor | None:
self.rope_offsets = safe_move_tensor(self.rope_offsets, device_idx, non_blocking = True)
return self.rope_offsets

def get_cu_seqlens(self, device: int) -> torch.Tensor | None:
if self.cu_seqlens is None:
return None
if self.cu_seqlens.device.index != device:
self.cu_seqlens = safe_move_tensor(self.cu_seqlens, device, non_blocking = True)
return self.cu_seqlens

def get_cu_seqlens_max(self) -> torch.Tensor | None:
assert self.cu_seqlens is not None
if self.cu_seqlens_max is not None:
return self.cu_seqlens_max
self.cu_seqlens_max = (self.cu_seqlens[1:] - self.cu_seqlens[:-1]).max().item()
return self.cu_seqlens_max

def get_block_diag_mask(self, device: int) -> torch.Tensor | None:
if self.block_diag_mask is None:
csl = self.get_cu_seqlens(device)
if csl is None:
return None
positions = torch.arange(csl[-1], device = csl.device)
labels = torch.searchsorted(csl[1:], positions, right = True)
self.block_diag_mask = labels.unsqueeze(0) == labels.unsqueeze(1).repeat(self.batch_size)
if self.block_diag_mask.device.index != device:
self.block_diag_mask = safe_move_tensor(self.block_diag_mask, device, non_blocking = True)
return self.block_diag_mask


class PagedParams(Params):
Expand Down
Loading

0 comments on commit 3486f9e

Please sign in to comment.