-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[sharktank] restore custom matmul kernel #896
Conversation
cannot reproduce the numeric result with this patch,here is the log danfix.log |
This argument can be skipped. Then the result dtype is inferred from the arguments. This requires iree-org/iree-turbine#451
I added some modifications that require merging iree-org/iree-turbine#451 first. I made the |
This change needs an IREE Turbine bump #916. |
are you still working on this @sogartar |
@dan-garvey the kernel is ready. I also cleaned up linear and qlinear. |
I did not notice this failing unit test. I will see what is the problem. |
The kernel does not work with unsigned types. It needs to be adapted. I made an exception in qlinear if there are unsigned integer types to first promote to the accumulation dtype. |
@dan-garvey let me know if it looks good. I will approve the PR as you can't probably approve your own PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dan-garvey Before merged, we need this issue fix. I got error when I tried to export input bin for f8 model, the error tracked to batch_matmul_transpose_b.py
paged_llm_v_export_f8_input_bin_bug.log @aviator19941 need your review as well to ensure the paged_llm_v1.py work as well.
@AmosLewis we get this error when compiling for llvm-cpu. It compiles for amdgpu. I guess we should put an eager implementation for the kernel. |
I will implement to eager path to not use IREE. |
Get a similar error in paged_llama_attention_block.py with rebase eager commit https://gist.github.com/AmosLewis/3860a0371236b528b24c86c48e1e31c2?permalink_comment_id=5433466#gistcomment-5433466 |
I will run the model locally and debug this problem. |
@sogartar git commit 5bf4636 fixing |
May bad. I will fix it. |
Here is an IREE issue for the compilation failure iree-org/iree#20002. |
page_table.index_put_( | ||
indices=indices, values=ops.to(cache_partition, dtype=page_table.dtype) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing element-wise conversion from f8 to int8 before writing to the cache. I think you need view(dtype)
instead of to()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I noticed this problem. I am preparing a fix, but there are other numerical issues that also need to be resolved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing this compilation error now with those changes:
fp8_dan_attn_q_bf16_f32_v4.mlir:488:12: error: 'flow.tensor.bitcast' op value set has 1 dynamic dimensions but only 0 dimension values are attached
%178 = torch.aten.view.dtype %168, %int1_151 : !torch.vtensor<[?,32,8,128],f8E4M3FNUZ>, !torch.int -> !torch.vtensor<[?,32,8,128],si8>
^
fp8_dan_attn_q_bf16_f32_v4.mlir:488:12: note: see current operation: %244 = "flow.tensor.bitcast"(%243) <{operandSegmentSizes = array<i32: 1, 0, 0>}> : (tensor<?x32x8x128xf8E4M3FNUZ>) -> tensor<?x32x8x128xi8>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was recently fixed in iree. You need an iree build with iree-org/iree@d38b22e (merged yesterday)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok cool thanks!
…le_custom_fp8_matmul
extra_filenames=["tokenizer.json"], | ||
), | ||
), | ||
).alias_to("llama3_8B_fp16") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This model cannot be removed as it is part of llama_serving.md
release docs, that was tested in sharktank/shortfin. However, agree that we need only one of the 2 llama 8b fp16 models listed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docs could be updated to use standard huggingface tooling instead of sharktank.utils.hf_datasets
. As written today, the hf_datasets file should only be a utility for project development and testing, not something user-facing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed the doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama3_8B_fp16
is the right instruct version that was tested to be numerically right for the release, the other model is non-instruct version and I remember it was generating repetitive tokens.
This can be a separate PR where we consult with shortfin folks if this model switch on the release docs can be made.
@archana-ramalingam, although I technically reviewed the PR, could you review it as I am now the one making changes. |
Make the quantized linear op use the batch_matmul_transpose_b kernel. Make the batch_matmul_transpose_b kernel accept accum_dtype argument that specifies in what precision the operation is performed. Make LLM models accept a kv-cache-dtype argument that specifies only the dtype of the KV cache that can be different from the attention dtype. This config option is used in Llama f8. For Llama f8 the IREE change iree-org/iree#20005 is required. Fixes iree-org/iree#19859 (comment) ``` python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/llama3_8b_fp8.irpa \ --output-mlir=fp8_dan1.mlir \ --output-config=config1.json \ --bs=1 \ --attention-kernel torch \ --attention-dtype=bfloat16 \ --activation-dtype=bfloat16 \ --kv-cache-dtype=float8_e4m3fnuz ``` --------- Signed-off-by: Boian Petkantchin<boian.petkantchin@amd.com> Co-authored-by: aviator19941 <avinash.sharma@amd.com> Co-authored-by: archana-ramalingam <archana.ramalingam@amd.com>
fix iree-org/iree#19859 (comment)