-
Notifications
You must be signed in to change notification settings - Fork 102
/
Copy pathsvd_reinforce_hydra.py
426 lines (370 loc) · 15.1 KB
/
svd_reinforce_hydra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
import gc
import json
import os
from datetime import datetime
from typing import Dict
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf
from transformers import AutoModelForCausalLM, AutoTokenizer
from base_model import BaseModel
from logging_utils import Metrics, get_mean_std_max_min_dict
from optim_modules import OptimizationAlgorithm
from policy import Policy
from tasks import Task
from utils import (eval_model, eval_model_experts_prompt_based, forward,
load_hf_params_to_vllm)
def wandb_init(cfg, run_name: str, group_name: str, log_dir: str):
import wandb
config_dict = OmegaConf.to_container(
cfg,
resolve=True,
throw_on_missing=False,
)
config_dict["log_dir"] = log_dir
config_dict["wandb_run_name"] = run_name
config_dict["wandb_group_name"] = group_name
# wandb has a 128-size character limit on the group name
wandb.init(
project=cfg.wandb_project,
group=group_name[:127],
name=run_name[:127],
config=config_dict,
)
return wandb
@hydra.main(version_base=None, config_path="cfgs", config_name="config")
def main(cfg):
"""Main function."""
num_iters = cfg.num_iters
test_interval = cfg.test_interval
batch_size = cfg.batch_size
seed = cfg.seed
policy_name = cfg.policy_name
test_only = cfg.test_only
save_legacy_params = cfg.save_legacy_params
exp_name = cfg.exp_name
run_name = cfg.run_name
task_name = cfg.task_name
load_ckpt = cfg.load_ckpt
use_lora = cfg.use_lora
prompt_based_eval = cfg.prompt_based_eval
experts_path_dict = cfg.experts_path_dict
resuming_from_ckpt = False
if load_ckpt is not None:
if load_ckpt == "scratch" or load_ckpt == "base":
resuming_from_ckpt = False
else:
resuming_from_ckpt = True
# Create task
task_loader: Task = hydra.utils.instantiate(cfg.task_loader)
base_model: BaseModel = hydra.utils.instantiate(cfg.base_model)
model_id = base_model.get_model_id()
decomposed_param_file = base_model.get_param_file(param_folder_path="")
extract_svd = cfg.extract_svd or (not os.path.exists(decomposed_param_file))
has_training_split = task_loader.has_training_split
has_transfer_split = task_loader.has_transfer_split
if not has_training_split:
assert test_only, "Cannot train on a task with no training split"
if exp_name is None:
exp_name = "temp"
metrics_to_log = Metrics()
# Create log dir.
if run_name is None:
now = datetime.now()
run_name = now.strftime("%Y%m%d-%H%M%S")
if test_only and (not resuming_from_ckpt):
log_dir = f"{cfg.out_dir}/{task_name}/{cfg.base_model_name}_base"
group_name = cfg.base_model_name
else:
log_dir = f"{cfg.out_dir}/{task_name}/{policy_name}/{exp_name}/{run_name}"
group_name = cfg.wandb_group_name
os.makedirs(log_dir, exist_ok=True)
vllm_model = task_loader.get_vllm_model(model_id=model_id)
train_eval, *test_evals = task_loader.get_evaluator()
if task_loader.has_transfer_split:
test_eval, transfer_eval = test_evals
else:
test_eval = test_evals[0]
train_data, train_ix, valid_ix = task_loader.get_train_data()
gpu = torch.device("cuda:1")
np_random = np.random.RandomState(seed)
# cpu + float32 for initial SVD decomposition
if extract_svd:
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cpu", torch_dtype=torch.float32
)
else:
# Load model and tokenizer.
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda:1", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
base_params = model.state_dict()
original_model_params = {
k: v.clone().detach().cpu() for k, v in base_params.items() if "mlp" in k
}
# Load decomposed parameters.
if not os.path.exists(decomposed_param_file):
print("Decomposed params not found. Decomposing...")
decomposed_params = {}
for k, v in base_params.items():
if "norm" not in k:
print(k)
U, S, V = torch.svd(v)
decomposed_params[f"{k}.U"] = U
decomposed_params[f"{k}.S"] = S
decomposed_params[f"{k}.V"] = V
torch.save(decomposed_params, decomposed_param_file)
print("successfully decomposed model - returning")
return
elif extract_svd:
print(f"ERROR: SVD file already exists at {decomposed_param_file}")
else:
print("Decomposed params found. Loading...")
assert not extract_svd
decomposed_params = torch.load(decomposed_param_file)
for k, v in decomposed_params.items():
decomposed_params[k] = v.to(torch.bfloat16).to(gpu)
if cfg.wandb_log:
wandb = wandb_init(
cfg=cfg, group_name=group_name, run_name=run_name, log_dir=log_dir
)
policy: Policy = hydra.utils.instantiate(
cfg.shakeoff_policy,
base_params=base_params,
decomposed_params=decomposed_params,
gpu=gpu,
)
optimization_algorithm: OptimizationAlgorithm = hydra.utils.instantiate(
cfg.optimization_algorithm,
policy=policy,
gpu=gpu,
)
if resuming_from_ckpt and os.path.exists(load_ckpt):
print(f"Starting from checkpoint at: {load_ckpt}")
# load the lora weight
if use_lora:
assert os.path.isdir(load_ckpt), "ckpt for lora must be dir to lora adapter"
from peft import PeftModel
lora_model = PeftModel.from_pretrained(model, load_ckpt)
merged_model = lora_model.merge_and_unload()
new_params = merged_model.state_dict()
# load svd expert
elif "learnable_params" in load_ckpt:
learnable_params = torch.load(load_ckpt)
for k, v in learnable_params.items():
learnable_params[k] = v.to(gpu)
assert test_only
new_params = forward(
policy, model, base_params, decomposed_params, learnable_params
)
else:
state_dict = torch.load(load_ckpt, weights_only=True)
policy.load_state_dict(state_dict=state_dict)
if test_only:
learnable_params = policy.get_learnable_params()
new_params = forward(
policy, model, base_params, decomposed_params, learnable_params
)
load_hf_params_to_vllm(new_params, vllm_model.llm)
else:
print(f"Starting from the base model as load_ckpt=={load_ckpt}")
model.eval()
# Prompt-based and cls dispatcher evaluation.
if test_only and prompt_based_eval:
test_data_dict = eval_model_experts_prompt_based(
vllm_model,
test_eval,
experts_path_dict,
policy,
model,
base_params,
decomposed_params,
task_loader.target_metric_test,
)
test_data_dict["type"] = "test"
# Log the results.
if cfg.wandb_log:
wandb.log(test_data_dict)
with open(f"{log_dir}/eval_results.json", "w") as f:
json.dump(test_data_dict, f, indent=4)
print(f"Test evaluation results: {test_data_dict}")
# Eval the transfer set if available
if has_transfer_split:
transfer_data_dict = eval_model_experts_prompt_based(
vllm_model,
transfer_eval,
experts_path_dict,
policy,
model,
base_params,
decomposed_params,
task_loader.target_metric_transfer,
)
transfer_data_dict["type"] = "transfer"
# Log the results.
if cfg.wandb_log:
wandb.log(transfer_data_dict)
with open(f"{log_dir}/eval_results.json", "w") as f:
json.dump(transfer_data_dict, f, indent=4)
print(f"Transfer evaluation results: {transfer_data_dict}")
return
# Non-adaptive evaluation on train, val, test set.
if test_only and not prompt_based_eval:
data_dict = {}
details_dict = {}
if has_training_split:
train_res = eval_model(vllm_model, train_eval, train_ix)
valid_res = eval_model(vllm_model, train_eval, valid_ix)
data_dict["train_acc"] = train_res.aggregate_metrics[
task_loader.target_metric_train
]
data_dict["valid_acc"] = valid_res.aggregate_metrics[
task_loader.target_metric_valid
]
details_dict["train"] = train_res.sample_details
details_dict["valid"] = valid_res.sample_details
test_res = eval_model(vllm_model, test_eval)
data_dict["test_acc"] = test_res.aggregate_metrics[
task_loader.target_metric_test
]
details_dict["test"] = test_res.sample_details
if has_transfer_split:
transfer_res = eval_model(vllm_model, transfer_eval)
data_dict["transfer_acc"] = transfer_res.aggregate_metrics[
task_loader.target_metric_transfer
]
details_dict["transfer"] = transfer_res.sample_details
if cfg.wandb_log:
wandb.log(data_dict)
with open(f"{log_dir}/eval_results.json", "w") as f:
json.dump(data_dict, f, indent=4)
print(f"Evaluation results: {data_dict}")
return
learnable_params = policy.get_learnable_params()
for k in learnable_params:
model.get_parameter(k).requires_grad_(True)
# Training loop.
if batch_size is None:
clipped_batch_size = len(list(train_ix))
else:
clipped_batch_size = min(batch_size, len(list(train_ix)))
best_val_acc = 0.0
test_at_best = 0.0
transfer_at_best = 0.0
for i in range(num_iters):
batch_ix = np_random.choice(train_ix, size=clipped_batch_size, replace=False)
optimization_algorithm.step_optimization(
model_id=model_id,
model=model,
tokenizer=tokenizer,
policy=policy,
task_loader=task_loader,
batch_ix=batch_ix,
train_data=train_data,
train_eval=train_eval,
base_params=base_params,
decomposed_params=decomposed_params,
original_model_params=original_model_params,
metrics_to_log=metrics_to_log,
vllm_model=vllm_model,
)
with torch.no_grad():
lists_to_log = {}
grads = [p.grad for p in policy.trainable_params]
if grads[0] is not None:
grad_mean = [g.mean().item() for g in grads]
grad_mags = [torch.linalg.vector_norm(g).item() for g in grads]
lists_to_log["grad_mean"] = grad_mean
lists_to_log["grad_mags"] = grad_mags
param_mags = [
torch.linalg.vector_norm(p).item() for p in policy.trainable_params
]
lists_to_log["policy_param_mag"] = param_mags
generated_params_list = list(learnable_params.values())
generated_param_mean = [p.mean().item() for p in generated_params_list]
generated_param_mags = [
torch.linalg.vector_norm(p).item() for p in generated_params_list
]
lists_to_log["generated_param_mean"] = generated_param_mean
lists_to_log["generated_param_mags"] = generated_param_mags
list_stats = {}
for k, v in lists_to_log.items():
list_stats.update(get_mean_std_max_min_dict(array=v, prefix=k))
metrics_to_log.update(**list_stats)
optimization_algorithm.update(policy=policy)
# Make sure old params are deleted and garbage-collected
gc.collect()
torch.cuda.empty_cache()
model.zero_grad()
# More accurate logging.
value_mean = list_stats.get("generated_param_mean/mean", None)
grad_mean_mag = list_stats.get("grad_mags/mean", None)
print(
f"Iter {i}: "
+ f"param_mean={value_mean}, "
+ f"grad_mean_mag={grad_mean_mag}"
)
optimization_algorithm.log_optim(metrics_to_log=metrics_to_log)
# Test and save.
if i % test_interval == 0:
learnable_params = policy.get_learnable_params()
forward(policy, model, base_params, decomposed_params, learnable_params)
load_hf_params_to_vllm(model.state_dict(), vllm_model.llm)
train_res = eval_model(vllm_model, train_eval, train_ix)
valid_res = eval_model(vllm_model, train_eval, valid_ix)
test_res = eval_model(vllm_model, test_eval)
if has_transfer_split:
transfer_res = eval_model(vllm_model, transfer_eval)
if (
valid_res.aggregate_metrics[task_loader.target_metric_valid]
> best_val_acc
):
best_val_acc = valid_res.aggregate_metrics[
task_loader.target_metric_valid
]
test_at_best = test_res.aggregate_metrics[
task_loader.target_metric_test
]
if has_transfer_split:
transfer_at_best = transfer_res.aggregate_metrics[
task_loader.target_metric_transfer
]
print("best_val_acc updated")
path = f"{log_dir}/policy_params.pt"
torch.save(policy.state_dict(), path)
if save_legacy_params:
torch.save(learnable_params, f"{log_dir}/learnable_params.pt")
path = f"{log_dir}/policy_params_latest.pt"
torch.save(policy.state_dict(), path)
if save_legacy_params:
torch.save(learnable_params, f"{log_dir}/learnable_params_latest.pt")
policy.record_state(metrics_to_log=metrics_to_log)
data_dict = {
"iter": i,
"best_val_acc": best_val_acc,
"test_at_best_val": test_at_best,
"train_acc": train_res.aggregate_metrics[
task_loader.target_metric_train
],
"valid_acc": valid_res.aggregate_metrics[
task_loader.target_metric_valid
],
"test_acc": test_res.aggregate_metrics[task_loader.target_metric_test],
**metrics_to_log.get(),
}
if has_transfer_split:
data_dict["transfer_acc"] = transfer_res.aggregate_metrics[
task_loader.target_metric_transfer
]
data_dict["transfer_at_best_val"] = transfer_at_best
if cfg.wandb_log:
wandb.log(data_dict)
with open(f"{log_dir}/reinforce_log.json", "a") as f:
json_data = json.dumps(data_dict, indent=4)
f.write(json_data)
f.write("\n")
metrics_to_log.reset()
if __name__ == "__main__":
main()