Skip to content

Commit

Permalink
Make Nf4 a NF4 Tensor subclass (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Feb 14, 2024
1 parent e2a3869 commit 7fae31b
Show file tree
Hide file tree
Showing 6 changed files with 454 additions and 341 deletions.
103 changes: 50 additions & 53 deletions benchmarks/qlora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import csv
import itertools

Expand All @@ -12,6 +11,8 @@

import transformer_nuggets as nugs
import transformer_nuggets.quant.qlora as qlora
from jsonargparse import CLI
from tabulate import tabulate
from tqdm import tqdm
from transformer_nuggets.quant import NF4Tensor

Expand Down Expand Up @@ -57,15 +58,17 @@ def linear_experiment(config: ExperimentConfig) -> ExperimentResult:
config.device,
)
qlora_weight = NF4Tensor.from_tensor(input_weight.clone())
bnb_linear = qlora.build_bitsandbytes_linear(input_weight, config.device)
compiled_qlora_linear = torch.compile(qlora.linear_nf4, fullgraph=True, dynamic=config.dynamic)
if bnb_available:
bnb_linear = qlora.build_bitsandbytes_linear(input_weight, config.device)

# warmup
for _ in range(3):
F.linear(sample_input, input_weight)
qlora.linear_nf4(sample_input, qlora_weight)
compiled_qlora_linear(sample_input, qlora_weight)
bnb_linear(sample_input)
if bnb_available:
bnb_linear(sample_input)

