Skip to content

Commit 6b2b243

Browse files
HolyFalafelyafsharhsubramonyyeonsily
authored andcommitted
Load INC GPTQ checkpoint & rename params (huggingface#1364)
Co-authored-by: Yaser Afshar <yaser.afshar@intel.com> Co-authored-by: Harish Subramony <81822986+hsubramony@users.noreply.github.com> Co-authored-by: Yeonsil Yoon <yyoon@habana.ai>
1 parent e35e970 commit 6b2b243

File tree

5 files changed

+155
-15
lines changed

5 files changed

+155
-15
lines changed

Makefile

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ slow_tests_diffusers: test_installs
105105

106106
# Run text-generation non-regression tests
107107
slow_tests_text_generation_example: test_installs
108+
BUILD_CUDA_EXT=0 python -m pip install -vvv --no-build-isolation git+https://github.com/HabanaAI/AutoGPTQ.git
108109
python -m pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.18.0
109110
python -m pytest tests/test_text_generation_example.py tests/test_encoder_decoder.py -v -s --token $(TOKEN)
110111

examples/text-generation/README.md

+61-2
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ python run_generation.py \
502502

503503
### Loading 4 Bit Checkpoints from Hugging Face
504504

505-
You can load pre-quantized 4bit models with the argument `--load_quantized_model`.
505+
You can load pre-quantized 4bit models with the argument `--load_quantized_model_with_inc`.
506506
Currently, uint4 checkpoints and single device are supported.
507507
More information on enabling 4 bit inference in SynapseAI is available here:
508508
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_UINT4.html.
@@ -524,7 +524,35 @@ python run_lm_eval.py \
524524
--attn_softmax_bf16 \
525525
--bucket_size=128 \
526526
--bucket_internal \
527-
--load_quantized_model
527+
--load_quantized_model_with_inc
528+
```
529+
530+
### Loading 4 Bit Checkpoints from Neural Compressor (INC)
531+
532+
You can load a pre-quantized 4-bit checkpoint with the argument `--local_quantized_inc_model_path`, supplied with the original model with the argument `--model_name_or_path`.
533+
Currently, only uint4 checkpoints and single-device configurations are supported.
534+
**Note:** In this process, you can load a checkpoint that has been quantized using INC.
535+
More information on enabling 4-bit inference in SynapseAI is available here:
536+
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_UINT4.html?highlight=inference%20using%20int4#enabling-and-running-uint4-in-pytorch-models.
537+
538+
Below is an example of loading a llama7b model with a 4bit checkpoint quantized in INC.
539+
Please note that the model checkpoint name is denoted as `<local_model_path_from_inc>`.
540+
Additionally, the following environment variables are used for performance optimizations and are planned to be removed in future versions:
541+
`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1`
542+
```bash
543+
SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1 \
544+
python run_lm_eval.py \
545+
-o acc_load_uint4_model.txt \
546+
--model_name_or_path meta-llama/Llama-2-7b-hf \
547+
--use_hpu_graphs \
548+
--use_kv_cache \
549+
--trim_logits \
550+
--batch_size 1 \
551+
--bf16 \
552+
--attn_softmax_bf16 \
553+
--bucket_size=128 \
554+
--bucket_internal \
555+
--local_quantized_inc_model_path <local_model_path_from_inc> \
528556
```
529557

530558
### Using Habana Flash Attention
@@ -555,6 +583,37 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
555583

556584
For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa).
557585

586+
### Running with UINT4 weight quantization using AutoGPTQ
587+
588+
589+
Llama2-7b in UINT4 weight only quantization is enabled using [AutoGPTQ Fork](https://github.com/HabanaAI/AutoGPTQ), which provides quantization capabilities in PyTorch.
590+
Currently, the support is for UINT4 inference of pre-quantized models only.
591+
592+
You can run a *UINT4 weight quantized* model using AutoGPTQ by setting the following environment variables:
593+
`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=true` before running the command,
594+
and by adding the argument `--load_quantized_model_with_autogptq`.
595+
596+
***Note:***
597+
Setting the above environment variables improves performance. These variables will be removed in future releases.
598+
599+
600+
Here is an example to run a quantized model <quantized_gptq_model>:
601+
```bash
602+
SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false \
603+
ENABLE_EXPERIMENTAL_FLAGS=true python run_generation.py \
604+
--attn_softmax_bf16 \
605+
--model_name_or_path <quantized_gptq_model> \
606+
--use_hpu_graphs \
607+
--limit_hpu_graphs \
608+
--use_kv_cache \
609+
--bucket_size 128 \
610+
--bucket_internal \
611+
--trim_logits \
612+
--max_new_tokens 128 \
613+
--batch_size 1 \
614+
--bf16 \
615+
--load_quantized_model_with_autogptq
616+
```
558617

559618
## Language Model Evaluation Harness
560619

examples/text-generation/run_generation.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -293,21 +293,11 @@ def setup_parser(parser):
293293
type=str,
294294
help="Path to serialize const params. Const params will be held on disk memory instead of being allocated on host memory.",
295295
)
296-
parser.add_argument(
297-
"--disk_offload",
298-
action="store_true",
299-
help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.",
300-
)
301296
parser.add_argument(
302297
"--trust_remote_code",
303298
action="store_true",
304299
help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.",
305300
)
306-
parser.add_argument(
307-
"--load_quantized_model",
308-
action="store_true",
309-
help="Whether to load model from hugging face checkpoint.",
310-
)
311301
parser.add_argument(
312302
"--parallel_strategy",
313303
type=str,
@@ -326,6 +316,35 @@ def setup_parser(parser):
326316
help="Run the inference with dataset for specified --n_iterations(default:5)",
327317
)
328318

319+
parser.add_argument(
320+
"--run_partial_dataset",
321+
action="store_true",
322+
help="Run the inference with dataset for specified --n_iterations(default:5)",
323+
)
324+
325+
quant_parser_group = parser.add_mutually_exclusive_group()
326+
quant_parser_group.add_argument(
327+
"--load_quantized_model_with_autogptq",
328+
action="store_true",
329+
help="Load an AutoGPTQ quantized checkpoint using AutoGPTQ.",
330+
)
331+
quant_parser_group.add_argument(
332+
"--disk_offload",
333+
action="store_true",
334+
help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.",
335+
)
336+
quant_parser_group.add_argument(
337+
"--load_quantized_model_with_inc",
338+
action="store_true",
339+
help="Load a Huggingface quantized checkpoint using INC.",
340+
)
341+
quant_parser_group.add_argument(
342+
"--local_quantized_inc_model_path",
343+
type=str,
344+
default=None,
345+
help="Path to neural-compressor quantized model, if set, the checkpoint will be loaded.",
346+
)
347+
329348
args = parser.parse_args()
330349

331350
if args.torch_compile:
@@ -338,6 +357,9 @@ def setup_parser(parser):
338357
args.flash_attention_fast_softmax = True
339358

340359
args.quant_config = os.getenv("QUANT_CONFIG", "")
360+
if args.quant_config and args.load_quantized_model_with_autogptq:
361+
raise RuntimeError("Setting both quant_config and load_quantized_model_with_autogptq is unsupported. ")
362+
341363
if args.quant_config == "" and args.disk_offload:
342364
logger.warning(
343365
"`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag."

examples/text-generation/utils.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,32 @@ def setup_model(args, model_dtype, model_kwargs, logger):
237237
torch_dtype=model_dtype,
238238
**model_kwargs,
239239
)
240-
elif args.load_quantized_model:
240+
elif args.load_quantized_model_with_autogptq:
241+
from transformers import GPTQConfig
242+
243+
quantization_config = GPTQConfig(bits=4, use_exllama=False)
244+
model = AutoModelForCausalLM.from_pretrained(
245+
args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs
246+
)
247+
elif args.load_quantized_model_with_inc:
241248
from neural_compressor.torch.quantization import load
242249

