forked from johnsmith0031/alpaca_lora_4bit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathautograd_4bit.py
294 lines (241 loc) · 11.6 KB
/
autograd_4bit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import matmul_utils_4bit as mm4b
import torch
import torch.nn as nn
import time
import math
from torch.cuda.amp import custom_bwd, custom_fwd
from colorama import init, Fore, Back, Style
init(autoreset=True)
class AutogradMatmul4bitCuda(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq):
ctx.save_for_backward(qweight, scales, zeros, g_idx)
if g_idx is None:
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
else:
output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, g_idx)
output = output.clone()
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, zeros, g_idx = ctx.saved_tensors
if ctx.needs_input_grad[0]:
if g_idx is None:
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
else:
grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, g_idx, transpose=True)
return grad, None, None, None, None, None, None
try:
import triton_utils as tu
class AutogradMatmul4bitTriton(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq):
output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
output = output.clone()
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = tu.triton_matmul_transpose(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
except ImportError:
print('Triton not found. Please run "pip install triton".')
AutogradMatmul4bit = AutogradMatmul4bitCuda
backend = 'cuda'
def switch_backend_to(to_backend):
global AutogradMatmul4bit
global backend
if to_backend == 'cuda':
AutogradMatmul4bit = AutogradMatmul4bitCuda
backend = 'cuda'
print(Style.BRIGHT + Fore.GREEN + 'Using CUDA implementation.')
elif to_backend == 'triton':
# detect if AutogradMatmul4bitTriton is defined
if 'AutogradMatmul4bitTriton' not in globals():
raise ValueError('Triton not found. Please install triton')
AutogradMatmul4bit = AutogradMatmul4bitTriton
backend = 'triton'
print(Style.BRIGHT + Fore.GREEN + 'Using Triton implementation.')
else:
raise ValueError('Backend not supported.')
def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq):
if backend == 'cuda':
return mm4b.matmul4bit(x, qweight, scales, qzeros, g_idx)
elif backend == 'triton':
assert qzeros.dtype == torch.int32
return tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
else:
raise ValueError('Backend not supported.')
# Assumes layer is perfectly divisible into 256 * 256 blocks
class Autograd4bitQuantLinear(nn.Module):
def __init__(self, in_features, out_features, groupsize=-1, is_v1_model=False):
super().__init__()
bits = 4
self.in_features = in_features
self.out_features = out_features
self.bits = bits
self.maxq = 2 ** self.bits - 1
groupsize = groupsize if groupsize != -1 else in_features
self.groupsize = groupsize
self.is_v1_model = is_v1_model
self.disable_bias = True
if is_v1_model:
self.register_buffer('zeros', torch.empty((out_features, 1)))
self.register_buffer('scales', torch.empty((out_features, 1)))
self.g_idx = None
else:
self.register_buffer('qzeros',
torch.empty((math.ceil(in_features/groupsize), out_features // 256 * (bits * 8)), dtype=torch.int32)
)
self.register_buffer('scales', torch.empty((math.ceil(in_features/groupsize), out_features)))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(in_features)], dtype = torch.int32))
self.register_buffer('bias', torch.empty(out_features))
self.register_buffer(
'qweight', torch.empty((in_features // 256 * (bits * 8), out_features), dtype=torch.int32)
)
def forward(self, x):
if torch.is_grad_enabled():
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
self.qzeros if not self.is_v1_model else self.zeros,
self.g_idx, self.bits, self.maxq)
else:
out = matmul4bit_with_backend(x, self.qweight, self.scales,
self.qzeros if not self.is_v1_model else self.zeros,
self.g_idx, self.bits, self.maxq)
if not self.disable_bias:
out += self.bias
return out
def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1, is_v1_model=False):
if isinstance(module, Autograd4bitQuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
setattr(
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize, is_v1_model=is_v1_model)
)
for name1, child in module.named_children():
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize, is_v1_model=is_v1_model)
def model_to_half(model):
model.half()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
if m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print(Style.BRIGHT + Fore.YELLOW + 'Converted as Half.')
def model_to_float(model):
model.float()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
if m.is_v1_model:
m.zeros = m.zeros.float()
m.scales = m.scales.float()
m.bias = m.bias.float()
print(Style.BRIGHT + Fore.YELLOW + 'Converted as Float.')
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048, is_v1_model=False):
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
print(Style.BRIGHT + Fore.CYAN + "Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LlamaConfig.from_pretrained(config_path)
model = LlamaForCausalLM(config)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize, is_v1_model=is_v1_model)
model = accelerate.load_checkpoint_and_dispatch(
model=model,
checkpoint=model_path,
device_map=device_map,
no_split_module_classes=["LlamaDecoderLayer"]
)
model.seqlen = seqlen
if half:
model_to_half(model)
tokenizer = LlamaTokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer
def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None, is_v1_model=False):
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
if max_memory is None:
max_memory = {0: '24Gib', 'cpu': '48Gib'}
print(Style.BRIGHT + Fore.CYAN + "Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LlamaConfig.from_pretrained(config_path)
model = LlamaForCausalLM(config)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize, is_v1_model=is_v1_model)
accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'})
# rotary_emb fix
for n, m in model.named_modules():
if 'rotary_emb' in n:
cos_cached = m.cos_cached.clone().cpu()
sin_cached = m.sin_cached.clone().cpu()
break
if lora_path is not None:
from peft import PeftModel
from monkeypatch.peft_tuners_lora_monkey_patch import Linear4bitLt
model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32, is_trainable=True)
print(Style.BRIGHT + Fore.GREEN + '{} Lora Applied.'.format(lora_path))
model.seqlen = seqlen
print('Apply half ...')
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)):
if m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print('Dispatching model ...')
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0)
torch.cuda.empty_cache()
print(Style.BRIGHT + Fore.YELLOW + 'Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024))
# rotary_emb fix
for n, m in model.named_modules():
if 'rotary_emb' in n:
if getattr(m, '_hf_hook', None):
if isinstance(m._hf_hook, accelerate.hooks.SequentialHook):
hooks = m._hf_hook.hooks
else:
hooks = [m._hf_hook]
for hook in hooks:
if hook.offload:
if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys():
hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu()
hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu()
tokenizer = LlamaTokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer
load_llama_model_4bit_low_ram_and_offload_to_cpu = load_llama_model_4bit_low_ram_and_offload