linear_time = nugs.utils.benchmark_torch_function_in_microseconds(
F.linear, sample_input, input_weight
Expand All @@ -76,7 +79,12 @@ def linear_experiment(config: ExperimentConfig) -> ExperimentResult:
compiled_qlora_linear_time = nugs.utils.benchmark_torch_function_in_microseconds(
compiled_qlora_linear, sample_input, qlora_weight
)
bnb_linear_time = nugs.utils.benchmark_torch_function_in_microseconds(bnb_linear, sample_input)
if bnb_available:
bnb_linear_time = nugs.utils.benchmark_torch_function_in_microseconds(
bnb_linear, sample_input
)
else:
bnb_linear_time = -1.0

return ExperimentResult(
linear_time, qlora_linear_time, compiled_qlora_linear_time, bnb_linear_time
Expand All @@ -94,21 +102,26 @@ def mlp_experiment(config: ExperimentConfig) -> ExperimentResult:
mlp = qlora.MLP(*weights)
nf4_mlp = qlora.NF4MLP(*weights)
compiled_qlora_mlp = torch.compile(nf4_mlp, fullgraph=True, dynamic=config.dynamic)
bnb_mlp = qlora.BnbQloraMLP(*weights, config.device)
if bnb_available:
bnb_mlp = qlora.BnbQloraMLP(*weights, config.device)

# warmup
for _ in range(3):
mlp(sample_input)
nf4_mlp(sample_input)
compiled_qlora_mlp(sample_input)
bnb_mlp(sample_input)
if bnb_available:
bnb_mlp(sample_input)

mlp_time = nugs.utils.benchmark_torch_function_in_microseconds(mlp, sample_input)
qlora_mlp_time = nugs.utils.benchmark_torch_function_in_microseconds(nf4_mlp, sample_input)
compiled_qlora_mlp_time = nugs.utils.benchmark_torch_function_in_microseconds(
compiled_qlora_mlp, sample_input
)
bnb_mlp_time = nugs.utils.benchmark_torch_function_in_microseconds(bnb_mlp, sample_input)
if bnb_available:
bnb_mlp_time = nugs.utils.benchmark_torch_function_in_microseconds(bnb_mlp, sample_input)
else:
bnb_mlp_time = -1.0

return ExperimentResult(mlp_time, qlora_mlp_time, compiled_qlora_mlp_time, bnb_mlp_time)

Expand Down Expand Up @@ -137,22 +150,34 @@ def gen_configs() -> List[ExperimentConfig]:


def main(output_path: Optional[Path], profile_path: Optional[Path], dynamic: bool):
"""Run experiments and output results to file
Args:
output_path (Optional[Path]): Path to write out CSV file for experiment results.
profile_path (Optional[Path]): Path to write out json chrome trace file for an experiment.
dynamic (bool): Compile with Dynamic shapes
"""

results = []
for experiment_config in tqdm(gen_configs()):
# Since we are changing between dynamic and not
import torch._dynamo # noqa: F402

torch._dynamo.reset()
experiment = experiment_types[experiment_config.op]
experiment_result = experiment(experiment_config)
merged = asdict(experiment_config) | asdict(experiment_result)
results.append(merged)

if output_path is not None:
results = []
for experiment_config in tqdm(gen_configs()):
# Since we are changing between dynamic and not
import torch._dynamo # noqa: F402

torch._dynamo.reset()
experiment = experiment_types[experiment_config.op]
experiment_result = experiment(experiment_config)
merged = asdict(experiment_config) | asdict(experiment_result)
results.append(merged)

with open(output_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=results[0].keys())
writer.writeheader()
writer.writerows(results)
with open(output_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=results[0].keys())
writer.writeheader()
writer.writerows(results)
else:
headers = results[0].keys()
rows = [list(r.values()) for r in results]
print(tabulate(rows, headers=headers))

if profile_path is not None:
profile_experiment = ExperimentConfig(4096, 8, 128, torch.device("cuda:0"), "mlp", dynamic)
Expand All @@ -169,7 +194,7 @@ def main(output_path: Optional[Path], profile_path: Optional[Path], dynamic: boo

qlora_mlp = qlora.NF4MLP(*weights)
compiled_qlora_mlp = torch.compile(qlora_mlp, fullgraph=True, dynamic=dynamic)
print("dynamic = ", dynamic)
logging.info("Running torch.compile with dynamic = %s", dynamic)
profile_config = nugs.utils.ProfileConfig(
str(profile_path), "qlora_mlp", iters=5, warmup_iters=3, sync=True
)
Expand All @@ -183,34 +208,6 @@ def main(output_path: Optional[Path], profile_path: Optional[Path], dynamic: boo
if __name__ == "__main__":
"""Sample usage:
# Running sweep
python benchmarks/qlora.py -o benchmarks/data/qlora_sweep.csv
python benchmarks/qlora.py -p benchmarks/data/4096_8_128_qlora.json
python benchmarks/qlora.py false --output_path benchmarks/data/qlora_sweep.csv
"""
parser = argparse.ArgumentParser(description="Run experiments and output results to file")
parser.add_argument(
"-o",
"--output_file",
type=str,
help="Path to write out CSV file for experiment results.",
default=None,
)
parser.add_argument(
"-p",
"--profile_path",
type=str,
help="Path to write out json chrome trace file for an experiment.",
default=None,
)
parser.add_argument(
"--dynamic_shapes", action="store_true", help="Compile with Dynamic shapes"
)

args = parser.parse_args()
output_path = None
profile_path = None
if args.output_file is not None:
output_path = Path(args.output_file)
if args.profile_path is not None:
profile_path = Path(args.profile_path)

main(output_path, profile_path, args.dynamic_shapes)
CLI(main)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ dev = [
"pytest",
"flake8==6.1.0",
"flake8-pyproject",
"jsonargparse",
"docstring-parser"
]

qlora = ['bitsandbytes']
Expand Down
6 changes: 4 additions & 2 deletions test/test_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch.nn.functional as F

import transformer_nuggets.quant.qlora as qlora
from transformer_nuggets.quant import linear_nf4, NF4Tensor
from transformer_nuggets.quant import linear_nf4
from transformer_nuggets.quant.nf4_tensor import NF4Tensor
from transformer_nuggets.quant.qlora_debug import NF4TensorDebug

bnb_available = False
Expand Down Expand Up @@ -91,8 +92,9 @@ def test_binning_distribution(embed_dim: int):
@pytest.mark.parametrize("embed_dim", [256, 4096, 5120, 6656, 8192])
@pytest.mark.parametrize("compile", [True, False])
@pytest.mark.parametrize("requires_grad", [True, False])
@pytest.mark.xfail(reason="TORCH COMPILE No longer works here")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_autograd_func_to_eager(embed_dim: int, compile: bool, requires_grad: bool):
torch._dynamo.reset()
torch.manual_seed(0)
device = "cuda"
input_weight = qlora.build_input_weight(embed_dim, device)
Expand Down
3 changes: 2 additions & 1 deletion transformer_nuggets/quant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from transformer_nuggets.quant.qlora import get_block_absmax, linear_nf4, NF4Tensor
from transformer_nuggets.quant.nf4_tensor import get_block_absmax, NF4Tensor
from transformer_nuggets.quant.qlora import linear_nf4
from transformer_nuggets.quant.qlora_debug import NF4TensorDebug
Loading

0 comments on commit 7fae31b

Please sign in to comment.