Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharktank] restore custom matmul kernel #896

Merged
merged 35 commits into from
Feb 21, 2025

Conversation

dan-garvey
Copy link
Member

@dan-garvey dan-garvey commented Feb 2, 2025

fix iree-org/iree#19859 (comment)

python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/llama3_8b_fp8.irpa \
--output-mlir=fp8_dan1.mlir \
--output-config=config1.json \
--bs=1 --attention-kernel torch \
--attention-dtype=float8_e4m3fnuz --activation-dtype=bfloat16

@AmosLewis
Copy link

cannot reproduce the numeric result with this patch,here is the log danfix.log

This argument can be skipped. Then the result dtype is inferred from the
arguments.

This requires iree-org/iree-turbine#451
@sogartar
Copy link
Contributor

sogartar commented Feb 4, 2025

I added some modifications that require merging iree-org/iree-turbine#451 first.

I made the batch_matmul_transpose_b kernel accept an optional accumulation dtype.

@sogartar
Copy link
Contributor

sogartar commented Feb 5, 2025

This change needs an IREE Turbine bump #916.

@dan-garvey
Copy link
Member Author

are you still working on this @sogartar

@sogartar
Copy link
Contributor

sogartar commented Feb 8, 2025

@dan-garvey the kernel is ready. I also cleaned up linear and qlinear.
Should this kernel be also used in the sharktank.ops.matmul? I have not added it there. The only places that this seems useful is when we want to control the accumulation dtype. I assume that b transposition and then matmul would be fused by the compiler.

@sogartar
Copy link
Contributor

sogartar commented Feb 8, 2025

I did not notice this failing unit test. I will see what is the problem.
I added one xfail batch_matmul_transpose_b kernel test that can't be compiled for llvm-cpu. It does compile for amdgpu.

@sogartar
Copy link
Contributor

sogartar commented Feb 8, 2025

The kernel does not work with unsigned types. It needs to be adapted.
The problem is that we reinterpret cast to signless integer types and lose this info when we get to linalg.batch_matmul_transpose_b. Maybe linalg.batch_matmul_transpose_b when promoting from i8 to i32 assumes a
signed type even though i8 is signless.

I made an exception in qlinear if there are unsigned integer types to first promote to the accumulation dtype.

@sogartar
Copy link
Contributor

sogartar commented Feb 8, 2025

@dan-garvey let me know if it looks good. I will approve the PR as you can't probably approve your own PR.

Copy link

@AmosLewis AmosLewis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dan-garvey Before merged, we need this issue fix. I got error when I tried to export input bin for f8 model, the error tracked to batch_matmul_transpose_b.py
paged_llm_v_export_f8_input_bin_bug.log @aviator19941 need your review as well to ensure the paged_llm_v1.py work as well.

@sogartar
Copy link
Contributor

@AmosLewis we get this error when compiling for llvm-cpu. It compiles for amdgpu. I guess we should put an eager implementation for the kernel.

@sogartar
Copy link
Contributor

I will implement to eager path to not use IREE.

@AmosLewis
Copy link

AmosLewis commented Feb 11, 2025

@AmosLewis we get this error when compiling for llvm-cpu. It compiles for amdgpu. I guess we should put an eager implementation for the kernel.

Get a similar error in paged_llama_attention_block.py with rebase eager commit https://gist.github.com/AmosLewis/3860a0371236b528b24c86c48e1e31c2?permalink_comment_id=5433466#gistcomment-5433466

@sogartar
Copy link
Contributor

I will run the model locally and debug this problem.

@archana-ramalingam
Copy link
Collaborator

@sogartar git commit 5bf4636 fixing .index_copy_ breaks iree-compile for fp8 8b.
Error: https://gist.github.com/archana-ramalingam/7347234893c5eba773d826283d5a635d

@sogartar
Copy link
Contributor

May bad. I will fix it.

@sogartar
Copy link
Contributor

Here is an IREE issue for the compilation failure iree-org/iree#20002.

Comment on lines 278 to 280
page_table.index_put_(
indices=indices, values=ops.to(cache_partition, dtype=page_table.dtype)
)
Copy link

@IanWood1 IanWood1 Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing element-wise conversion from f8 to int8 before writing to the cache. I think you need view(dtype) instead of to()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I noticed this problem. I am preparing a fix, but there are other numerical issues that also need to be resolved.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing this compilation error now with those changes:

fp8_dan_attn_q_bf16_f32_v4.mlir:488:12: error: 'flow.tensor.bitcast' op value set has 1 dynamic dimensions but only 0 dimension values are attached
    %178 = torch.aten.view.dtype %168, %int1_151 : !torch.vtensor<[?,32,8,128],f8E4M3FNUZ>, !torch.int -> !torch.vtensor<[?,32,8,128],si8>
           ^
fp8_dan_attn_q_bf16_f32_v4.mlir:488:12: note: see current operation: %244 = "flow.tensor.bitcast"(%243) <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (tensor<?x32x8x128xf8E4M3FNUZ>) -> tensor<?x32x8x128xi8>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was recently fixed in iree. You need an iree build with iree-org/iree@d38b22e (merged yesterday)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok cool thanks!

extra_filenames=["tokenizer.json"],
),
),
).alias_to("llama3_8B_fp16")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model cannot be removed as it is part of llama_serving.md release docs, that was tested in sharktank/shortfin. However, agree that we need only one of the 2 llama 8b fp16 models listed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs could be updated to use standard huggingface tooling instead of sharktank.utils.hf_datasets. As written today, the hf_datasets file should only be a utility for project development and testing, not something user-facing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the doc.

Copy link
Collaborator

@archana-ramalingam archana-ramalingam Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama3_8B_fp16 is the right instruct version that was tested to be numerically right for the release, the other model is non-instruct version and I remember it was generating repetitive tokens.
This can be a separate PR where we consult with shortfin folks if this model switch on the release docs can be made.

@sogartar
Copy link
Contributor

@archana-ramalingam, although I technically reviewed the PR, could you review it as I am now the one making changes.

@sogartar sogartar merged commit 6397ead into main Feb 21, 2025
36 checks passed
@sogartar sogartar deleted the users/dan-garvey/enable_custom_fp8_matmul branch February 21, 2025 21:05
IanNod pushed a commit that referenced this pull request Feb 22, 2025
Make the quantized linear op use the batch_matmul_transpose_b kernel.

Make the batch_matmul_transpose_b kernel accept accum_dtype argument that specifies in what precision the operation is performed.

Make LLM models accept a kv-cache-dtype argument that specifies only the dtype of the KV cache that can be different from the attention dtype. This config option is used in Llama f8.

For Llama f8 the IREE change iree-org/iree#20005 is required.

Fixes iree-org/iree#19859 (comment)

```
python3 -m sharktank.examples.export_paged_llm_v1
--irpa-file=/sharedfile/llama3_8b_fp8.irpa \
--output-mlir=fp8_dan1.mlir \
--output-config=config1.json \
--bs=1 \
--attention-kernel torch \
--attention-dtype=bfloat16  \
--activation-dtype=bfloat16 \
--kv-cache-dtype=float8_e4m3fnuz
```

---------

Signed-off-by: Boian Petkantchin<boian.petkantchin@amd.com>
Co-authored-by: aviator19941 <avinash.sharma@amd.com>
Co-authored-by: archana-ramalingam <archana.ramalingam@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Numeric issue for llama_8b_fp8 model on hip
7 participants