243250
model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs)
251+
elif args.local_quantized_inc_model_path:
252+
org_model = AutoModelForCausalLM.from_pretrained(
253+
args.model_name_or_path,
254+
**model_kwargs,
255+
)
256+
257+
from neural_compressor.torch.quantization import load
258+
259+
model = load(
260+
model_name_or_path=args.local_quantized_inc_model_path,
261+
format="default",
262+
device="hpu",
263+
original_model=org_model,
264+
**model_kwargs,
265+
)
244266
else:
245267
if args.assistant_model is not None:
246268
assistant_model = AutoModelForCausalLM.from_pretrained(
@@ -613,8 +635,7 @@ def initialize_model(args, logger):
613635
"token": args.token,
614636
"trust_remote_code": args.trust_remote_code,
615637
}
616-
617-
if args.load_quantized_model:
638+
if args.load_quantized_model_with_inc or args.local_quantized_inc_model_path:
618639
model_kwargs["torch_dtype"] = torch.bfloat16
619640

620641
if args.trust_remote_code:

tests/test_text_generation_example.py

+37
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@
6767
("mistralai/Mixtral-8x7B-v0.1", 2, 48, True, 2048, 2048, 1147.50),
6868
("microsoft/phi-2", 1, 1, True, 128, 128, 254.08932787178165),
6969
],
70+
"load_quantized_model_with_autogptq": [
71+
("TheBloke/Llama-2-7b-Chat-GPTQ", 1, 10, False, 128, 2048, 456.7),
72+
],
7073
"deepspeed": [
7174
("bigscience/bloomz", 8, 1, 36.77314954096159),
7275
("meta-llama/Llama-2-70b-hf", 8, 1, 64.10514998902435),
@@ -110,6 +113,7 @@
110113
("state-spaces/mamba-130m-hf", 224, False, 794.542),
111114
],
112115
"fp8": [],
116+
"load_quantized_model_with_autogptq": [],
113117
"deepspeed": [
114118
("bigscience/bloomz-7b1", 8, 1, 31.994268212011505),
115119
],
@@ -132,6 +136,7 @@ def _test_text_generation(
132136
world_size: int = 8,
133137
torch_compile: bool = False,
134138
fp8: bool = False,
139+
load_quantized_model_with_autogptq: bool = False,
135140
max_input_tokens: int = 0,
136141
max_output_tokens: int = 100,
137142
parallel_strategy: str = None,
@@ -243,6 +248,8 @@ def _test_text_generation(
243248
f"--max_input_tokens {max_input_tokens}",
244249
"--limit_hpu_graphs",
245250
]
251+
if load_quantized_model_with_autogptq:
252+
command += ["--load_quantized_model_with_autogptq"]
246253
if parallel_strategy is not None:
247254
command += [
248255
f"--parallel_strategy={parallel_strategy}",
@@ -336,6 +343,36 @@ def test_text_generation_fp8(
336343
)
337344

338345

346+
@pytest.mark.parametrize(
347+
"model_name, world_size, batch_size, reuse_cache, input_len, output_len, baseline",
348+
MODELS_TO_TEST["load_quantized_model_with_autogptq"],
349+
)
350+
def test_text_generation_gptq(
351+
model_name: str,
352+
baseline: float,
353+
world_size: int,
354+
batch_size: int,
355+
reuse_cache: bool,
356+
input_len: int,
357+
output_len: int,
358+
token: str,
359+
):
360+
deepspeed = True if world_size > 1 else False
361+
_test_text_generation(
362+
model_name,
363+
baseline,
364+
token,
365+
deepspeed=deepspeed,
366+
world_size=world_size,
367+
fp8=False,
368+
load_quantized_model_with_autogptq=True,
369+
batch_size=batch_size,
370+
reuse_cache=reuse_cache,
371+
max_input_tokens=input_len,
372+
max_output_tokens=output_len,
373+
)
374+
375+
339376
@pytest.mark.parametrize("model_name, world_size, batch_size, baseline", MODELS_TO_TEST["deepspeed"])
340377
def test_text_generation_deepspeed(model_name: str, baseline: float, world_size: int, batch_size: int, token: str):
341378
_test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, batch_size=batch_size)

0 commit comments

Comments
 (0)