Skip to content

Commit

Permalink
fix norm for nullptr crash on iGPU
Browse files Browse the repository at this point in the history
  • Loading branch information
Neo Zhang authored and Neo Zhang committed Jul 9, 2024
1 parent 3226bc1 commit 2551774
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions ggml/src/ggml-sycl/norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,15 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
sycl::range<1>(32), cgh);
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
s_sum_acc_ct1.get_pointer(), WARP_SIZE);
});
});
}
Expand Down Expand Up @@ -231,6 +233,8 @@ static void group_norm_f32_sycl(const float* x, float* dst,
if (group_size < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
cgh);
const float eps_ct4 = eps;
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
Expand All @@ -239,7 +243,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
nullptr, WARP_SIZE);
s_sum_acc_ct1.get_pointer(), WARP_SIZE);
});
});
}
Expand Down Expand Up @@ -279,13 +283,15 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
cgh);
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
s_sum_acc_ct1.get_pointer(), WARP_SIZE);
});
});
}
Expand Down

0 comments on commit 2551774

Please sign in to comment.