diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 6de8a9d611362..ee632ac3b69c6 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -190,13 +190,15 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols, if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1( + sycl::range<1>(32), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { norm_f32(x, dst, ncols, eps, item_ct1, - nullptr, WARP_SIZE); + s_sum_acc_ct1.get_pointer(), WARP_SIZE); }); }); } @@ -231,6 +233,8 @@ static void group_norm_f32_sycl(const float* x, float* dst, if (group_size < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(32), + cgh); const float eps_ct4 = eps; cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, @@ -239,7 +243,7 @@ static void group_norm_f32_sycl(const float* x, float* dst, [[intel::reqd_sub_group_size(WARP_SIZE)]] { group_norm_f32( x, dst, group_size, ne_elements, eps_ct4, item_ct1, - nullptr, WARP_SIZE); + s_sum_acc_ct1.get_pointer(), WARP_SIZE); }); }); } @@ -279,13 +283,15 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(32), + cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { rms_norm_f32(x, dst, ncols, eps, item_ct1, - nullptr, WARP_SIZE); + s_sum_acc_ct1.get_pointer(), WARP_SIZE); }); }); }