Skip to content

Commit

Permalink
Update benchmarks
Browse files Browse the repository at this point in the history
stack-info: PR: #39, branch: drisspg/stack/2
  • Loading branch information
drisspg committed Jan 17, 2025
1 parent a4c66bb commit 2c7e59d
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 183 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
Expand Down
119 changes: 82 additions & 37 deletions benchmarks/fp8_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,26 @@
preprocess_data,
Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_device_tma_persistent,
)

try:
from transformer_nuggets.fp8.fp8_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_device_tma_persistent,
)
except ModuleNotFoundError:
print("Triton version not new enough")
pass

from datetime import datetime
from enum import Enum
import csv
import logging

torch._dynamo.config.cache_size_limit = 1000
torch._dynamo.config.cache_size_limit = 10000
logging.getLogger("transformer_nuggets").setLevel(logging.INFO)
torch._inductor.config.max_autotune_gemm_backends = "TRITON"
CHECK = False


class FP8Kernel(Enum):
Expand Down Expand Up @@ -80,13 +90,14 @@ class ExperimentConfig:
scaling_strategy: ScalingStrategy
fp8_kernel: FP8Kernel
compile: bool
bf16: bool


@dataclass(frozen=True)
class ExperimentResult:
bf16_time: float
bf16_time: Optional[float]
fp8_time: float
bf16_tflops: float
bf16_tflops: Optional[float]
fp8_tflops: float


Expand All @@ -113,29 +124,34 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:

if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
bf16_matmul = torch.compile(bf16_matmul)
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune")
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs", dynamic=False)

# Warmup phase
warmup_iterations = 5
for _ in range(warmup_iterations):
_ = bf16_matmul(A, B)
if config.bf16:
_ = bf16_matmul(A, B)
_ = fp8_matmul()
torch.cuda.synchronize()

# Actual benchmarking
bf16_time = benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B))

bf16_time = (
benchmark_cuda_function_in_microseconds(lambda: bf16_matmul(A, B)) if config.bf16 else None
)
fp8_time = benchmark_cuda_function_in_microseconds(fp8_matmul)

# Calculate TFLOPS
bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time)
bf16_tflops = calculate_tflops(config.M, config.N, config.K, bf16_time) if bf16_time else None
fp8_tflops = calculate_tflops(config.M, config.N, config.K, fp8_time)

# Baseline fp8_matmul correctness
scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM)
out_base = scaled_mm_base()
out = fp8_matmul()
# Failing on one sample with large N
torch.testing.assert_close(out, out_base)
if CHECK:
scaled_mm_base = get_fp8_matmul(A, B, config.scaling_strategy, FP8Kernel.SCALED_MM)
out_base = scaled_mm_base()
out = fp8_matmul()
# Failing on one sample with large N
torch.testing.assert_close(out, out_base)

