diff --git a/src/layer/arm/binaryop_arm.cpp b/src/layer/arm/binaryop_arm.cpp index 55fb165911e..a9f997ef487 100644 --- a/src/layer/arm/binaryop_arm.cpp +++ b/src/layer/arm/binaryop_arm.cpp @@ -285,6 +285,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, div_ps(y, x)) MAKE_FUNCTION(binary_op_rpow, (float)powf(y, x), pow_ps(y, x)) MAKE_FUNCTION(binary_op_atan2, (float)atan2f(x, y), atan2_ps(x, y)) MAKE_FUNCTION(binary_op_ratan2, (float)atan2f(y, x), atan2_ps(y, x)) +MAKE_FUNCTION(binary_op_remainder, remainderf(x, y), remainder_ps(x, y)) // *INDENT-ON* // clang-format on @@ -308,6 +309,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } diff --git a/src/layer/arm/binaryop_arm_asimdhp.cpp b/src/layer/arm/binaryop_arm_asimdhp.cpp index 8c49cadb88c..e0735878e3c 100644 --- a/src/layer/arm/binaryop_arm_asimdhp.cpp +++ b/src/layer/arm/binaryop_arm_asimdhp.cpp @@ -329,6 +329,7 @@ MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vdiv_f16(y, x), vdivq_f16(y, x)) MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)powf(y, x), vcvt_f16_f32(pow_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x)))))) MAKE_FUNCTION(binary_op_atan2_fp16s, (__fp16)atan2f(x, y), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))), vcombine_f16(vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_low_f16(y)))), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_high_f16(x)), vcvt_f32_f16(vget_high_f16(y)))))) MAKE_FUNCTION(binary_op_ratan2_fp16s, (__fp16)atan2f(y, x), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x)))))) +MAKE_FUNCTION(binary_op_remainder_fp16s, (__fp16)remainderf(x, y), vcvt_f16_f32(remainder_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))), vcombine_f16(vcvt_f16_f32(remainder_ps(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_low_f16(y)))), vcvt_f16_f32(remainder_ps(vcvt_f32_f16(vget_high_f16(x)), vcvt_f32_f16(vget_high_f16(y)))))) // *INDENT-ON* // clang-format on @@ -352,6 +353,7 @@ static void binary_op_vector_fp16s(const __fp16* ptr, const __fp16* ptr1, __fp16 if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } diff --git a/src/layer/arm/neon_mathfun.h b/src/layer/arm/neon_mathfun.h index 2b4094a9ed7..faafa42531a 100644 --- a/src/layer/arm/neon_mathfun.h +++ b/src/layer/arm/neon_mathfun.h @@ -395,5 +395,16 @@ static inline float32x4_t atan2_ps(float32x4_t a, float32x4_t b) return vld1q_f32(tmpx); } +static inline float32x4_t remainder_ps(float32x4_t x, float32x4_t y) +{ + float tmpx[4]; + float tmpy[4]; + vst1q_f32(tmpx, x); + vst1q_f32(tmpy, y); + for (int i = 0; i < 4; i++) + tmpx[i] = remainderf(tmpx[i], tmpy[i]); + return vld1q_f32(tmpx); +} + #include "neon_mathfun_tanh.h" #endif // NEON_MATHFUN_H diff --git a/src/layer/binaryop.cpp b/src/layer/binaryop.cpp index 52d3d083b31..e65210d977e 100644 --- a/src/layer/binaryop.cpp +++ b/src/layer/binaryop.cpp @@ -237,6 +237,17 @@ struct binary_op_ratan2 } }; +struct binary_op_remainder +{ + float operator()(const float& x, const float& y) const + { + const float div_result = x / y; + const float floor_result = floorf(div_result); + const float mul_result = floor_result * y; + return x - mul_result; + } +}; + static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) { if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast(a, b, c, opt); @@ -251,6 +262,7 @@ static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast(b, a, c, opt); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast(a, b, c, opt); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast(b, a, c, opt); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_broadcast(a, b, c, opt); // should never reach here } @@ -269,6 +281,7 @@ static void binary_op_scalar_inplace(Mat& bottom_top_blob, float b, int op_type, if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar_inplace(bottom_top_blob, b, opt); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_scalar_inplace(bottom_top_blob, b, opt); // should never reach here } diff --git a/src/layer/binaryop.h b/src/layer/binaryop.h index 5fc06918d20..f22d970be6c 100644 --- a/src/layer/binaryop.h +++ b/src/layer/binaryop.h @@ -45,7 +45,8 @@ class BinaryOp : public Layer Operation_RDIV = 8, Operation_RPOW = 9, Operation_ATAN2 = 10, - Operation_RATAN2 = 11 + Operation_RATAN2 = 11, + Operation_REMAINDER = 12 }; public: diff --git a/src/layer/loongarch/binaryop_loongarch.cpp b/src/layer/loongarch/binaryop_loongarch.cpp index 33916d966aa..6e35b86faf6 100644 --- a/src/layer/loongarch/binaryop_loongarch.cpp +++ b/src/layer/loongarch/binaryop_loongarch.cpp @@ -312,6 +312,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __lsx_vfdiv_s(y, x)) MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x)) MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y)) MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x)) +MAKE_FUNCTION(binary_op_remainder, remainderf(x, y), remainder_ps(x, y)) // *INDENT-ON* // clang-format on @@ -335,6 +336,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } diff --git a/src/layer/loongarch/lsx_mathfun.h b/src/layer/loongarch/lsx_mathfun.h index 194f63bedc3..a09cd6d711a 100644 --- a/src/layer/loongarch/lsx_mathfun.h +++ b/src/layer/loongarch/lsx_mathfun.h @@ -269,4 +269,17 @@ static inline __m128 atan2_ps(__m128 a, __m128 b) return (__m128)__lsx_vld(tmpx, 0); } +static inline __m128 remainder_ps(__m128 x, __m128 y) +{ + float tmpx[4]; + float tmpy[4]; + __lsx_vst(x, tmpx, 0); + __lsx_vst(y, tmpy, 0); + tmpx[0] = remainderf(tmpx[0], tmpy[0]); + tmpx[1] = remainderf(tmpx[1], tmpy[1]); + tmpx[2] = remainderf(tmpx[2], tmpy[2]); + tmpx[3] = remainderf(tmpx[3], tmpy[3]); + return (__m128)__lsx_vld(tmpx, 0); +} + #endif // LSX_MATHFUN_H diff --git a/src/layer/mips/binaryop_mips.cpp b/src/layer/mips/binaryop_mips.cpp index 188a0860508..7cc27a0b74f 100644 --- a/src/layer/mips/binaryop_mips.cpp +++ b/src/layer/mips/binaryop_mips.cpp @@ -312,6 +312,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __msa_fdiv_w(y, x)) MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x)) MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y)) MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x)) +MAKE_FUNCTION(binary_op_remainder, remainderf(x, y), remainder_ps(x, y)) // *INDENT-ON* // clang-format on @@ -335,6 +336,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } diff --git a/src/layer/mips/msa_mathfun.h b/src/layer/mips/msa_mathfun.h index cab71acbc6b..07b30ea7f85 100644 --- a/src/layer/mips/msa_mathfun.h +++ b/src/layer/mips/msa_mathfun.h @@ -267,4 +267,17 @@ static inline v4f32 atan2_ps(v4f32 a, v4f32 b) return (v4f32)__msa_ld_w(tmpx, 0); } +static inline v4f32 remainder_ps(v4f32 x, v4f32 y) +{ + float tmpx[4]; + float tmpy[4]; + __msa_st_w((v4i32)x, tmpx, 0); + __msa_st_w((v4i32)y, tmpy, 0); + tmpx[0] = remainderf(tmpx[0], tmpy[0]); + tmpx[1] = remainderf(tmpx[1], tmpy[1]); + tmpx[2] = remainderf(tmpx[2], tmpy[2]); + tmpx[3] = remainderf(tmpx[3], tmpy[3]); + return (v4f32)__msa_ld_w(tmpx, 0); +} + #endif // MSA_MATHFUN_H diff --git a/src/layer/riscv/binaryop_riscv.cpp b/src/layer/riscv/binaryop_riscv.cpp index da4593197f4..fe7f9145e7b 100644 --- a/src/layer/riscv/binaryop_riscv.cpp +++ b/src/layer/riscv/binaryop_riscv.cpp @@ -293,6 +293,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, vfdiv_vv_f32m8(y, x, vl), vfrdiv_vf_f32m8(x MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x, vl), pow_ps(vfmv_v_f_f32m8(y, vl), x, vl), pow_ps(y, vfmv_v_f_f32m8(x, vl), vl)) MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y, vl), atan2_ps(x, vfmv_v_f_f32m8(y, vl), vl), atan2_ps(vfmv_v_f_f32m8(x, vl), y, vl)) MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x, vl), atan2_ps(vfmv_v_f_f32m8(y, vl), x, vl), atan2_ps(y, vfmv_v_f_f32m8(x, vl), vl)) +MAKE_FUNCTION(binary_op_remainder, (float)remainderf(x, y), remainder_ps(x, y, vl), remainder_ps(x, vfmv_v_f_f32m8(y, vl), vl), remainder_ps(vfmv_v_f_f32m8(x, vl), y, vl)) // *INDENT-ON* // clang-format on @@ -316,6 +317,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } @@ -887,6 +889,7 @@ MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vfdiv_vv_f16m8(y, x, vl), vfrdiv_vf_f MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)pow((float)y, (float)x), pow_ps(y, x, vl), pow_ps(vfmv_v_f_f16m8(y, vl), x, vl), pow_ps(y, vfmv_v_f_f16m8(x, vl), vl)) MAKE_FUNCTION(binary_op_atan2_fp16s, (__fp16)atan2((float)x, (float)y), atan2_ps(x, y, vl), atan2_ps(x, vfmv_v_f_f16m8(y, vl), vl), atan2_ps(vfmv_v_f_f16m8(x, vl), y, vl)) MAKE_FUNCTION(binary_op_ratan2_fp16s, (__fp16)atan2((float)y, (float)x), atan2_ps(y, x, vl), atan2_ps(vfmv_v_f_f16m8(y, vl), x, vl), atan2_ps(y, vfmv_v_f_f16m8(x, vl), vl)) +MAKE_FUNCTION(binary_op_remainder_fp16s, (__fp16)remainderf((float)x, (float)y), remainder_ps(x, y, vl), remainder_ps(x, vfmv_v_f_f16m8(y, vl), vl), remainder_ps(vfmv_v_f_f16m8(x, vl), y, vl)) // *INDENT-ON* // clang-format on @@ -910,6 +913,7 @@ static void binary_op_vector_fp16s(const __fp16* ptr, const __fp16* ptr1, __fp16 if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_fp16s(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } diff --git a/src/layer/riscv/rvv_mathfun.h b/src/layer/riscv/rvv_mathfun.h index 34f072788e5..1f54e1ee715 100644 --- a/src/layer/riscv/rvv_mathfun.h +++ b/src/layer/riscv/rvv_mathfun.h @@ -580,4 +580,23 @@ _RVV_FLOAT32_ATAN2_OP(2, 16) _RVV_FLOAT32_ATAN2_OP(4, 8) _RVV_FLOAT32_ATAN2_OP(8, 4) +#define _RVV_FLOAT32_REMAINDER_OP(LMUL, MLEN) \ + static inline vfloat32m##LMUL##_t remainder_ps(vfloat32m##LMUL##_t x, vfloat32m##LMUL##_t y, size_t vl) \ + { \ + std::vector tmpx(vl); \ + std::vector tmpy(vl); \ + vse32_v_f32m##LMUL(tmpx.data(), x, vl); \ + vse32_v_f32m##LMUL(tmpy.data(), y, vl); \ + for (size_t i = 0; i < vl; i++) \ + { \ + tmpx[i] = remainderf(tmpx[i], tmpy[i]); \ + } \ + return vle32_v_f32m##LMUL(tmpx.data(), vl); \ + } + +_RVV_FLOAT32_REMAINDER_OP(1, 32) +_RVV_FLOAT32_REMAINDER_OP(2, 16) +_RVV_FLOAT32_REMAINDER_OP(4, 8) +_RVV_FLOAT32_REMAINDER_OP(8, 4) + #endif // RVV_MATHFUN_H diff --git a/src/layer/riscv/rvv_mathfun_fp16s.h b/src/layer/riscv/rvv_mathfun_fp16s.h index ee5ffe4a304..5838e3c4c21 100644 --- a/src/layer/riscv/rvv_mathfun_fp16s.h +++ b/src/layer/riscv/rvv_mathfun_fp16s.h @@ -416,4 +416,23 @@ _RVV_FLOAT16_ATAN2_OP(2, 16) _RVV_FLOAT16_ATAN2_OP(4, 8) _RVV_FLOAT16_ATAN2_OP(8, 4) +#define _RVV_FLOAT16_REMAINDER_OP(LMUL, MLEN) \ + static inline vfloat16m##LMUL##_t remainder_ps(vfloat16m##LMUL##_t x, vfloat16m##LMUL##_t y, size_t vl) \ + { \ + std::vector<__fp16> tmpx(vl); \ + std::vector<__fp16> tmpy(vl); \ + vse16_v_f16m##LMUL(tmpx.data(), x, vl); \ + vse16_v_f16m##LMUL(tmpy.data(), y, vl); \ + for (size_t i = 0; i < vl; i++) \ + { \ + tmpx[i] = (__fp16)remainderf((float)tmpx[i], (float)tmpy[i]); \ + } \ + return vle16_v_f16m##LMUL(tmpx.data(), vl); \ + } + +_RVV_FLOAT16_REMAINDER_OP(1, 32) +_RVV_FLOAT16_REMAINDER_OP(2, 16) +_RVV_FLOAT16_REMAINDER_OP(4, 8) +_RVV_FLOAT16_REMAINDER_OP(8, 4) + #endif // RVV_MATHFUN_FP16S_H diff --git a/src/layer/vulkan/shader/binaryop.comp b/src/layer/vulkan/shader/binaryop.comp index 18f566a2a72..a6632e0ee1d 100644 --- a/src/layer/vulkan/shader/binaryop.comp +++ b/src/layer/vulkan/shader/binaryop.comp @@ -137,6 +137,7 @@ void main() if (op_type == 10) res = atan(v1, v2); if (op_type == 11) res = atan(v2, v1); #endif + if (op_type == 12) res = v1 - floorf(v1 / v2) * v2; #if NCNN_image_shader image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/binaryop_broadcast.comp b/src/layer/vulkan/shader/binaryop_broadcast.comp index 732e3f50b0a..93edcb886ab 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast.comp @@ -199,6 +199,7 @@ void main() if (op_type == 10) res = atan(v1, v2); if (op_type == 11) res = atan(v2, v1); #endif + if (op_type == 12) res = v1 - floorf(v1 / v2) * v2; #if NCNN_image_shader image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp index ced3933db4a..64cd537b71f 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp @@ -130,6 +130,7 @@ void main() if (op_type == 10) res = atan(v1, v2); if (op_type == 11) res = atan(v2, v1); #endif + if (op_type == 12) res = v1 - floorf(v1 / v2) * v2; #if NCNN_image_shader image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack1to8.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack1to8.comp index 963f9c0030c..f497e0726b2 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack1to8.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack1to8.comp @@ -187,6 +187,11 @@ void main() res[1] = atan(v2[1], v1[1]); #endif } + if (op_type == 12) + { + res[0] = v1[0] - floorf(v1[0] / v2[0]) * v2[0]; + res[1] = v1[1] - floorf(v1[1] / v2[1]) * v2[1]; + } #if NCNN_image_shader image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp index a0f0376b09e..4a9116dfc4f 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp @@ -199,6 +199,7 @@ void main() if (op_type == 10) res = atan(v1, v2); if (op_type == 11) res = atan(v2, v1); #endif + if (op_type == 12) res = v1 - floorf(v1 / v2) * v2; #if NCNN_image_shader image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp index b9e7d492bb9..1b6ca17bcbb 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp @@ -253,6 +253,11 @@ void main() res[1] = atan(v2[1], v1[1]); #endif } + if (op_type == 12) + { + res[0] = v1[0] - floorf(v1[0] / v2[0]) * v2[0]; + res[1] = v1[1] - floorf(v1[1] / v2[1]) * v2[1]; + } #if NCNN_image_shader image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/binaryop_pack4.comp b/src/layer/vulkan/shader/binaryop_pack4.comp index 0189253fb3d..f86d53ac8b7 100644 --- a/src/layer/vulkan/shader/binaryop_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_pack4.comp @@ -128,6 +128,7 @@ void main() if (op_type == 10) res = atan(v1, v2); if (op_type == 11) res = atan(v2, v1); #endif + if (op_type == 12) res = v1 - floorf(v1 / v2) * v2; #if NCNN_image_shader image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/binaryop_pack8.comp b/src/layer/vulkan/shader/binaryop_pack8.comp index 9fe54902bd5..1be2bd5fc18 100644 --- a/src/layer/vulkan/shader/binaryop_pack8.comp +++ b/src/layer/vulkan/shader/binaryop_pack8.comp @@ -183,6 +183,11 @@ void main() res[1] = atan(v2[1], v1[1]); #endif } + if (op_type == 12) + { + res[0] = v1[0] - floorf(v1[0] / v2[0]) * v2[0]; + res[1] = v1[1] - floorf(v1[1] / v2[1]) * v2[1]; + } #if NCNN_image_shader image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/x86/avx512_mathfun.h b/src/layer/x86/avx512_mathfun.h index b5e47bdbe68..4bd8d074202 100644 --- a/src/layer/x86/avx512_mathfun.h +++ b/src/layer/x86/avx512_mathfun.h @@ -856,4 +856,12 @@ static NCNN_FORCEINLINE __m512 abs512_ps(__m512 x) return _mm512_andnot_ps(magic_negative_zero, x); } +static NCNN_FORCEINLINE __m512 remainder512_ps(__m512 x, __m512 y) +{ + const __m512 div_result = _mm512_div_ps(x, y); + const __m512 floor_result = _mm512_floor_ps(div_result); + const __m512 mul_result = _mm512_mul_ps(y, floor_result); + return _mm512_sub_ps(x, mul_result); +} + #endif // AVX512_MATHFUN_H diff --git a/src/layer/x86/avx_mathfun.h b/src/layer/x86/avx_mathfun.h index 65c34efc23e..d3708f7fb1b 100644 --- a/src/layer/x86/avx_mathfun.h +++ b/src/layer/x86/avx_mathfun.h @@ -1087,4 +1087,12 @@ static NCNN_FORCEINLINE __m256 abs256_ps(__m256 x) return _mm256_andnot_ps(magic_negative_zero, x); } +static NCNN_FORCEINLINE __m256 remainder256_ps(__m256 x, __m256 y) +{ + const __m256 div_result = _mm256_div_ps(x, y); + const __m256 floor_result = _mm256_floor_ps(div_result); + const __m256 mul_result = _mm256_mul_ps(y, floor_result); + return _mm256_sub_ps(x, mul_result); +} + #endif // AVX_MATHFUN_H diff --git a/src/layer/x86/binaryop_x86.cpp b/src/layer/x86/binaryop_x86.cpp index 14ad9d5f638..ab2587326c4 100644 --- a/src/layer/x86/binaryop_x86.cpp +++ b/src/layer/x86/binaryop_x86.cpp @@ -789,6 +789,35 @@ struct binary_op_ratan2 #endif // __SSE2__ }; +struct binary_op_remainder +{ + float func(const float& x, const float& y) const + { + const float div_result = x / y; + const float floor_result = floorf(div_result); + const float mul_result = floor_result * y; + return x - mul_result; + } +#if __SSE2__ + __m128 func_pack4(const __m128& x, const __m128& y) const + { + return remainder_ps(x, y); + } +#if __AVX__ + __m256 func_pack8(const __m256& x, const __m256& y) const + { + return remainder256_ps(x, y); + } +#if __AVX512F__ + __m512 func_pack16(const __m512& x, const __m512& y) const + { + return remainder512_ps(x, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + } // namespace BinaryOp_x86_functor static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, int aw, int bw, int ap, int bp, int op_type) @@ -807,6 +836,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); + if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector(ptr, ptr1, outptr, aw, bw, ap, bp); // should never reach here } diff --git a/src/layer/x86/sse_mathfun.h b/src/layer/x86/sse_mathfun.h index b7cecfb8123..5ca090ef206 100644 --- a/src/layer/x86/sse_mathfun.h +++ b/src/layer/x86/sse_mathfun.h @@ -1157,4 +1157,17 @@ static NCNN_FORCEINLINE __m128 abs_ps(__m128 inputs) return _mm_andnot_ps(magic_negative_zero, inputs); } +static NCNN_FORCEINLINE __m128 remainder_ps(__m128 x, __m128 y) +{ + const __m128 div_result = _mm_div_ps(x, y); + // Need SSE4.1 + // const __m128 floor_result = _mm_floor_ps(div_result); + const __m128 trunc_result = _mm_cvtepi32_ps(_mm_cvttps_epi32(div_result)); + const __m128 cmp = _mm_cmplt_ps(div_result, trunc_result); + const __m128 one = _mm_set1_ps(1.0f); + const __m128 floor_result = _mm_sub_ps(trunc_result, _mm_and_ps(cmp, one)); + const __m128 mul_result = _mm_mul_ps(y, floor_result); + return _mm_sub_ps(x, mul_result); +} + #endif // SSE_MATHFUN_H diff --git a/tests/test_binaryop.cpp b/tests/test_binaryop.cpp index 89f953eaccb..22b3ab07bf3 100644 --- a/tests/test_binaryop.cpp +++ b/tests/test_binaryop.cpp @@ -15,7 +15,7 @@ #include "layer/binaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 12 +#define OP_TYPE_MAX 13 static int op_type = 0; diff --git a/tests/test_binaryop_1.cpp b/tests/test_binaryop_1.cpp index d6b20ede1a8..3fa361c686e 100644 --- a/tests/test_binaryop_1.cpp +++ b/tests/test_binaryop_1.cpp @@ -15,7 +15,7 @@ #include "layer/binaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 12 +#define OP_TYPE_MAX 13 static int op_type = 0; diff --git a/tests/test_binaryop_2.cpp b/tests/test_binaryop_2.cpp index 14c5e7d3dac..5adb9443b31 100644 --- a/tests/test_binaryop_2.cpp +++ b/tests/test_binaryop_2.cpp @@ -15,7 +15,7 @@ #include "layer/binaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 12 +#define OP_TYPE_MAX 13 static int op_type = 0; diff --git a/tests/test_binaryop_3.cpp b/tests/test_binaryop_3.cpp index 655c2a3ce91..1d557a2002c 100644 --- a/tests/test_binaryop_3.cpp +++ b/tests/test_binaryop_3.cpp @@ -15,7 +15,7 @@ #include "layer/binaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 12 +#define OP_TYPE_MAX 13 static int op_type = 0; @@ -55,6 +55,16 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b, int flag) b[i] = 0.001f; } } + if (op_type == 12) + { + // divisor must be non-zero for remainder + b = b.clone(); + for (int i = 0; i < b.total(); i++) + { + if (b[i] == 0.f) + b[i] = 0.001f; + } + } ncnn::ParamDict pd; pd.set(0, op_type); diff --git a/tools/pnnx/src/pass_ncnn/expand_expression.cpp b/tools/pnnx/src/pass_ncnn/expand_expression.cpp index f8f97baa55c..16031e3be19 100644 --- a/tools/pnnx/src/pass_ncnn/expand_expression.cpp +++ b/tools/pnnx/src/pass_ncnn/expand_expression.cpp @@ -215,7 +215,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx if (t == "min" || t == "minimum") op_binary->params["0"] = 5; if (t == "floor_divide") fprintf(stderr, "BinaryOp floor_divide not supported yet\n"); // TODO if (t == "fmod") fprintf(stderr, "BinaryOp fmod not supported yet\n"); // TODO - if (t == "remainder") fprintf(stderr, "BinaryOp remainder not supported yet\n"); // TODO + if (t == "remainder") op_binary->params["0"] = 12; if (t == "pow") op_binary->params["0"] = 6; if (t == "atan2") op_binary->params["0"] = 10; diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 04dcbeed63f..5139c9e47fd 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -186,6 +186,7 @@ pnnx_ncnn_add_test(torch_minimum) pnnx_ncnn_add_test(torch_neg) pnnx_ncnn_add_test(torch_pow) pnnx_ncnn_add_test(torch_reciprocal) +pnnx_ncnn_add_test(torch_remainder) pnnx_ncnn_add_test(torch_round) pnnx_ncnn_add_test(torch_rsqrt) pnnx_ncnn_add_test(torch_sin) diff --git a/tools/pnnx/tests/ncnn/test_torch_remainder.py b/tools/pnnx/tests/ncnn/test_torch_remainder.py new file mode 100644 index 00000000000..ffa06dba25f --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_remainder.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + out0 = torch.remainder(x, y) + out1 = torch.remainder(y, y) + out2 = torch.remainder(z, torch.ones_like(z) + 0.5) + return out0, out1, out2 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.rand(3, 16) + z = torch.rand(5, 9, 3) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_remainder.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_remainder.pt inputshape=[3,16],[3,16],[5,9,3]") + + # ncnn inference + import test_torch_remainder_ncnn + b = test_torch_remainder_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)