-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·65 lines (52 loc) · 3.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os, argparse, warnings
from transformers import AutoTokenizer
from wanda.lib.prune import check_sparsity
from utils import get_llm, write_results
from evaluation import eval_model
from process import finetune, prune
from constant import CACHE_PATH
os.environ['HF_HOME'] = CACHE_PATH
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--sparsity', type=float, default=0.0, help='sparsity')
parser.add_argument('--action', type=str, default="base", help='action to perform')
parser.add_argument('--out_type', type=str, default="base", help='Type of the output in write.txt')
parser.add_argument('--model_path', type=str, default="baffo32/decapoda-research-llama-7B-hf", help='store path')
parser.add_argument('--save_path', type=str, default="baffo32/decapoda-research-llama-7B-hf", help='save path')
parser.add_argument('--auto_tokenizer_model_name', type=str, default="baffo32/decapoda-research-llama-7B-hf", help='name or path of the model from which the tokenizer will be inferred by autotokenizer. ')
parser.add_argument('--not_eval', action="store_true", help='not to evaluate the model')
parser.add_argument('--epochs', type=float, default=0.1, help='finetuning epochs')
parser.add_argument('--ft_iter', type=int, default=1, help='ith iteration for this finetuning if action = finetune, number of times fine-tuning has been performed if action = prune.')
parser.add_argument('--results_path', type=str, default='results.json', help='path to save the results as a json file')
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.auto_tokenizer_model_name, use_fast = False)
os.makedirs(args.save_path, exist_ok=True)
if args.action == "finetune":
finetune(tokenizer=tokenizer, model=args.model_path, save_path=args.save_path, seed=args.ft_iter, epochs=args.epochs)
elif args.action == "prune":
prune(model=args.model_path, save_path=args.save_path, sparsity=args.sparsity)
elif args.action == "base":
if args.model_path != args.save_path:
warnings.warn('--model_path and --save_path are not the same, which may casue unexpected behaviour, since model evaluation is based on --save_path.')
else:
raise ValueError('Unsupported --action : {}'.format(args.action))
if args.not_eval:
return
metrics = {"finetune_iterations": args.ft_iter}
saved_model = get_llm(args.save_path)
sparsity_latest = check_sparsity(saved_model) # check the sparsity of the final model and compare sparsities
accuracy_mmlu = eval_model(saved_model, tokenizer, ds_name="cais/mmlu")
accuracy_bbh = eval_model(saved_model, tokenizer, ds_name="lukaemon/bbh")
accuracy_belebele = eval_model(saved_model, tokenizer, ds_name="facebook/belebele")
accuracy_factoid_qa = eval_model(saved_model, tokenizer, ds_name="kelvin-jiang/factoid-qa")
ppl = eval_model(saved_model, tokenizer, ds_name='wikitext2')
metrics["ppl"] = ppl
metrics["bbh"] = accuracy_bbh
metrics["mmlu"] = accuracy_mmlu
metrics["belebele"] = accuracy_belebele
metrics["factoid_qa"] = accuracy_factoid_qa
metrics["sparsity_prune"] = round(args.sparsity, 2)
metrics["sparsity_latest"] = round(sparsity_latest, 2)
write_results(pipeline=args.out_type, metrics=metrics, results_path=args.results_path)
if __name__ == "__main__":
main()