return ExperimentResult(
bf16_time=bf16_time, fp8_time=fp8_time, bf16_tflops=bf16_tflops, fp8_tflops=fp8_tflops
Expand All @@ -161,24 +177,38 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
for experiment in experiments:
config = experiment.config
result = experiment.result
speedup = result.bf16_time / result.fp8_time
tflops_ratio = result.fp8_tflops / result.bf16_tflops

# Format values handling None cases
bf16_time = f"{result.bf16_time:.4f}" if result.bf16_time is not None else "N/A"
fp8_time = f"{result.fp8_time:.4f}"
bf16_tflops = f"{result.bf16_tflops:.2f}" if result.bf16_tflops is not None else "N/A"
fp8_tflops = f"{result.fp8_tflops:.2f}"

# Calculate ratios only if bf16 results exist
if result.bf16_time is not None:
speedup = f"{(result.bf16_time / result.fp8_time):.2f}x"
tflops_ratio = f"{(result.fp8_tflops / result.bf16_tflops):.2f}x"
else:
speedup = "N/A"
tflops_ratio = "N/A"

rows.append(
[
config.M,
config.K,
config.N,
config.scaling_strategy,
config.fp8_kernel,
config.scaling_strategy.value,
config.fp8_kernel.value,
config.compile,
f"{result.bf16_time:.4f}",
f"{result.fp8_time:.4f}",
f"{speedup:.2f}x",
f"{result.bf16_tflops:.2f}",
f"{result.fp8_tflops:.2f}",
f"{tflops_ratio:.2f}x",
bf16_time,
fp8_time,
speedup,
bf16_tflops,
fp8_tflops,
tflops_ratio,
]
)

print(tabulate(rows, headers=headers, floatfmt=".4f"))

if save_path is not None:
Expand All @@ -189,33 +219,41 @@ def print_results(experiments: List[Experiment], save_path: Optional[Path] = Non
print(f"💾 Results saved to: {save_path}")


def get_configs_varying_k(M: int = 8192, N: int = 8192) -> List[ExperimentConfig]:
shapes = [(M, K, N) for K in range(512, 8193, 512)]
def get_configs_varying_k(
M: int = 8192, N: int = 8192, bf16: bool = False
) -> List[ExperimentConfig]:
shapes = [(M, K, N) for K in range(1024, 16385, 1024)]
scaling_strategies = [ScalingStrategy.PER_ROW]
compile_options = [False]
compile_options = [True, False]
configs = []
fp8_kernels = [
FP8Kernel.SCALED_MM,
# FP8Kernel.PERSISTENT,
FP8Kernel.PERSISTENT_TMA,
FP8Kernel.DEVICE_TMA,
# FP8Kernel.PERSISTENT_TMA,
# FP8Kernel.DEVICE_TMA,
]

for (M, K, N), strategy, compile, kernel in itertools.product(
shapes, scaling_strategies, compile_options, fp8_kernels
):
configs.append(
ExperimentConfig(
M=M, K=K, N=N, scaling_strategy=strategy, compile=compile, fp8_kernel=kernel
M=M,
K=K,
N=N,
scaling_strategy=strategy,
compile=compile,
fp8_kernel=kernel,
bf16=bf16,
)
)
return configs


def load_and_process_data(file_path):
df = pd.read_csv(file_path)
df["Speedup"] = df["Speedup"].str.rstrip("x").astype(float)
df["TFLOPS Ratio"] = df["TFLOPS Ratio"].str.rstrip("x").astype(float)
# df["Speedup"] = df["Speedup"].str.rstrip("x").astype(float)
# df["TFLOPS Ratio"] = df["TFLOPS Ratio"].str.rstrip("x").astype(float)
return df


Expand Down Expand Up @@ -250,17 +288,24 @@ def plot_tflops_comparison(df, save_path: Path):
print(f"TFLOPS comparison plot saved as {graph_path}")


def main(save_path: Optional[str] = None, M: int = 8192, N: int = 8192, graph: bool = False):
def main(
save_path: Optional[str] = None,
M: int = 8192,
N: int = 8192,
graph: bool = False,
bf_16: bool = False,
):
"""Benchmark FP8 MatMul with different configurations and optionally graph results.
Args:
save_path (Optional[str], optional): Path to save the results. Defaults to None.
M (int, optional): Number of rows in the first matrix. Defaults to 8192.
N (int, optional): Number of columns in the second matrix. Defaults to 8192.
graph_results (bool, optional): Whether to create a graph of the results. Defaults to False.
bf_16 (bool, optional): Whether to use BF16 for the baseline. Defaults to False.
"""
torch.random.manual_seed(123)
configs = get_configs_varying_k(M, N)
configs = get_configs_varying_k(M, N, bf16=bf_16)
results = []
if save_path is not None:
save_path = Path(save_path)
Expand Down
18 changes: 9 additions & 9 deletions benchmarks/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ def forward(
block_size = self.config.block_size
if max_seq_length is None:
max_seq_length = block_size
assert (
T <= max_seq_length
), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
assert (
max_seq_length <= block_size
), f"Cannot attend to {max_seq_length}, block size is only {block_size}"
assert (
T <= block_size
), f"Cannot forward sequence of length {T}, block size is only {block_size}"
assert T <= max_seq_length, (
f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
)
assert max_seq_length <= block_size, (
f"Cannot attend to {max_seq_length}, block size is only {block_size}"
)
assert T <= block_size, (
f"Cannot forward sequence of length {T}, block size is only {block_size}"
)

if self.rope_cache is None:
self.rope_cache = self.build_rope_cache(idx)
Expand Down
18 changes: 18 additions & 0 deletions transformer_nuggets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,19 @@
from transformer_nuggets import quant as quant, utils as utils


def init_logging():
"""
Configure logging for transformer_nuggets library at INFO level.
Adds a StreamHandler if none exists.
"""
import logging

logger = logging.getLogger("transformer_nuggets")
logger.setLevel(logging.INFO)

if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
6 changes: 3 additions & 3 deletions transformer_nuggets/flex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def visualize_attention_scores(
Returns:
None
"""
assert (
score_mod is not None or mask_mod is not None
), "Must provide either score_mod or mask_mod"
assert score_mod is not None or mask_mod is not None, (
"Must provide either score_mod or mask_mod"
)
query = query[batch_idx, head_idx, :, :]
key = key[batch_idx, head_idx, :, :]
scores_viz = create_score_mod(
Expand Down
24 changes: 12 additions & 12 deletions transformer_nuggets/fp8/fp8_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ def validate_matmul_inputs(
if ROW_WISE_SCALING:
assert a_scale.dim() == 2, f"a_scale must be a 2D tensor but got {a_scale.dim()}"
assert b_scale.dim() == 2, f"b_scale must be a 2D tensor but got {b_scale.dim()}"
assert (
a_scale.shape[0] == a.shape[0]
), f"a_scale must have same number of rows as a, got {a_scale.shape[0]} vs {a.shape[0]}"
assert a_scale.shape[0] == a.shape[0], (
f"a_scale must have same number of rows as a, got {a_scale.shape[0]} vs {a.shape[0]}"
)
assert a_scale.shape[1] == 1, f"a_scale must have 1 column, got {a_scale.shape[1]}"
assert (
b_scale.shape[1] == b.shape[1]
), f"b_scale must have same number of columns as b, got {b_scale.shape[0]} vs {b.shape[1]}"
assert b_scale.shape[1] == b.shape[1], (
f"b_scale must have same number of columns as b, got {b_scale.shape[0]} vs {b.shape[1]}"
)
assert b_scale.shape[0] == 1, f"b_scale must have 1 column, got {b_scale.shape[1]}"
else:
assert (
a_scale.numel() == 1
), f"a_scale must be a scalar for per-tensor scaling, got shape {a_scale.shape}"
assert (
b_scale.numel() == 1
), f"b_scale must be a scalar for per-tensor scaling, got shape {b_scale.shape}"
assert a_scale.numel() == 1, (
f"a_scale must be a scalar for per-tensor scaling, got shape {a_scale.shape}"
)
assert b_scale.numel() == 1, (
f"b_scale must be a scalar for per-tensor scaling, got shape {b_scale.shape}"
)

return ROW_WISE_SCALING

Expand Down
6 changes: 3 additions & 3 deletions transformer_nuggets/llama/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,9 @@ def entrypoint(
overfit: bool = False,
profile: bool = False,
):
assert (
isinstance(fp8_linear_type, str) or fp8_linear_type is None
), "fp8_linear_type must be str"
assert isinstance(fp8_linear_type, str) or fp8_linear_type is None, (
"fp8_linear_type must be str"
)
assert isinstance(compile, bool), "compile must be bool"
assert isinstance(overfit, bool), "overfit must be bool"
assert isinstance(profile, bool), "profile must be bool"
Expand Down
6 changes: 3 additions & 3 deletions transformer_nuggets/quant/dequant_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def dequant_nf4_tensor_kernel(
def dequant_nf4_tensor(weight: NF4Tensor):
"""Takes a quantized tensor and dequantizes it to bfloat16"""
assert isinstance(weight, NF4Tensor), "Input tensor must be of type NF4Tensor"
assert (
weight.shape.numel() % weight.block_size == 0
), "Input tensor must be a multiple of block size"
assert weight.shape.numel() % weight.block_size == 0, (
"Input tensor must be a multiple of block size"
)
out_tensor = torch.empty(weight.shape, dtype=weight.dtype, device="cuda")
numel = weight.shape.numel()
grid = (triton.cdiv(numel, (weight.block_size)),)
Expand Down
Loading

0 comments on commit 2c7e59d

Please sign in to comment.